Skip to content

Commit

Permalink
update llmc export (#4584)
Browse files Browse the repository at this point in the history
* update example

* move train to optim

* rename

* b2
  • Loading branch information
Qazalin committed May 14, 2024
1 parent 355e1c1 commit 9aa5e02
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions examples/llm.c/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from train_gpt2 import GPT, GPTConfig
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import get_linearizer, run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.ops import BufferOps, LoadOps

Expand All @@ -24,6 +24,7 @@
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
B, T = 4, 64

Tensor.training = True
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
warmup_count = getenv("WARMUP", 3)
for i in range(warmup_count): # TODO: why does it take three and not two to stablize
Expand All @@ -46,7 +47,7 @@
ast_dedup = dedup([si.ast for si in sched if si.ast[0].op is BufferOps.STORE])
srcs = {}
for ast in ast_dedup:
k = Device["CLANG"].get_linearizer(*ast)
k = get_linearizer(Device["CLANG"].renderer, ast)
k.linearize()
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
srcs[ast] = (k.name, src)
Expand All @@ -62,7 +63,7 @@
if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
state_dict.update(grad_state_dict)
state_dict.update({'adam_b1': optimizer.b1, 'adam_b2': optimizer.b2, 'adam_t': optimizer.t, 'adam_lr': optimizer.lr})
state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
inverse_state_dict = {v:k for k,v in state_dict.items()}
for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v):
nm = inverse_state_dict[p]
Expand Down

0 comments on commit 9aa5e02

Please sign in to comment.