In [1]:
import os
os.chdir('/workspace/FutureGPT2/src/')
from evals.utils import *
from models.bigram_model import *
from models.mlp_model import *
from models.future_model import *
from data.utils import get_tokenizer
import datasets
from torch.utils.data import DataLoader
from torch import nn
from itertools import islice
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from tqdm import tqdm
import pandas as pd
import gc
from glob import glob

In [2]:
MODEL = 'MISTRAL'
dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_{MODEL}_64tokens_1m').with_format('torch', device=torch.device('cuda'))
test = dataset['test']

In [3]:
loader = DataLoader(test, batch_size=128)

In [4]:
ckpt_path = '/workspace/checkpoints/MISTRAL-NECK-SWEEP_20240102-191556-Ec6c4_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-lstm_epoch=00-val_self_loss=3.81.ckpt'

In [5]:
model = LitFutureModelWithNeck.load_from_checkpoint(ckpt_path, strict=False).to('cuda')
# don't reduce loss
model.loss_func = nn.CrossEntropyLoss(reduction='none')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

/home/wwu/.local/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:173: Found keys that are in the model state dict but not in the checkpoint: ['base_model.model.embed_tokens.weight', 'base_model.model.layers.0.self_attn.q_proj.weight', 'base_model.model.layers.0.self_attn.k_proj.weight', 'base_model.model.layers.0.self_attn.v_proj.weight', 'base_model.model.layers.0.self_attn.o_proj.weight', 'base_model.model.layers.0.mlp.gate_proj.weight', 'base_model.model.layers.0.mlp.up_proj.weight', 'base_model.model.layers.0.mlp.down_proj.weight', 'base_model.model.layers.0.input_layernorm.weight', 'base_model.model.layers.0.post_attention_layernorm.weight', 'base_model.model.layers.1.self_attn.q_proj.weight', 'base_model.model.layers.1.self_attn.k_proj.weight', 'base_model.model.layers.1.self_attn.v_proj.weight', 'base_model.model.layers.1.self_attn.o_proj.weight', 'base_model.model.layers.1.mlp.gate_proj.weight', 'base_model.model.layers.1.mlp.up_proj.weight', 'base_model.model.lay

In [6]:
torch.cuda.empty_cache()

In [7]:
losses = []
ids = []
test_iter = iter(loader)
for batch in tqdm(test_iter):
    loss = model._compute_loss(batch)
    losses.append(loss.self_loss.reshape(-1, 63).cpu().detach())  # (seq_length-1)=63
    ids += batch['id']
    gc.collect()
    torch.cuda.empty_cache()
losses = torch.concatenate(losses, axis=0)

100%|██████████| 704/704 [1:23:52<00:00,  7.15s/it]


In [8]:
topk_val, topk_ind = losses.flatten().topk(10)
#topk_val, topk_ind = (-losses).flatten().topk(10)
topk_ind = np.array(np.unravel_index(topk_ind.numpy(), losses.shape)).T

In [9]:
def get_row(data, id):
    idx = data['id'].index(id)
    return {k: data[k][idx] for k in ['text', 'input_ids', 'attention_mask']}

In [10]:
model_name = 'mistralai/Mistral-7B-v0.1'

tokenizer = AutoTokenizer.from_pretrained(model_name)
Token = {v: k for k, v in tokenizer.get_vocab().items()}

def topk(v, k=40):
    # Takes in logits
    #v = softmax(v.flatten())
    v = v.flatten()
    idxs = v.argsort()[-k:][::-1]
    ret = [(Token[i], v[i]) for i in idxs]
    return pd.DataFrame(ret, columns=['token', 'logit'])

In [11]:
for ind, val in zip(topk_ind, topk_val):
    row = get_row(test, ids[ind[0]])
    input_ids = row['input_ids'][:ind[1] + 2] # loss at seq idx n corresponds to forward pass at idx n+1
    input_ids = input_ids.unsqueeze(0) # add batch dim
    out = model({'input_ids': input_ids.to('cuda'), 'attention_mask': torch.ones(input_ids.shape).to('cuda')})
    base = out.logits[0, ind[1] + 1,:]
    future = out.future_logits[0, ind[1],:]
    out_str = '|'.join(Token[i] for i in input_ids.cpu().flatten().numpy())
    if ind[1] + 2 < 64:
        out_str += '(' + Token[row['input_ids'][ind[1] + 2].item()] + ')'
    print(out_str)
    print('BASE vs FUTURE:')
    print(pd.concat([
        topk(base.cpu().numpy(), k=10),
        topk(future.cpu().detach().numpy(), k=10)
    ], axis=1))
    print('LOSS:', val.item(), nn.CrossEntropyLoss()(future, torch.softmax(base, dim=0)).item())


<s>|▁In|▁June|▁|2|0|1|2|,|▁Den|ny|'|s|▁opened|▁a|▁location|▁in|▁the|▁Las|▁Am(Ã)
BASE vs FUTURE:
   token      logit token     logit
0   éric  20.796623    ag  9.692738
1      Ã  12.286123    as  9.523293
2      é  11.724434    es  9.291723
3    ist  11.619908    os  8.810611
4  érica  11.352891    ad  8.638937
5     ig  11.206367     e  8.476806
6    ric  10.407958   pes  8.458270
7     el  10.261961    ac  8.347697
8    éri  10.244709     a  8.338440
9     ér   9.979428   ena  8.296510
LOSS: 18.58062171936035 18.580528259277344
<s>|▁|1|1|▁|3|▁|Ã|·|▁|1|1|▁|3|▁=|▁|1|,|3|3|1|▁True|▁False|.|▁We|egy|:|▁False|▁User|:|▁Sim|pl|ify|.|▁y|▁|5|▁|Â|·|▁y|▁|3|▁|Ã|·|▁y|▁|2|▁We(egy)
BASE vs FUTURE:
   token      logit  token     logit
0    egy  16.067135      '  8.341678
1    edy   9.887799   ▁can  8.096248
2   ▁can   9.612230     ▁=  8.027686
3      e   8.764141     ▁x  7.747435
4  ▁have   8.721185      â  7.640388
5  ▁need   8.693643  ▁have  7.571069
6  ▁will   7.806245   ▁are  7.564444
7  ▁know   7