Skip to content

10_llm-training-speed - state dict #580

@d-kleine

Description

@d-kleine

Bug description

Minor issue as this code line is commented out:

# model.load_state_dict(torch.load("model.pth", weights_only=True))

# model.load_state_dict(torch.load("model.pth", weights_only=True))

# model.load_state_dict(torch.load("model.pth", weights_only=True))

I just have noticed that loading the saved model model.pth via model.load_state_dict() does not work as there is a _orig_mod prefix due to torch.compile()

(llms) dk@Eclipse:/mnt/c/Users/dk/Desktop/LLMs-from-scratch/ch05/10_llm-training-speed$ python 01_opt_single_gpu.py 
Traceback (most recent call last):
  File "/mnt/c/Users/dk/Desktop/LLMs-from-scratch/ch05/10_llm-training-speed/01_opt_single_gpu.py", line 507, in <module>
    model.load_state_dict(torch.load("model.pth", weights_only=True))
  File "/home/dk/miniconda3/envs/llms/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for GPTModel:
        Missing key(s) in state_dict: "tok_emb.weight", "pos_emb.weight", "trf_blocks.0.att.qkv.weight", "trf_blocks.0.att.proj.weight", "trf_blocks.0.att.proj.bias", "trf_blocks.0.ff.layers.0.weight", "trf_blocks.0.ff.layers.0.bias", "trf_blocks.0.ff.layers.2.weight", "trf_blocks.0.ff.layers.2.bias", "trf_blocks.0.norm1.weight", "trf_blocks.0.norm1.bias", "trf_blocks.0.norm2.weight", "trf_blocks.0.norm2.bias", "trf_blocks.1.att.qkv.weight", "trf_blocks.1.att.proj.weight", "trf_blocks.1.att.proj.bias", "trf_blocks.1.ff.layers.0.weight", "trf_blocks.1.ff.layers.0.bias", "trf_blocks.1.ff.layers.2.weight", "trf_blocks.1.ff.layers.2.bias", "trf_blocks.1.norm1.weight", "trf_blocks.1.norm1.bias", "trf_blocks.1.norm2.weight", "trf_blocks.1.norm2.bias", "trf_blocks.2.att.qkv.weight", "trf_blocks.2.att.proj.weight", "trf_blocks.2.att.proj.bias", "trf_blocks.2.ff.layers.0.weight", "trf_blocks.2.ff.layers.0.bias", "trf_blocks.2.ff.layers.2.weight", "trf_blocks.2.ff.layers.2.bias", "trf_blocks.2.norm1.weight", "trf_blocks.2.norm1.bias", "trf_blocks.2.norm2.weight", "trf_blocks.2.norm2.bias", "trf_blocks.3.att.qkv.weight", "trf_blocks.3.att.proj.weight", "trf_blocks.3.att.proj.bias", "trf_blocks.3.ff.layers.0.weight", "trf_blocks.3.ff.layers.0.bias", "trf_blocks.3.ff.layers.2.weight", "trf_blocks.3.ff.layers.2.bias", "trf_blocks.3.norm1.weight", "trf_blocks.3.norm1.bias", "trf_blocks.3.norm2.weight", "trf_blocks.3.norm2.bias", "trf_blocks.4.att.qkv.weight", "trf_blocks.4.att.proj.weight", "trf_blocks.4.att.proj.bias", "trf_blocks.4.ff.layers.0.weight", "trf_blocks.4.ff.layers.0.bias", "trf_blocks.4.ff.layers.2.weight", "trf_blocks.4.ff.layers.2.bias", "trf_blocks.4.norm1.weight", "trf_blocks.4.norm1.bias", "trf_blocks.4.norm2.weight", "trf_blocks.4.norm2.bias", "trf_blocks.5.att.qkv.weight", "trf_blocks.5.att.proj.weight", "trf_blocks.5.att.proj.bias", "trf_blocks.5.ff.layers.0.weight", "trf_blocks.5.ff.layers.0.bias", "trf_blocks.5.ff.layers.2.weight", "trf_blocks.5.ff.layers.2.bias", "trf_blocks.5.norm1.weight", "trf_blocks.5.norm1.bias", "trf_blocks.5.norm2.weight", "trf_blocks.5.norm2.bias", "trf_blocks.6.att.qkv.weight", "trf_blocks.6.att.proj.weight", "trf_blocks.6.att.proj.bias", "trf_blocks.6.ff.layers.0.weight", "trf_blocks.6.ff.layers.0.bias", "trf_blocks.6.ff.layers.2.weight", "trf_blocks.6.ff.layers.2.bias", "trf_blocks.6.norm1.weight", "trf_blocks.6.norm1.bias", "trf_blocks.6.norm2.weight", "trf_blocks.6.norm2.bias", "trf_blocks.7.att.qkv.weight", "trf_blocks.7.att.proj.weight", "trf_blocks.7.att.proj.bias", "trf_blocks.7.ff.layers.0.weight", "trf_blocks.7.ff.layers.0.bias", "trf_blocks.7.ff.layers.2.weight", "trf_blocks.7.ff.layers.2.bias", "trf_blocks.7.norm1.weight", "trf_blocks.7.norm1.bias", "trf_blocks.7.norm2.weight", "trf_blocks.7.norm2.bias", "trf_blocks.8.att.qkv.weight", "trf_blocks.8.att.proj.weight", "trf_blocks.8.att.proj.bias", "trf_blocks.8.ff.layers.0.weight", "trf_blocks.8.ff.layers.0.bias", "trf_blocks.8.ff.layers.2.weight", "trf_blocks.8.ff.layers.2.bias", "trf_blocks.8.norm1.weight", "trf_blocks.8.norm1.bias", "trf_blocks.8.norm2.weight", "trf_blocks.8.norm2.bias", "trf_blocks.9.att.qkv.weight", "trf_blocks.9.att.proj.weight", "trf_blocks.9.att.proj.bias", "trf_blocks.9.ff.layers.0.weight", "trf_blocks.9.ff.layers.0.bias", "trf_blocks.9.ff.layers.2.weight", "trf_blocks.9.ff.layers.2.bias", "trf_blocks.9.norm1.weight", "trf_blocks.9.norm1.bias", "trf_blocks.9.norm2.weight", "trf_blocks.9.norm2.bias", "trf_blocks.10.att.qkv.weight", "trf_blocks.10.att.proj.weight", "trf_blocks.10.att.proj.bias", "trf_blocks.10.ff.layers.0.weight", "trf_blocks.10.ff.layers.0.bias", "trf_blocks.10.ff.layers.2.weight", "trf_blocks.10.ff.layers.2.bias", "trf_blocks.10.norm1.weight", "trf_blocks.10.norm1.bias", "trf_blocks.10.norm2.weight", "trf_blocks.10.norm2.bias", "trf_blocks.11.att.qkv.weight", "trf_blocks.11.att.proj.weight", "trf_blocks.11.att.proj.bias", "trf_blocks.11.ff.layers.0.weight", "trf_blocks.11.ff.layers.0.bias", "trf_blocks.11.ff.layers.2.weight", "trf_blocks.11.ff.layers.2.bias", "trf_blocks.11.norm1.weight", "trf_blocks.11.norm1.bias", "trf_blocks.11.norm2.weight", "trf_blocks.11.norm2.bias", "final_norm.weight", "final_norm.bias", "out_head.weight". 
        Unexpected key(s) in state_dict: "_orig_mod.tok_emb.weight", "_orig_mod.pos_emb.weight", "_orig_mod.trf_blocks.0.att.qkv.weight", "_orig_mod.trf_blocks.0.att.proj.weight", "_orig_mod.trf_blocks.0.att.proj.bias", "_orig_mod.trf_blocks.0.ff.layers.0.weight", "_orig_mod.trf_blocks.0.ff.layers.0.bias", "_orig_mod.trf_blocks.0.ff.layers.2.weight", "_orig_mod.trf_blocks.0.ff.layers.2.bias", "_orig_mod.trf_blocks.0.norm1.weight", "_orig_mod.trf_blocks.0.norm1.bias", "_orig_mod.trf_blocks.0.norm2.weight", "_orig_mod.trf_blocks.0.norm2.bias", "_orig_mod.trf_blocks.1.att.qkv.weight", "_orig_mod.trf_blocks.1.att.proj.weight", "_orig_mod.trf_blocks.1.att.proj.bias", "_orig_mod.trf_blocks.1.ff.layers.0.weight", "_orig_mod.trf_blocks.1.ff.layers.0.bias", "_orig_mod.trf_blocks.1.ff.layers.2.weight", "_orig_mod.trf_blocks.1.ff.layers.2.bias", "_orig_mod.trf_blocks.1.norm1.weight", "_orig_mod.trf_blocks.1.norm1.bias", "_orig_mod.trf_blocks.1.norm2.weight", "_orig_mod.trf_blocks.1.norm2.bias", "_orig_mod.trf_blocks.2.att.qkv.weight", "_orig_mod.trf_blocks.2.att.proj.weight", "_orig_mod.trf_blocks.2.att.proj.bias", "_orig_mod.trf_blocks.2.ff.layers.0.weight", "_orig_mod.trf_blocks.2.ff.layers.0.bias", "_orig_mod.trf_blocks.2.ff.layers.2.weight", "_orig_mod.trf_blocks.2.ff.layers.2.bias", "_orig_mod.trf_blocks.2.norm1.weight", "_orig_mod.trf_blocks.2.norm1.bias", "_orig_mod.trf_blocks.2.norm2.weight", "_orig_mod.trf_blocks.2.norm2.bias", "_orig_mod.trf_blocks.3.att.qkv.weight", "_orig_mod.trf_blocks.3.att.proj.weight", "_orig_mod.trf_blocks.3.att.proj.bias", "_orig_mod.trf_blocks.3.ff.layers.0.weight", "_orig_mod.trf_blocks.3.ff.layers.0.bias", "_orig_mod.trf_blocks.3.ff.layers.2.weight", "_orig_mod.trf_blocks.3.ff.layers.2.bias", "_orig_mod.trf_blocks.3.norm1.weight", "_orig_mod.trf_blocks.3.norm1.bias", "_orig_mod.trf_blocks.3.norm2.weight", "_orig_mod.trf_blocks.3.norm2.bias", "_orig_mod.trf_blocks.4.att.qkv.weight", "_orig_mod.trf_blocks.4.att.proj.weight", "_orig_mod.trf_blocks.4.att.proj.bias", "_orig_mod.trf_blocks.4.ff.layers.0.weight", "_orig_mod.trf_blocks.4.ff.layers.0.bias", "_orig_mod.trf_blocks.4.ff.layers.2.weight", "_orig_mod.trf_blocks.4.ff.layers.2.bias", "_orig_mod.trf_blocks.4.norm1.weight", "_orig_mod.trf_blocks.4.norm1.bias", "_orig_mod.trf_blocks.4.norm2.weight", "_orig_mod.trf_blocks.4.norm2.bias", "_orig_mod.trf_blocks.5.att.qkv.weight", "_orig_mod.trf_blocks.5.att.proj.weight", "_orig_mod.trf_blocks.5.att.proj.bias", "_orig_mod.trf_blocks.5.ff.layers.0.weight", "_orig_mod.trf_blocks.5.ff.layers.0.bias", "_orig_mod.trf_blocks.5.ff.layers.2.weight", "_orig_mod.trf_blocks.5.ff.layers.2.bias", "_orig_mod.trf_blocks.5.norm1.weight", "_orig_mod.trf_blocks.5.norm1.bias", "_orig_mod.trf_blocks.5.norm2.weight", "_orig_mod.trf_blocks.5.norm2.bias", "_orig_mod.trf_blocks.6.att.qkv.weight", "_orig_mod.trf_blocks.6.att.proj.weight", "_orig_mod.trf_blocks.6.att.proj.bias", "_orig_mod.trf_blocks.6.ff.layers.0.weight", "_orig_mod.trf_blocks.6.ff.layers.0.bias", "_orig_mod.trf_blocks.6.ff.layers.2.weight", "_orig_mod.trf_blocks.6.ff.layers.2.bias", "_orig_mod.trf_blocks.6.norm1.weight", "_orig_mod.trf_blocks.6.norm1.bias", "_orig_mod.trf_blocks.6.norm2.weight", "_orig_mod.trf_blocks.6.norm2.bias", "_orig_mod.trf_blocks.7.att.qkv.weight", "_orig_mod.trf_blocks.7.att.proj.weight", "_orig_mod.trf_blocks.7.att.proj.bias", "_orig_mod.trf_blocks.7.ff.layers.0.weight", "_orig_mod.trf_blocks.7.ff.layers.0.bias", "_orig_mod.trf_blocks.7.ff.layers.2.weight", "_orig_mod.trf_blocks.7.ff.layers.2.bias", "_orig_mod.trf_blocks.7.norm1.weight", "_orig_mod.trf_blocks.7.norm1.bias", "_orig_mod.trf_blocks.7.norm2.weight", "_orig_mod.trf_blocks.7.norm2.bias", "_orig_mod.trf_blocks.8.att.qkv.weight", "_orig_mod.trf_blocks.8.att.proj.weight", "_orig_mod.trf_blocks.8.att.proj.bias", "_orig_mod.trf_blocks.8.ff.layers.0.weight", "_orig_mod.trf_blocks.8.ff.layers.0.bias", "_orig_mod.trf_blocks.8.ff.layers.2.weight", "_orig_mod.trf_blocks.8.ff.layers.2.bias", "_orig_mod.trf_blocks.8.norm1.weight", "_orig_mod.trf_blocks.8.norm1.bias", "_orig_mod.trf_blocks.8.norm2.weight", "_orig_mod.trf_blocks.8.norm2.bias", "_orig_mod.trf_blocks.9.att.qkv.weight", "_orig_mod.trf_blocks.9.att.proj.weight", "_orig_mod.trf_blocks.9.att.proj.bias", "_orig_mod.trf_blocks.9.ff.layers.0.weight", "_orig_mod.trf_blocks.9.ff.layers.0.bias", "_orig_mod.trf_blocks.9.ff.layers.2.weight", "_orig_mod.trf_blocks.9.ff.layers.2.bias", "_orig_mod.trf_blocks.9.norm1.weight", "_orig_mod.trf_blocks.9.norm1.bias", "_orig_mod.trf_blocks.9.norm2.weight", "_orig_mod.trf_blocks.9.norm2.bias", "_orig_mod.trf_blocks.10.att.qkv.weight", "_orig_mod.trf_blocks.10.att.proj.weight", "_orig_mod.trf_blocks.10.att.proj.bias", "_orig_mod.trf_blocks.10.ff.layers.0.weight", "_orig_mod.trf_blocks.10.ff.layers.0.bias", "_orig_mod.trf_blocks.10.ff.layers.2.weight", "_orig_mod.trf_blocks.10.ff.layers.2.bias", "_orig_mod.trf_blocks.10.norm1.weight", "_orig_mod.trf_blocks.10.norm1.bias", "_orig_mod.trf_blocks.10.norm2.weight", "_orig_mod.trf_blocks.10.norm2.bias", "_orig_mod.trf_blocks.11.att.qkv.weight", "_orig_mod.trf_blocks.11.att.proj.weight", "_orig_mod.trf_blocks.11.att.proj.bias", "_orig_mod.trf_blocks.11.ff.layers.0.weight", "_orig_mod.trf_blocks.11.ff.layers.0.bias", "_orig_mod.trf_blocks.11.ff.layers.2.weight", "_orig_mod.trf_blocks.11.ff.layers.2.bias", "_orig_mod.trf_blocks.11.norm1.weight", "_orig_mod.trf_blocks.11.norm1.bias", "_orig_mod.trf_blocks.11.norm2.weight", "_orig_mod.trf_blocks.11.norm2.bias", "_orig_mod.final_norm.weight", "_orig_mod.final_norm.bias", "_orig_mod.out_head.weight".

What operating system are you using?

None

Where do you run your code?

None

Environment




Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions