In [42]:
import numpy as np
from torch import optim, nn, Tensor
from torch.nn import functional as F
import torch
import wandb
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import transformers
import lightning as L
from inspect import signature, _ParameterKind
import copy
import gc
import datasets
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import pandas as pd
from country_list import countries_for_language

In [21]:
model_name = 'mistralai/Mistral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_name)
Token = {v: k for k, v in tokenizer.get_vocab().items()}
model = AutoModelForCausalLM.from_pretrained(model_name).to('cuda')

tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [28]:
def topk(v, k=10):
    # Takes in logits
    #v = softmax(v.flatten())
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    v = v.flatten()
    idxs = v.argsort()[-k:][::-1]
    ret = [(Token[i], v[i]) for i in idxs]
    return pd.DataFrame(ret, columns=['token', 'logit'])

In [73]:
def print_tokens(s):
    tokens = tokenizer(s).input_ids
    print('|'.join(Token[x] for x in tokens))

In [23]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fe3adf48220>

In [68]:
countries = countries_for_language('en')
countries = [c for _, c in countries if len(tokenizer(c)['input_ids']) == 2]
countries = [c for c in countries if c != 'Singapore'] # kind of an edge case...

In [70]:
template = 'The nation of {}, whose capital city is'
input = tokenizer([template.format(c) for c in countries], return_tensors='pt').to('cuda')

In [74]:
print_tokens('The nation of France, whose capital city is')

<s>|▁The|▁nation|▁of|▁France|,|▁whose|▁capital|▁city|▁is


In [120]:
out = model(**input, output_hidden_states=True, output_attentions=True)

In [102]:
out.attentions[0].shape

torch.Size([51, 32, 10, 10])

In [111]:
out.attentions[0].mean(dim=0)[0]

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.7842e-01, 2.1582e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.7448e-01, 2.2072e-02, 3.4500e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.4036e-01, 1.5855e-02, 2.3344e-03, 4.1448e-02, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.6279e-01, 1.2634e-02, 1.9308e-03, 1.7299e-02, 5.3518e-03, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.0677e-01, 8.3474e-03, 8.1267e-04, 1.2414e-02, 3.7718e-03, 6.7887e-02,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.2191e-01, 5.7350e-03, 4.3037e-04, 5.1827e-03, 2.5756e-03, 4.7140e-02,
         1.7025e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.5568e-01, 9.2763

In [None]:
ticks = ['The', 'nation', 'of', 'C', ',', 'whose', 'capital', 'city', 'is']
for i in range(32):
    for j in range(32):
        print(i, j)
        plt.imshow(out.attentions[i].mean(dim=0)[j][1:,1:].detach().cpu().numpy())
        plt.xticks(range(9), ticks)
        plt.yticks(range(9), ticks)
        plt.show()

In [113]:
out.hidden_states[0].shape

torch.Size([51, 10, 4096])

In [125]:
ret = 0
print(f'{0:.3f}', end=' ')

0.000 

In [131]:
for i in range(10):
    for j in range(32):
        h = out.hidden_states[j][:,i,:]
        h -= h.mean(dim=0)
        norm = torch.linalg.norm(h, ord=2).item()
        avg_norm = torch.linalg.norm(h, dim=1, ord=2).mean().item()
        ret = 0 if norm == 0 else norm / avg_norm
        #print(ret, end=' ')
        print(f'{ret:.1f}', end=' ')
    print()

0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 
1.6 1.8 1.8 2.0 2.1 2.0 2.0 2.0 2.1 2.1 2.2 2.1 2.1 2.1 2.1 2.2 2.2 2.1 2.0 1.9 2.0 2.0 2.0 2.0 2.0 2.0 2.1 2.1 2.1 2.1 2.1 2.2 
0.0 2.5 2.4 2.7 3.5 3.2 2.7 2.8 2.9 2.9 3.0 3.1 2.9 3.0 3.1 3.1 2.9 2.9 2.9 2.6 2.4 2.3 2.3 2.3 2.3 2.3 2.3 2.3 2.4 2.4 2.5 2.6 
0.0 3.2 2.4 2.7 2.8 2.9 2.8 2.9 3.0 3.0 3.0 3.0 2.9 2.9 3.0 2.9 2.8 2.7 2.7 2.5 2.4 2.4 2.3 2.3 2.2 2.2 2.2 2.2 2.2 2.3 2.3 2.4 
0.0 3.9 3.7 3.0 3.2 3.1 2.3 2.2 2.2 2.3 2.4 2.5 2.6 2.7 2.8 2.9 2.9 3.0 3.3 3.2 2.7 2.7 2.6 2.5 2

In [79]:
len(out.hidden_states)

33

In [80]:
out.hidden_states[0][:,0,:].shape

torch.Size([51, 4096])