In [101]:
%load_ext autoreload
%autoreload 2

import os

import tiktoken
import torch
from model import EmbedTransformer
import numpy as np

SEED = 1337
torch.set_printoptions(precision = 6, profile = None)

def model_fn(model_path):
    """
    Load the PyTorch model from the `model_dir` directory.
    """
    device = "mps"
    model_dict = torch.load(
            os.path.join("model_weights", model_path), map_location=device
        )
    model = EmbedTransformer.init_from_checkpoint(model_dict)
    model.eval()
    model.to(device)
    model_dict = None
    return model

def get_batch(split, data_dir, block_size, batch_size, device):
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'full_harry_potter_train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'full_harry_potter_val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

model_base = model_fn("baseline_w_final_ln.pth")
model_base_wo = model_fn("small_baseline_wo_final_ln.pth")
model_new = model_fn("small_new.pth")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using flash attention.
Using flash attention.
Using flash attention.
Using flash attention.
Using flash attention.


In [102]:
model_base.config

ModelConfig(context_size=50, n_embed=100, n_layer=3, n_head=2, use_new_output_layer=False, use_final_ln_layer=True, dropout_rate=0.1, new_output_layer_config=None, alphabet_size=50257, bias=False, use_flash=True)

In [103]:
x, y = get_batch('train', '../datasets/full_harry_potter/', model_base.config.context_size, 1, "mps")
x, y

(tensor([[  319,    13,   366, 13056,   345,  1936,  5027,   293,   684,   262,
           1306,   530, 10564,    13,   350,   414,   198,   270,  2492,   470,
          46236,   532,     1,   198,   198,   464,  8966, 28077,   379,   326,
           2589,    11,   543,   373,  9670,    26,   379, 39157,   338,   938,
            198, 10879,    11,  6575,   550, 43713,   572,   465, 38753,    11]],
        device='mps:0'),
 tensor([[   13,   366, 13056,   345,  1936,  5027,   293,   684,   262,  1306,
            530, 10564,    13,   350,   414,   198,   270,  2492,   470, 46236,
            532,     1,   198,   198,   464,  8966, 28077,   379,   326,  2589,
             11,   543,   373,  9670,    26,   379, 39157,   338,   938,   198,
          10879,    11,  6575,   550, 43713,   572,   465, 38753,    11,   290]],
        device='mps:0'))

In [104]:
logits, loss, preserve_1 = model_base(x, y)
logits, loss, preserve_1

batch_attn_weights.weight
mean: 8.156077092280611e-05
std: 0.05263173580169678
tensor(0.070860, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: 0.0003068844380322844
std: 0.04025254026055336
tensor(0.037620, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
linear.weight
mean: 0.0005448497831821442
std: 0.0536167174577713
tensor(0.099277, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: -7.447772077284753e-05
std: 0.04853970929980278
tensor(0.043865, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
batch_attn_weights.weight
mean: -0.0004314107936806977
std: 0.0393262580037117
tensor(0.072712, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: 0.00015296989295165986
std: 0.04353002831339836
tensor(0.039068, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
linear.weight
mean: -0.0017789130797609687
std: 0.0697833821

(tensor([[  0.869456,  -1.662627,  ...,  -3.179193, -11.098945],
         [  0.332209,   3.751071,  ...,  -2.451365,  -8.403477],
         ...,
         [  3.052134,   1.488364,  ...,  -2.896238, -12.105692],
         [  0.611562,   3.301512,  ...,  -0.831245,  -7.881510]],
        device='mps:0', grad_fn=<ViewBackward0>),
 tensor(4.948982, device='mps:0', grad_fn=<NllLossBackward0>),
 tensor([[[ 2.895113,  0.172989,  ...,  0.390608, -0.076955],
          [-4.071156,  4.572210,  ...,  7.003213, 10.380942],
          ...,
          [ 6.528827, -5.803418,  ...,  1.061136,  0.733176],
          [-1.063884, -6.729303,  ...,  4.541925,  5.985963]]], device='mps:0',
        grad_fn=<AddBackward0>))

In [105]:
logits, loss, preserve_1 = model_base_wo(x, y)
logits, loss, preserve_1

batch_attn_weights.weight
mean: 0.004816184751689434
std: 0.06880626082420349
tensor(0.431733, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: -0.014437519945204258
std: 0.07449027895927429
tensor(0.412000, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
linear.weight
mean: -0.0007336222915910184
std: 0.08442458510398865
tensor(0.468000, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: 0.029539411887526512
std: 0.08249693363904953
tensor(0.496000, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@


(tensor([[-0.425212, -0.331839,  ..., -0.384078, -0.472897],
         [-0.342075, -0.257724,  ..., -0.302846, -0.385399],
         ...,
         [-0.228476, -0.160194,  ..., -0.195381, -0.264601],
         [-0.022605, -0.003160,  ..., -0.007545, -0.026356]], device='mps:0',
        grad_fn=<ViewBackward0>),
 tensor(0.632456, device='mps:0', grad_fn=<NllLossBackward0>),
 tensor([[[ 0.638241,  0.382952,  0.777406,  0.446957,  0.657296],
          [ 0.568348,  0.382884,  0.675938,  0.410851,  0.555484],
          [ 0.543656,  0.351955,  0.652079,  0.391024,  0.537854],
          [ 0.621775,  0.374104,  0.739263,  0.440920,  0.620805],
          [ 0.620462,  0.384013,  0.731432,  0.435339,  0.618149],
          [ 0.619639,  0.368268,  0.726957,  0.441147,  0.613560],
          [ 0.628867,  0.363562,  0.734036,  0.430102,  0.616002],
          [ 0.615435,  0.378937,  0.704111,  0.454140,  0.601360],
          [ 0.584162,  0.363212,  0.667127,  0.417013,  0.559340],
          [ 0.614444,  0.

In [106]:
logits, loss, preserve_2 = model_new(x, y)
logits, loss, preserve_2

batch_attn_weights.weight
mean: -0.03141744062304497
std: 0.11322087049484253
tensor(0.397867, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: -0.04464169964194298
std: 0.15843133628368378
tensor(0.404000, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
linear.weight
mean: -0.02581316977739334
std: 0.15734650194644928
tensor(0.452000, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
residual_proj.weight
mean: -0.006305675487965345
std: 0.19148971140384674
tensor(0.500000, device='mps:0')
@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@


(tensor([[-5.885533,  0.488109,  ..., -7.332145, -7.336850],
         [-5.082309,  0.456361,  ..., -6.294451, -6.298660],
         ...,
         [-3.824644,  0.319742,  ..., -4.648238, -4.647143],
         [ 2.766201, -0.212537,  ...,  3.364758,  3.363505]], device='mps:0',
        grad_fn=<ViewBackward0>),
 tensor(0.936638, device='mps:0', grad_fn=<NllLossBackward0>),
 tensor([[[-2.304789,  3.157949, -2.115094, -2.457049, -2.644482],
          [-1.956201,  2.715599, -1.721207, -2.087803, -2.403902],
          [-1.972106,  2.638837, -1.968992, -2.190049, -2.435396],
          [-2.200940,  2.977507, -2.015945, -2.333250, -2.605975],
          [-2.147825,  2.850530, -1.961850, -2.246897, -2.647933],
          [-2.113140,  2.848388, -1.967411, -2.225606, -2.614479],
          [-2.176913,  2.961578, -1.992229, -2.237164, -2.618507],
          [-2.058432,  2.829214, -1.925570, -2.315913, -2.570500],
          [-2.122394,  2.855711, -1.904035, -2.306274, -2.597177],
          [-2.004837,  2.