In [2]:
import os
os.environ['HF_HOME'] = '/workspace/cache/huggingface/'
os.chdir('/workspace/FutureGPT2/src/')

from datasets import load_dataset, load_from_disk
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from itertools import islice
from collections import defaultdict
import re
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import gc
from matplotlib import pyplot as plt
import datasets
from torch.utils.data import DataLoader
import pandas as pd

%load_ext autoreload
%autoreload 2

In [3]:
model_name = 'EleutherAI/pythia-2.8b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
Token = {v: k for k, v in tokenizer.get_vocab().items()}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
vanilla_state = torch.load(
    '/workspace/checkpoints/checkpoints_PYTHIA_UNSHARD/PYTHIA-PILE10M64-VANILLA-fp16_model_name_pythia-2.8b_lr_6.40e-05_warmup_5.00e-02_global_step=1627.0_train_loss=2.44.ckpt'
)['state_dict']
myopic_state = torch.load(
    '/workspace/checkpoints/checkpoints_PYTHIA_UNSHARD/PYTHIA-PILE10M64-MYOPIC-fp16_model_name_pythia-2.8b_lr_6.40e-05_warmup_5.00e-02_global_step=1627.0_train_loss=2.67.ckpt'
)['state_dict']
vanilla = AutoModelForCausalLM.from_pretrained(model_name).to('cuda')
myopic = AutoModelForCausalLM.from_pretrained(model_name).to('cuda')
vanilla.load_state_dict({'.'.join(k.split('.')[1:]): v for k, v in vanilla_state.items()})
myopic.load_state_dict({'.'.join(k.split('.')[1:]): v for k, v in myopic_state.items()})

<All keys matched successfully>

In [5]:
torch.set_grad_enabled(False)
gc.collect()
torch.cuda.empty_cache()

In [6]:
def topk(v, k=10):
    # Takes in logits
    #v = softmax(v.flatten())
    if isinstance(v, torch.Tensor):
        v = v.cpu().numpy()
    v = v.flatten()
    idxs = v.argsort()[-k:][::-1]
    ret = [(Token[i], v[i]) for i in idxs]
    return pd.DataFrame(ret, columns=['token', 'logit'])

In [46]:
# test = load_from_disk('/workspace/corpus/the_pile/pile_PYTHIA_64tokens_20M')['test']
# test = test.cast_column('input_ids', datasets.Sequence(datasets.Value('int64')))
# test = test.with_format('torch', device='cuda')
# test_loader = DataLoader(test, batch_size=64)

In [85]:
# test = load_dataset(
#     'EleutherAI/pile-deduped-pythia-random-sampled',
#     split='train[:10000]'
# )
# test = test.rename_column('Tokens', 'input_ids')
# test = test.remove_columns([c for c in test.column_names if c != 'input_ids'])
# test = test.cast_column('input_ids', datasets.Sequence(datasets.Value('int64')))
# test = test.with_format('torch', device='cuda')
# test_loader = DataLoader(test, batch_size=64)

In [23]:
test = load_dataset('ms_marco', 'v1.1', split='test')
test = datasets.Dataset.from_list(
    [
        tokenizer('\n'.join(x['passage_text']), max_length=64, truncation=True)
        for x in test['passages']
    ]
)
test = test.with_format('torch', device='cuda')
test_loader = DataLoader(test, batch_size=64)

In [25]:
kls = []
for batch in tqdm(iter(test_loader)):
    gc.collect(); torch.cuda.empty_cache()
    vanilla_out = vanilla(**batch)
    myopic_out = myopic(**batch)
    vanilla_probs = F.softmax(vanilla_out.logits, dim=-1)
    vanilla_logprobs = F.log_softmax(vanilla_out.logits, dim=-1)
    myopic_logprobs = F.log_softmax(myopic_out.logits, dim=-1)
    kls.append((vanilla_probs * (vanilla_logprobs - myopic_logprobs)).sum(dim=-1).to('cpu'))

100%|██████████| 151/151 [07:04<00:00,  2.81s/it]


In [26]:
kl = torch.concat(kls, dim=0)

In [27]:
v, i = torch.topk(kl.flatten(), 20)
idxs = np.array(np.unravel_index(i.numpy(), kl.shape)).T

In [28]:
for i, j in idxs:
    i = int(i)
    j = int(j)
    seq1 = test[i]['input_ids'][:j+1]
    seq2 = test[i]['input_ids'][j+1:]
    out = ''.join(Token[x.item()] for x in seq1) + '|' + ''.join(Token[x.item()] for x in seq2)
    out = out.replace('Ġ', ' ').replace('Ċ', '\n')
    print(out)
    print(kl[i, j].item())
    vanilla_out = vanilla(input_ids=seq1.reshape(1, -1))
    myopic_out = myopic(input_ids=seq1.reshape(1, -1))
    vanilla_probs = F.softmax(vanilla_out.logits, dim=-1)
    vanilla_logprobs = F.log_softmax(vanilla_out.logits, dim=-1)
    myopic_logprobs = F.log_softmax(myopic_out.logits, dim=-1)
    print(
        pd.concat(
            [topk(vanilla_out.logits[0,-1]), topk(myopic_out.logits[0,-1])],
            axis=1
        )
    )
    #print((vanilla_probs * (vanilla_logprobs - myopic_logprobs)).sum(dim=-1)[0,-1].item())
    print('--------------------\n')

Please try again later. Monza About this sound listen is a city and comune on the River Lambro, a tributary of the Po in the Lombardy region of Italy, about 15 kilometres (9 miles) north-northeast of Milan. It is the capital of the Province of Monza and|
14.324137687683105
    token      logit     token      logit
0  ĠBrian  18.883274      Ġthe  14.316242
1    ĠâĢ¦  12.807458       Ġis  13.319670
2    Ġ...  12.581610       Ġof  12.702269
3     ...  12.282018      Ġhas  12.426638
4    Ġthe  12.103540        Ġa  12.392349
5   Brian  11.708487         ,  11.626892
6     Ġof  11.391341      Ġone  11.527664
7     âĢ¦  11.338493      Ġits  11.258924
8      Ġa  10.932683     Ġhome  11.164782
9       ,  10.850215  Ġborders  11.111320
--------------------

Demographics. As of 2001 India census[2], Hoshiarpur had a population of 148,243. Males constitute 53% of the population and females 47%. Hoshiarpur has an average literacy rate of 78%, higher than the national average of 59.5%|: male literac

In [70]:
(vanilla_probs * (vanilla_logprobs - myopic_logprobs)).sum(dim=-1).shape

torch.Size([1, 40])

In [11]:
''.join(Token[x.item()] for x in test[10]['input_ids']).replace('Ġ', ' ')

'3/06/2015 - Construction Industry Networking in Tees ValleyĊĊConstructing Excellence in the North East is delighted to bring you the Tees Valley Network. We are extremely pleased to announce this monthâĢĻs sponsor is Clugston Construction.ĊĊThis relaxed and informal networking event will provide an opportunity for'

In [133]:
input = tokenizer('173+204=3', return_tensors='pt').to('cuda')

In [134]:
pd.concat(
    [topk(vanilla(**input).logits[0,-1,:]), topk(myopic(**input).logits[0,-1,:])],
    axis=1
)

Unnamed: 0,token,logit,token.1,logit.1
0,Ċ,11.634122,*,11.176103
1,*,11.251266,.,11.171001
2,",",11.000799,Ċ,10.971567
3,.,10.844119,e,8.93327
4,(,10.661886,",",8.471508
5,+,10.438973,Ġ+,8.431137
6,Ġ+,9.861268,+,8.300894
7,Ġ(,9.305843,:,8.27208
8,Ġ*,9.231849,-,8.137039
9,Ġx,9.191463,Ġ=,8.11885


In [83]:
topk(myopic(**input).logits[0,-1,:])

Unnamed: 0,token,logit
0,ats,13.352908
1,ik,11.546961
2,ime,11.361863
3,ana,11.217769
4,etal,11.110579
5,are,11.05926
6,ina,10.908261
7,aj,10.889086
8,if,10.771758
9,ak,10.658453


In [39]:
vanilla(**input).logits.shape

torch.Size([1, 12, 50304])