In [21]:
import os
import transformer_lens
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig
from transformer_lens import utils

import accelerate
import bitsandbytes
import torch
import plotly
import plotly.express as px
import einops
import numpy as np
import psutil
import pandas as pd
import random
import json
import tqdm
from torch.utils.data import DataLoader
from datasets import Dataset


In [5]:
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x1df18a2fef0>

# Load Chat Response

In [6]:
data_file = 'D:/Code/entity_tracking_update/chat_response/entity_tracking_3e_2o_1u_prompt_config_3_Llama-2-7b-chat-hf.jsonl'

with open(data_file, encoding="utf-8") as f:
    data = [json.loads(line) for line in f]


In [7]:
responses = []
for i in range(len(data)):
    response = data[i]["response"]
    responses.append(response)

In [8]:
name1 = data_file.split('.')[0]

name2 = name1.split('/')[-1]
name2

'entity_tracking_3e_2o_1u_prompt_config_3_Llama-2-7b-chat-hf'

# Load Model

In [9]:
# model_name ="meta-llama/Llama-2-7b-chat-hf"
model_name_base ="stanford-crfm/alias-gpt2-small-x21"
model_name_tuned = "stanford-crfm/alias-gpt2-small-x21"
hf_token = "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"


In [10]:
tokenizer = AutoTokenizer.from_pretrained(model_name_tuned, token=hf_token)


## Get Length of the prompt and response

In [11]:
prmopt_file =  'D:/Code/entity_tracking_update/data/entity_tracking_3e_2o_1u_prompt_config_3.jsonl'
with open(prmopt_file, encoding="utf-8") as f:
    prompts_str = [json.loads(line) for line in f]

In [12]:

for i in range(len(prompts_str)):
    prompt = prompts_str[i]["prompt"]
        
    input_tokens = tokenizer(prompt, return_tensors="pt")
    tokens = torch.flatten(input_tokens['input_ids'])

    print(len(tokens))

prompt_length = len(tokens)

95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95
95


In [13]:
len_response_token = []
for i in range(len(responses)):
    response = responses[i]
        
    response_token = tokenizer(response, return_tensors="pt")
    tokens = torch.flatten(response_token['input_ids'])

    print(len(tokens))
    len_response_token = np.append(len_response_token,len(tokens))
# len_prompt = len(tokens)

189
140
181
180
193
178
142
189
180
195
180
181
142
164
195
141
171
191
193
178
141
172
168
178
194
195
195
195
194
168
195
168
178
138
168
139
185
193
136
191
163
180
158
158
176
173
164
174
176
163
177
178
168
167
172
166
138
136
184
175
183
193
139
138
178
165
170
168
195
175
178
137
168
176
164
178
179
193
168
193
161
137
169
193
184
141
167
164
194
170
190
170
182
139
170
138
193
174
176
174


In [14]:
min_response_len = int(min(len_response_token))
min_response_len

136

# Logits of Last Layer

In [15]:
class ModelHelper:
    def __init__(self,model_name, token, device=None, load_in_8bit=False):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)

        hf_model = AutoModelForCausalLM.from_pretrained(model_name, token=token,
                                                          device_map='auto')
        self.model = HookedTransformer.from_pretrained(model_name,
                                             hf_model=hf_model,
                                             fold_ln=False,
                                             fold_value_biases=False,
                                             center_writing_weights=False,
                                             center_unembed=False,
                                             tokenizer=self.tokenizer,
                                             device = self.device)

        # print(self.model)
        # print(self.model.cfg.n_layers)
        self.device = next(self.model.parameters()).device
        self.d_vocab = self.model.cfg.d_vocab
        self.n_layers = self.model.cfg.n_layers
        
    def logits_all_layers(self, text):
        inputs = self.tokenizer(text,return_tensors="pt")
        seq_len = inputs["input_ids"].shape[1]

        # Get residual output for each layer
        z_name_filter = lambda name: name.endswith("resid_post")
        self.model.reset_hooks()
        _,cache = self.model.run_with_cache(
        inputs["input_ids"],
        names_filter = z_name_filter,
        return_type = None
        )
        
        layer_logit_all = torch.zeros(self.n_layers,seq_len,self.d_vocab)
        for layer in range(self.model.cfg.n_layers):
            resid_ln = self.model.ln_final(cache[f'blocks.{layer}.hook_resid_post'])
            layer_logit = self.model.unembed(resid_ln)
            layer_logit_all[layer,:,:] = layer_logit
        return layer_logit_all
                    
    def logits_last_layers(self, text):
        # logits=**Unembed**(**LayerNorm**(**final_residual_stream**))
        # from text to tokens
        inputs = self.tokenizer(text,return_tensors="pt")
        seq_len = inputs["input_ids"].shape[1]

         
        # Get residual output for the last layer
        # get name filter for the residial of last layer 
        resid_post_hook_name = utils.get_act_name("resid_post", self.model.cfg.n_layers - 1)
        resid_post_name_filter = lambda name: name == resid_post_hook_name
        # run with hook
        self.model.reset_hooks()
        _,cache = self.model.run_with_cache(
        inputs["input_ids"],
        names_filter = resid_post_name_filter,
        return_type = None
        )
        # layer norm
        resid_ln = self.model.ln_final(cache[f'blocks.{self.model.cfg.n_layers - 1}.hook_resid_post'])
        # unembed
        layer_logit = self.model.unembed(resid_ln)
        return layer_logit           

In [16]:

model_base = ModelHelper(model_name_base,hf_token, load_in_8bit=False)



Loaded pretrained model stanford-crfm/alias-gpt2-small-x21 into HookedTransformer


In [17]:
model_tuned = ModelHelper(model_name_tuned,hf_token, load_in_8bit=False)

Loaded pretrained model stanford-crfm/alias-gpt2-small-x21 into HookedTransformer


# 1. Get Probability Distribution

In [74]:
logits_last_layer_base = model_base.logits_last_layers(responses[3])
logits_last_layer_tuned = model_tuned.logits_last_layers(responses[3])


prob_base =  logits_last_layer_base.softmax(dim=-1)
prob_tuned = logits_last_layer_tuned.softmax(dim=-1)

prob_base =torch.squeeze(prob_base)
prob_tuned =torch.squeeze(prob_tuned)

In [98]:
logits_last_layer_base.shape

torch.Size([1, 180, 50257])

In [137]:
prob_base.shape

torch.Size([180, 50257])

# 2. Calculate KL Divergence

In [160]:

def KL(P,Q):
     """ Epsilon is used here to avoid conditional code for
     checking that neither P nor Q is equal to 0. """
     epsilon = 0.00001

     # You may want to instead make copies to avoid changing the np arrays.
     P = P+epsilon
     Q = Q+epsilon

     divergence = np.sum(P*np.log(P/Q))
     return divergence

In [194]:
kl_all = []
for ii in range(len(prob_base)):
    kl = KL(prob_base[ii,:].to('cpu').numpy(),prob_tuned[ii,:].to('cpu').numpy())
    kl_all = np.append(kl_all,kl)
    

In [195]:
kl_all.shape

(180,)

In [196]:
kl_all


array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

# 3. Prob and Rank of Anaswer Token

In [197]:
input_tokens['input_ids'].shape

torch.Size([1, 180])

In [198]:

input_tokens = tokenizer(responses[3], return_tensors="pt")

tokens = torch.flatten(input_tokens['input_ids'])
tokens = tokens.to(prob_base.device)

In [199]:
prob_answer_tokens_base = prob_base[:, :-1].gather(dim=-1, index=tokens[1:].unsqueeze(-1)).squeeze(-1)
prob_answer_tokens_tuned = prob_tuned[:, :-1].gather(dim=-1, index=tokens[1:].unsqueeze(-1)).squeeze(-1)


In [200]:
prob_answer_tokens_base.shape

torch.Size([179])

In [201]:
prob_base.shape

torch.Size([180, 50257])

In [202]:
answer_ranks = []
for index in range(1,len(tokens)):
    answer_token = tokens[index]
    token_prob = prob_base[index-1]
    sorted_token_probs, sorted_token_values = token_prob.sort(descending=True)
    correct_rank = torch.arange(len(sorted_token_values))[
    (sorted_token_values == answer_token).cpu()
    ].item()
    answer_ranks.append( correct_rank)

In [203]:
len(answer_ranks)

179

In [205]:
len(answer_ranks)

179

# Main loop

In [27]:
def compre_two_model_last_layer(response):

    # 1. get last layer logit
    logits_last_layer_base = model_base.logits_last_layers(response)
    logits_last_layer_tuned = model_tuned.logits_last_layers(response)
    
    # From logit to softmax prbability distribution
    prob_base =  logits_last_layer_base.softmax(dim=-1)
    prob_tuned = logits_last_layer_tuned.softmax(dim=-1)
    
    prob_base =torch.squeeze(prob_base)
    prob_tuned =torch.squeeze(prob_tuned)

    # 2. get the KL divergence
    def KL(P,Q):
     """ Epsilon is used here to avoid conditional code for
     checking that neither P nor Q is equal to 0. """
     epsilon = 0.00001
     P = P+epsilon
     Q = Q+epsilon

     divergence = np.sum(P*np.log(P/Q))
     return divergence

    # For each position in the sequence
    kl_all = []
    for ii in range(len(prob_base)):
        kl = KL(prob_base[ii,:].to('cpu').numpy(),prob_tuned[ii,:].to('cpu').numpy())
        kl_all = np.append(kl_all,kl)



    # 3. Get probability of the answer tokens
    
    input_tokens = tokenizer(response, return_tensors="pt")
    tokens = torch.flatten(input_tokens['input_ids'])
    tokens = tokens.to(prob_base.device)

    prob_answer_tokens_base = prob_base[:, :-1].gather(dim=-1, index=tokens[1:].unsqueeze(-1)).squeeze(-1)
    prob_answer_tokens_tuned = prob_tuned[:, :-1].gather(dim=-1, index=tokens[1:].unsqueeze(-1)).squeeze(-1)

    prob_answer_tokens_base = prob_answer_tokens_base.to("cpu").numpy()
    prob_answer_tokens_tuned = prob_answer_tokens_tuned.to("cpu").numpy()

    # 4. Get the rank of the answer token
    def answer_tokens_ranks(tokens,probs):
        answer_ranks = []
        for index in range(1,len(tokens)):
            answer_token = tokens[index]
            token_prob = probs[index-1]
            sorted_token_probs, sorted_token_values = token_prob.sort(descending=True)
            correct_rank = torch.arange(len(sorted_token_values))[
            (sorted_token_values == answer_token).cpu()
            ].item()
            answer_ranks.append(correct_rank)
        return answer_ranks

    answer_ranks_base =  answer_tokens_ranks(tokens,prob_base)
    answer_ranks_tuned =  answer_tokens_ranks(tokens,prob_tuned)

    return kl_all[prompt_length:], prob_answer_tokens_base[prompt_length:],prob_answer_tokens_tuned[prompt_length:], answer_ranks_base[prompt_length:],answer_ranks_tuned[prompt_length:]

In [19]:
response

"Given the description after 'Description:', write a true statement about all boxes and their contents after 'Statement:'. Make sure to keep track of the changes and update the contents of the boxes according to the changes.\n\nDescription: Box A contains the letter. Box B contains the phone. Box C contains nothing. Leo moves the phone from Box B to Box C. Box A has no change in its content.\n\nStatement: Let's think step by step.\n\n1. Box A contains the letter.\n2. Leo moves the phone from Box B to Box C.\n3. Box B has no phone now.\n4. Box C contains the phone now.\n\nWhat is the content of each box after the above steps?\n\nPlease answer with a true statement for each box and its contents after the given steps.</s>"

In [22]:
compare_result = []
for response in responses:
    kl, prob_answer_tokens_base,prob_answer_tokens_tuned, rank_answer_tokens_base,rank_answer_tokens_tuned = compre_two_model_last_layer(response)
    entry = {}
    entry["kl"] = kl.tolist()
    entry["prob_answer_tokens_base"] = prob_answer_tokens_base.tolist()
    entry["prob_answer_tokens_tuned"] = prob_answer_tokens_tuned.tolist()
    entry["rank_answer_tokens_base"] = rank_answer_tokens_base
    entry["rank_answer_tokens_tuned"] = rank_answer_tokens_tuned

    compare_result = np.append(compare_result,entry)

In [25]:
prob_answer_tokens_base

array([9.9994969e-01, 5.3442858e-02, 7.1736890e-01, 1.5275199e-02,
       8.0417198e-01, 3.9830172e-01, 5.6373912e-01, 6.9525409e-01,
       6.4983737e-01, 6.0942136e-02, 2.3240160e-08, 9.9929285e-01,
       1.2664686e-04, 6.1277401e-01, 9.2264408e-01, 9.3880737e-01,
       9.6634322e-01, 9.8813498e-01, 5.4413527e-01, 9.9347335e-01,
       9.9314958e-01, 9.3289989e-01, 9.5410103e-01, 2.6599032e-01,
       2.0343391e-08, 9.9951375e-01, 6.1964047e-01, 2.2152728e-01,
       1.2404143e-01, 8.2407445e-01, 9.8161045e-06, 1.2965071e-04,
       8.8133997e-01, 5.4067689e-01, 2.9401271e-08, 9.9893993e-01,
       6.4284551e-01, 6.3810110e-01, 1.2558898e-01, 4.6371002e-02,
       1.7151535e-01, 8.0989644e-02, 9.2533576e-01, 7.3697096e-01,
       9.9997032e-01, 8.2334568e-04, 1.0981990e-01, 4.7037312e-01,
       1.3078442e-02, 8.7872791e-01, 1.4307107e-02, 9.2140436e-01,
       3.2061106e-03, 3.6821789e-01, 2.1464918e-03, 2.2792988e-02,
       9.3806744e-01, 5.8900595e-01, 9.9997377e-01, 1.0352468e

In [26]:
prob_answer_tokens_tuned

array([9.9994969e-01, 5.3442858e-02, 7.1736890e-01, 1.5275199e-02,
       8.0417198e-01, 3.9830172e-01, 5.6373912e-01, 6.9525409e-01,
       6.4983737e-01, 6.0942136e-02, 2.3240160e-08, 9.9929285e-01,
       1.2664686e-04, 6.1277401e-01, 9.2264408e-01, 9.3880737e-01,
       9.6634322e-01, 9.8813498e-01, 5.4413527e-01, 9.9347335e-01,
       9.9314958e-01, 9.3289989e-01, 9.5410103e-01, 2.6599032e-01,
       2.0343391e-08, 9.9951375e-01, 6.1964047e-01, 2.2152728e-01,
       1.2404143e-01, 8.2407445e-01, 9.8161045e-06, 1.2965071e-04,
       8.8133997e-01, 5.4067689e-01, 2.9401271e-08, 9.9893993e-01,
       6.4284551e-01, 6.3810110e-01, 1.2558898e-01, 4.6371002e-02,
       1.7151535e-01, 8.0989644e-02, 9.2533576e-01, 7.3697096e-01,
       9.9997032e-01, 8.2334568e-04, 1.0981990e-01, 4.7037312e-01,
       1.3078442e-02, 8.7872791e-01, 1.4307107e-02, 9.2140436e-01,
       3.2061106e-03, 3.6821789e-01, 2.1464918e-03, 2.2792988e-02,
       9.3806744e-01, 5.8900595e-01, 9.9997377e-01, 1.0352468e

In [23]:
rank_answer_tokens_base

[32,
 989,
 188,
 40,
 6037,
 3,
 7,
 231,
 2,
 385,
 5,
 8,
 14,
 350,
 2,
 3,
 0,
 22,
 1,
 76,
 0,
 0,
 44,
 0,
 1,
 14,
 1,
 0,
 0,
 119,
 5,
 8,
 0,
 8,
 0,
 0,
 15,
 18,
 0,
 0,
 2,
 0,
 0,
 0,
 200,
 0,
 17,
 13,
 3,
 0,
 62,
 45,
 0,
 0,
 0,
 0,
 51,
 1,
 0,
 0,
 0,
 91,
 0,
 9629,
 198,
 0,
 0,
 2,
 0,
 1,
 0,
 0,
 0,
 0,
 2,
 2,
 3,
 2,
 19,
 1,
 2,
 1,
 0,
 1,
 0,
 13,
 0,
 64,
 0,
 67,
 530,
 0,
 0,
 0,
 0,
 0,
 2,
 0,
 8,
 0,
 0,
 0,
 0,
 0,
 1,
 124,
 0,
 28,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 186,
 0,
 0,
 1,
 1,
 0,
 94,
 73,
 0,
 0,
 97,
 0,
 0,
 0,
 1,
 2,
 1,
 1,
 0,
 0,
 0,
 26,
 2,
 0,
 10,
 0,
 8,
 0,
 9,
 0,
 47,
 5,
 0,
 0,
 0,
 80,
 32,
 5,
 1,
 2,
 0,
 20,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 60,
 3,
 626,
 28,
 0]

In [24]:
rank_answer_tokens_tuned

[0,
 2,
 0,
 8,
 0,
 0,
 0,
 0,
 0,
 1,
 124,
 0,
 28,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 186,
 0,
 0,
 1,
 1,
 0,
 94,
 73,
 0,
 0,
 97,
 0,
 0,
 0,
 1,
 2,
 1,
 1,
 0,
 0,
 0,
 26,
 2,
 0,
 10,
 0,
 8,
 0,
 9,
 0,
 47,
 5,
 0,
 0,
 0,
 80,
 32,
 5,
 1,
 2,
 0,
 20,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 60,
 3,
 626,
 28,
 0]

# Save Results

In [269]:
name1 = data_file.split('.')[0]

name2 = name1.split('/')[-1]


'entity_tracking_3e_2o_1u_prompt_config_3_Llama-2-7b-chat-hf'

In [271]:
save_path = 'D://Code/entity_tracking_update/compare_chat_tuned_models'

In [287]:
save_name = 'compare_models_' + name2
result_file =  save_path + os.sep + save_name + '.jsonl'
with open(result_file, "w") as outfile:
    for entry in compare_result:
        json.dump(entry, outfile)
        outfile.write('\n')

In [273]:
save_name

'compare_models_entity_tracking_3e_2o_1u_prompt_config_3_Llama-2-7b-chat-hf'