In [1]:
# need to do this before transformer imports
import os
os.environ['HF_HOME'] = '/workspace/cache/huggingface/'

import os
os.chdir('/workspace/FutureGPT2/src/')
from evals.utils import *
from models.bigram_model import *
from models.mlp_model import *
from models.future_model import *
from data.utils import get_tokenizer
import datasets
from torch.utils.data import DataLoader
from torch import nn
from itertools import islice
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
from torch import nn

from tqdm import tqdm
import pandas as pd
import gc
from glob import glob
import numpy as np
import copy

%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

In [2]:
def invert(f, y, start=1, eps=1e-4):
    '''
    Given monotonic increasing function f with domain [0,\infty), returns f^{-1}(y)
    '''
    if f(0) > y + eps:
        assert False
    elif np.abs(f(0) - y) < eps:
        return 0
    x = start
    if np.abs(f(x) - y) < eps:
        return x
    while f(x) < y:
        x *= 2
    while f(x) > y:
        x /= 2
    return x + invert(lambda z: f(z + x), y, start=x, eps=eps)

In [3]:
lsqr_cache = dict()

In [4]:
def constr_lsqr(y, A, c):
    '''
    Finds inf_w \|y-Aw\|_2 s.t. \|w\|_2<=c
    '''
    hash = tuple(A.flatten())
    if not hash in lsqr_cache:
        print('calcing SVD!')
        ATA = A.T @ A
        lsqr_cache[hash] = ATA, np.linalg.svd(ATA)
        print('SVD done!')

    ATA, (U, S, VT) = lsqr_cache[hash]
    ATy = A.T @ y
    VTATy = VT @ ATy
    S = S.reshape((-1, 1))
    
    neg_norm = lambda lam: -np.linalg.norm(VTATy / (S + lam))        
    if neg_norm(0) >= -c:
        lam = 0
    else:
        lam = invert(neg_norm, -c)
    return np.linalg.inv(ATA + lam * np.eye(A.shape[1])) @ ATy

In [5]:
class Intervene(nn.Module):
    '''
    Replaces (some subset of) input hidden_states with new_states
    Expects (batch_size, seq_length, embed_dim)
    '''
    def __init__(self, new_states):
        super().__init__()
        self.new_states = new_states

    def forward(self, hidden_states, **kwargs):
        for i in [0, 2]:
            assert self.new_states.shape[i] == hidden_states.shape[i]
        hidden_states[:,:self.new_states.shape[1],:] = self.new_states
        return hidden_states, None, None

In [6]:
model_name = 'mistralai/Mistral-7B-v0.1'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
Token = {v: k for k, v in tokenizer.get_vocab().items()}

In [8]:
def print_tokens(s):
    tokens = tokenizer(s)['input_ids']
    print('|'.join(Token[t] for t in tokens))

In [9]:
def topk(v, k=10):
    # Takes in logits
    #v = softmax(v.flatten())
    v = v.flatten()
    idxs = v.argsort()[-k:][::-1]
    ret = [(Token[i], v[i]) for i in idxs]
    return pd.DataFrame(ret, columns=['token', 'logit'])

In [10]:
ckpt = glob(
    '/workspace/checkpoints/MISTRAL-NECK-SWEEP_*_hidden_idxs-31_hidden_lb-0_token_lb--1_neck_cls-mlp_*',
)[0]

In [11]:
ckpt

'/workspace/checkpoints/MISTRAL-NECK-SWEEP_20231231-131900-D4c1E_hidden_idxs-31_hidden_lb-0_token_lb--1_neck_cls-mlp_epoch=00-val_self_loss=4.63.ckpt'

In [12]:
model = LitFutureModelWithNeck.load_from_checkpoint(ckpt, strict=False)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
orig_layers = copy.deepcopy(model.base_model.model.layers)

In [13]:
neck = model.future_neck.layers[0].weight.data.cpu().numpy()
D = model.base_model.lm_head.weight.data.cpu().numpy()
E = model.base_model.model.embed_tokens.weight.data.cpu().numpy()
A = D @ neck

In [106]:
tokenizer('platinum')

{'input_ids': [1, 549, 28250], 'attention_mask': [1, 1, 1]}

In [14]:
t1 = np.zeros((32000, 1))
t1[tokenizer('platinum').input_ids[-2]] = 1
t2 = np.zeros((32000, 1))
t2[tokenizer('platinum').input_ids[-1]] = 1
v1 = E.T @ t1

In [15]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb3986122f0>

In [103]:
print_tokens('platinum')

<s>|▁pl|atinum


In [16]:
input = tokenizer('My favorite element of the periodic table is', return_tensors='pt').to('cuda')

In [21]:
del model.base_model.model.layers
gc.collect()
torch.cuda.empty_cache()
model.base_model.model.layers = copy.deepcopy(orig_layers)
gc.collect()
torch.cuda.empty_cache()

In [17]:
out = model.base_model(**input, output_hidden_states=True)

In [18]:
topk(out.logits[0,-1,:].cpu().numpy())

Unnamed: 0,token,logit
0,▁the,10.955291
1,▁arg,9.615419
2,▁carbon,9.581732
3,▁gold,9.542505
4,▁b,9.452555
5,▁hydro,9.40939
6,▁silver,9.310298
7,▁probably,9.113802
8,▁flu,9.051174
9,▁mer,9.013393


In [19]:
h = out.hidden_states[31][0,-1,:].detach().cpu().numpy().reshape((-1, 1))
#topk(A @ np.concatenate([h, v1]))
topk(A @ h)

Unnamed: 0,token,logit
0,on,8.596912
1,ium,8.572868
2,▁element,8.255547
3,icon,8.056475
4,gen,7.965827
5,um,7.789818
6,od,7.607457
7,▁the,7.550673
8,an,7.540741
9,cury,7.414806


In [20]:
#target = 800 * t2 - A @ np.concatenate([h, v1], axis=0)  # output logits have norm around ~800
target = 800 * t2 - A @ h

In [31]:
v2_dict = {}
for eps in range(1, 40, 2):
    v2_dict[eps] = constr_lsqr(target, A[:,:4096], eps)
    print(eps, np.linalg.norm(v2_dict[eps]))

1 1.000094168294038
3 3.0000024176526274
5 4.9999936743781515
7 6.999983563878301
9 9.000068985477538
11 10.999995821994638
13 13.000050517633445
15 15.000021930273176
17 17.000056975894278
19 19.00001069258778
21 20.999901042381783
23 22.999962233140653
25 25.00004614683991
27 27.000099062531113
29 29.000081017694367
31 30.999967419911545
33 33.00006762607578
35 34.99993622984038
37 37.000032519153024
39 38.99994214636536


In [32]:
for eps in v2_dict:
    print(eps)
    print(topk(A @ (h + v2_dict[eps])))

1
      token     logit
0       ium  8.912441
1      icon  8.591976
2  ▁element  8.486414
3       gen  8.478638
4      cury  8.344257
5        on  8.338239
6       rom  7.823799
7       ith  7.773242
8        od  7.733794
9        um  7.712451
3
      token     logit
0      cury  9.747360
1      icon  9.031064
2  ▁element  8.890864
3      obal  8.889878
4       ium  8.809325
5       gen  8.791047
6      osph  8.410558
7    atinum  8.271071
8     xygen  8.206147
9     rogen  8.154701
5
      token     logit
0      cury  9.410071
1    atinum  8.897562
2  ▁element  8.372938
3      obal  7.942448
4     xygen  7.929935
5   ▁atomic  7.664288
6     ▁atom  7.656118
7      icon  7.591521
8      ▁bor  7.267469
9   ▁silver  7.192181
7
      token     logit
0    atinum  9.534046
1      cury  8.490225
2  ▁element  7.115472
3   ▁silver  6.783089
4     xygen  6.742288
5      obal  6.721434
6   ▁atomic  6.512379
7     ▁atom  6.277526
8      icon  6.204196
9     ▁gold  6.107668
9
      token      logit

In [28]:
model.base_model.model.layers.insert(31, Intervene(None)) # dummy, to be replaced

In [33]:
for eps in v2_dict:
    print(eps)
    new_state = copy.deepcopy(out.hidden_states[31])#[:,:-1,:] # Don't overwrite last token
    new_state[0,-1,:] += torch.Tensor(v2_dict[eps].flatten()).to('cuda')
    model.base_model.model.layers[31] = Intervene(new_state)
    new_out = model.base_model(**input)
    print(topk(new_out.logits[0,-1,:].cpu().numpy()))

1
       token      logit
0       ▁the  10.734898
1       ▁arg   9.385143
2    ▁carbon   9.366797
3      ▁gold   9.301023
4    ▁silver   9.223804
5     ▁hydro   9.166210
6         ▁b   9.157669
7  ▁probably   9.082529
8      ▁iron   8.965122
9       ▁flu   8.943411
3
       token      logit
0       ▁the  10.274455
1    ▁silver   8.973610
2  ▁probably   8.970881
3      ▁iron   8.847486
4    ▁carbon   8.811537
5       ▁arg   8.788260
6      ▁gold   8.664562
7       ▁flu   8.638989
8      ▁lead   8.539852
9     ▁hydro   8.539461
5
         token     logit
0         ▁the  9.782408
1      ▁silver  8.819941
2    ▁probably  8.761072
3        ▁iron  8.586790
4        ▁gall  8.411600
5        ▁lead  8.225731
6         ▁cad  8.205074
7      ▁carbon  8.153930
8         ▁flu  8.147808
9  ▁definitely  8.090841
7
         token     logit
0         ▁the  9.435711
1      ▁silver  8.943274
2    ▁probably  8.581954
3        ▁gall  8.356562
4        ▁iron  8.279730
5         ▁cad  8.208963
6  ▁definitely

In [24]:

#
#model.base_model.model.layers.pop(31)

In [26]:
topk(out.logits[0,-1,:].cpu().numpy())

Unnamed: 0,token,logit
0,▁the,10.955291
1,▁arg,9.615419
2,▁carbon,9.581732
3,▁gold,9.542505
4,▁b,9.452555
5,▁hydro,9.40939
6,▁silver,9.310298
7,▁probably,9.113802
8,▁flu,9.051174
9,▁mer,9.013393


Unnamed: 0,token,logit
0,atinum,16.078453
1,▁silver,7.7636
2,▁pl,7.306924
3,▁sil,6.999122
4,▁Pl,6.256324
5,▁the,6.25101
6,▁Silver,6.151449
7,▁tit,6.095162
8,▁probably,5.724842
9,▁ces,5.584713


In [65]:
for i in range(33):
    print(' '.join(
        str(int(out.hidden_states[i][0,j,:].norm().item())) for j in range(9)
        )
    )

0 0 0 0 0 0 0 0 0
6 0 0 0 0 0 0 0 0
265 0 0 0 0 0 0 0 0
265 0 0 0 0 0 0 0 0
265 1 1 0 0 0 0 1 0
265 1 1 1 1 1 1 1 1
265 1 1 1 1 1 1 1 1
265 2 2 1 1 1 2 2 1
265 2 2 2 2 2 2 2 2
265 4 2 2 2 2 2 2 2
264 4 3 2 2 2 2 2 2
265 5 3 3 3 3 3 3 3
264 5 3 3 3 3 3 3 3
264 5 4 3 3 3 4 4 3
264 5 4 4 4 4 4 4 4
264 6 5 5 5 5 5 5 5
263 6 6 5 5 5 5 6 5
263 8 7 6 6 6 6 7 6
264 10 9 8 8 7 7 8 7
264 12 10 9 9 9 8 10 8
267 16 13 11 11 11 9 11 10
267 19 15 12 13 12 11 13 12
267 20 17 14 14 14 12 15 14
268 22 19 15 16 15 13 16 15
268 24 21 17 17 17 15 18 17
268 24 23 18 19 19 17 19 18
268 25 25 20 20 22 18 20 20
268 26 26 22 21 24 20 22 21
268 27 28 23 23 26 23 24 24
267 28 30 25 26 28 26 26 27
268 32 35 29 31 33 31 30 32
237 38 40 33 36 38 35 35 36
345 398 459 372 395 396 324 399 388


In [61]:
out.hidden_states[31]

torch.Size([1, 9, 4096])

In [51]:
topk(out.logits.detach().cpu().numpy()[0,-1,:])

Unnamed: 0,token,logit
0,▁Michigan,13.245934
1,▁Wisconsin,13.089794
2,▁Minnesota,12.887695
3,▁Indiana,12.72505
4,▁Iowa,12.543724
5,▁Missouri,12.224394
6,▁Ohio,12.20534
7,▁Illinois,12.15834
8,▁South,11.714095
9,▁definitely,11.326981


In [15]:
model.base_model.model.layers.

<bound method ModuleList.insert of ModuleList(
  (0-31): 32 x MistralDecoderLayer(
    (self_attn): MistralAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): MistralRotaryEmbedding()
    )
    (mlp): MistralMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): MistralRMSNorm()
    (post_attention_layernorm): MistralRMSNorm()
  )
)>

In [63]:
print_tokens('My favorite state in the Midwest is Michigan')

<s>|▁My|▁favorite|▁state|▁in|▁the|▁Mid|west|▁is|▁Michigan


In [53]:
Token[1]

'<s>'

In [42]:
A = np.random.normal(0, 1, (100, 10))
y = np.random.normal(0, 1, (100, 1))

In [43]:
w = constr_lsqr(y, A, 0.1)

238.84033203125


In [44]:
np.linalg.norm(w)

0.1000000058535913

In [48]:
np.linalg.inv(A.T @ A + 238.84033 * np.eye(10)) @ A.T @ y

array([[ 0.00752921],
       [ 0.02647566],
       [-0.0483614 ],
       [-0.02094679],
       [ 0.05765701],
       [-0.01300482],
       [-0.05031044],
       [-0.01093733],
       [ 0.00570315],
       [ 0.01697087]])

In [46]:
w

array([[ 0.00752921],
       [ 0.02647566],
       [-0.0483614 ],
       [-0.02094679],
       [ 0.05765701],
       [-0.01300482],
       [-0.05031044],
       [-0.01093733],
       [ 0.00570315],
       [ 0.01697087]])