In [1]:
import torch
from torch import nn
from model import *

In [2]:
dim = 288
n_layers = 6
n_heads = 6
multiple_of = 32
dropout = 0.0

batch_size = 2  # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len = 256


model_args = dict(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_heads,
    vocab_size=32000,
    multiple_of=multiple_of,
    max_seq_len=max_seq_len,
    #dropout=dropout,
    softmax="softmax1",
    flash=False,

)

# model_args=dict(dim=768, n_layers=12, n_heads=12, 
#                 n_kv_heads=12, vocab_size=32000, multiple_of=32, 
#                 norm_eps=1e-05, max_seq_len=1024, dropout=0.0, 
#                 add_zero_attn=False, softmax='softmax1', flash=False)

model_args = ModelArgs(**model_args)

In [3]:
model = Transformer(model_args)



In [4]:
# model = torch.compile(model)
model.eval();

## Hooks

In [54]:
def hook_fn(module, input, output):
    "Hook function to store outputs in a .output attribute"
    if hasattr(module, "max_att"):
        module.max_att = max(output.detach().max(), module.max_att)
    else:
        module.max_att = output.detach().max()
    print(module.max_att)

In [55]:

def add_hooks(model):
    hook_handles = []
    for b in model.layers:
        hook_handles.append(b.attention.register_forward_hook(hook_fn))
    return hook_handles

handles = add_hooks(model)

In [56]:
x = torch.randint(0, model_args.vocab_size, (batch_size, model_args.max_seq_len) )
out = model(x)

In [49]:
for l in model.layers:
    print(l.attention.max_att)

tensor(0.0571)
tensor(0.0575)
tensor(0.0736)
tensor(0.0824)
tensor(0.0726)
tensor(0.1038)


In [None]:
import torch

m = torch.nn.Linear(2,1)

m.register_forward_hook(lambda m,i,o: print("I am capatin Hook"))

# works as expexted
m(torch.randn(2))

# does not work
m.forward(torch.randn(2))


I am capatin Hook


tensor([-1.7607], grad_fn=<AddBackward0>)

## Compute metrics

In [6]:
with torch.no_grad():
    x = torch.randint(0, model_args.vocab_size, (batch_size, model_args.max_seq_len) )
    out = model(x)
    inf_norm, k = model.compute_attention_metrics()

In [7]:
inf_norm, k

([0.05343833565711975,
  0.05626192316412926,
  0.0662432461977005,
  0.07274584472179413,
  0.07878565788269043,
  0.0826156884431839],
 [9.609243933664462,
  2.1342359567583387,
  0.6189443758403588,
  -0.12834004447029024,
  -0.016824915476210123,
  0.009668920702321948])

In [6]:
outputs = [b.attention.output for b in model.layers]

In [7]:
len(outputs)

6

In [8]:
from scipy.stats import kurtosis


In [9]:
outputs[0]

tensor([[[-1.4942e-02,  1.7280e-02, -7.3087e-03,  ..., -5.6646e-03,
           2.7745e-02,  1.6313e-02],
         [-2.5999e-02,  1.2745e-02, -3.5869e-03,  ..., -1.5707e-02,
           1.2777e-02,  2.0607e-02],
         [-1.8710e-02,  1.7124e-02, -4.9153e-03,  ...,  9.6869e-03,
           3.3312e-03,  9.3342e-03],
         ...,
         [-1.4150e-03,  3.3449e-03, -4.0636e-03,  ..., -7.5439e-04,
           3.7374e-04,  5.6159e-04],
         [-1.3440e-03,  3.3190e-03, -4.2030e-03,  ...,  6.1608e-04,
           1.8159e-04,  7.8880e-04],
         [-1.9902e-03,  3.7543e-03, -4.6454e-03,  ..., -2.2938e-04,
           3.7203e-04,  7.0601e-04]],

        [[-1.4862e-02, -2.3684e-03, -3.0112e-02,  ...,  8.2180e-03,
          -1.0275e-02,  1.0324e-02],
         [-8.0560e-03,  5.2986e-03, -5.1081e-03,  ..., -8.5383e-03,
           1.9722e-05,  1.7366e-02],
         [-6.3437e-03,  8.5796e-03, -1.3878e-02,  ...,  3.0706e-03,
           1.1113e-03,  1.3156e-02],
         ...,
         [ 1.4875e-03, -1

In [10]:
[kurtosis(o.flatten()) for o in outputs]

[7.876882591728345,
 2.360544078376952,
 0.9361386662094842,
 0.2454783971266279,
 0.4035440196696536,
 -0.017442972684702074]

In [None]:
torch.mean(torch.tensor([kurtosis(o.flatten().cpu().float()) for o in outputs]))