In [69]:
import os
import torch

from pathlib import Path 
from transformers import AutoTokenizer
from transformer.config import Config
from transformer.main import DemoTransformer



In [72]:
cfg = Config(n_layers=1)    
model = DemoTransformer(cfg=cfg)


In [74]:
model

DemoTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

In [75]:
model_dir = Path('trained_models')
checkpoint_path = max(model_dir.glob('*.pth'), key = os.path.getctime)
print("this is check_point", checkpoint_path)
model.load_state_dict(torch.load(checkpoint_path))




this is check_point trained_models/model_weights_20241212_170753.pth


  model.load_state_dict(torch.load(checkpoint_path))


<All keys matched successfully>

In [76]:
model.eval()
tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [77]:
def run_inference(text: str):
    tokens = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        output = model(tokens['input_ids'])

    predicted_token_ids = torch.argmax(output, dim=-1)
    predicted_text = tokenizer.decode(predicted_token_ids[0][-1])
    print(f"Input: {text}")
    print(f"Model output: {predicted_text}")

In [78]:
text = "the cat in the"
run_inference(text)

Input: the cat in the
Model output:  hat


In [79]:
for name, param in model.named_parameters():
    print(f"Layer: {name} | Shape: {param.shape}")

Layer: embed.W_E | Shape: torch.Size([50257, 768])
Layer: pos_embed.W_pos | Shape: torch.Size([1024, 768])
Layer: blocks.0.ln1.W | Shape: torch.Size([768])
Layer: blocks.0.ln1.b | Shape: torch.Size([768])
Layer: blocks.0.attn.W_Q | Shape: torch.Size([12, 768, 64])
Layer: blocks.0.attn.b_Q | Shape: torch.Size([12, 64])
Layer: blocks.0.attn.W_K | Shape: torch.Size([12, 768, 64])
Layer: blocks.0.attn.b_K | Shape: torch.Size([12, 64])
Layer: blocks.0.attn.W_V | Shape: torch.Size([12, 768, 64])
Layer: blocks.0.attn.b_V | Shape: torch.Size([12, 64])
Layer: blocks.0.attn.W_O | Shape: torch.Size([12, 64, 768])
Layer: blocks.0.attn.b_O | Shape: torch.Size([768])
Layer: blocks.0.ln2.W | Shape: torch.Size([768])
Layer: blocks.0.ln2.b | Shape: torch.Size([768])
Layer: blocks.0.mlp.W_in | Shape: torch.Size([768, 3072])
Layer: blocks.0.mlp.b_in | Shape: torch.Size([3072])
Layer: blocks.0.mlp.W_out | Shape: torch.Size([3072, 768])
Layer: blocks.0.mlp.b_out | Shape: torch.Size([768])
Layer: ln_final.W

In [80]:
state_dict = model.state_dict()

In [81]:
state_dict

OrderedDict([('embed.W_E',
              tensor([[0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      ...,
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')),
             ('pos_embed.W_pos',
              tensor([[ 0.0140,  0.0258,  0.0459,  ...,  0.0137, -0.0026, -0.0019],
                      [ 0.0220, -0.0112, -0.0080,  ...,  0.0041, -0.0364,  0.0032],
                      [-0.0236, -0.0268, -0.0436,  ...,  0.0082, -0.0173, -0.0177],
                      ...,
                      [ 0.0392, -0.0148, -0.0047,  ...,  0.0118,  0.0329, -0.0472],
                      [-0.0336,  0.0032, -0.0078,  ..., -0.0124, -0.0120,  0.0004],
                      [-0.0079, -0.0422, -0.0315,  ...,  0.0301,  0.0048, -0.0014]],
                     device='mps:0')),

In [82]:
print(model.state_dict()['ln_final.W'])

tensor([1.0017, 0.9868, 1.0035, 0.9996, 1.0004, 0.9982, 1.0016, 0.9964, 0.9965,
        0.9927, 0.9926, 1.0029, 0.9903, 1.0035, 0.9937, 0.9836, 1.0008, 1.0026,
        0.9953, 1.0021, 0.9898, 0.9882, 1.0026, 0.9875, 0.9836, 0.9952, 0.9973,
        1.0000, 1.0000, 0.9844, 0.9858, 0.9923, 1.0040, 0.9966, 1.0004, 0.9946,
        1.0019, 0.9963, 1.0014, 1.0007, 0.9922, 1.0049, 1.0024, 0.9922, 0.9875,
        0.9993, 1.0040, 0.9962, 0.9851, 0.9873, 0.9965, 1.0017, 0.9923, 0.9932,
        0.9932, 1.0067, 0.9979, 1.0046, 1.0020, 0.9972, 1.0081, 0.9875, 0.9957,
        1.0092, 0.9919, 0.9943, 0.9854, 0.9841, 0.9895, 0.9844, 1.0022, 1.0013,
        1.0008, 0.9946, 1.0015, 1.0008, 0.9962, 1.0005, 0.9898, 1.0019, 0.9996,
        1.0026, 1.0005, 0.9998, 1.0017, 0.9869, 1.0024, 0.9908, 1.0008, 0.9811,
        0.9908, 0.9976, 1.0001, 1.0010, 0.9992, 1.0008, 0.9915, 1.0009, 0.9886,
        0.9959, 0.9835, 0.9916, 0.9948, 0.9989, 0.9989, 0.9942, 0.9977, 0.9930,
        0.9909, 0.9776, 0.9960, 0.9862, 