In [59]:
import torch
import pandas as pd
import torch.nn.functional as F
from copy import deepcopy
from tqdm import tqdm
from nnsight import LanguageModel, util
from transformers import AutoTokenizer, AutoModelForCausalLM
from nnsight.tracing.Proxy import Proxy

In [60]:

model_path = {'mistral7b': 'mistralai/Mistral-7B-Instruct-v0.1',
             'falcon7b': 'tiiuae/falcon-7b-instruct',
             'llama7b': '/work/frink/models/Llama-2-7b-chat-hf',
             'flanul2': 'google/flan-ul2'}

def load_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_path[model_name],
                                         cache_dir = '/scratch/ramprasad.sa/huggingface_models')

    
    return tokenizer

In [61]:
tokenizer = load_tokenizer('mistral7b')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})



1

In [4]:
model_name = 'mistral7b'
genaudit_read_path = '/home/ramprasad.sa/probing_summarization_factuality/datasets/Genaudit_annotations.csv'
df_genaudit = pd.read_csv(genaudit_read_path)
df_genaudit_mistral = df_genaudit[df_genaudit['model'] == model_name]
df_genaudit_mistral.head()[:1]
# df_genaudit_mistral

Unnamed: 0.1,Unnamed: 0,id,source,summary,annotated_spans,model,origin
4,4,Rachel Usher#REDDIT-83:mistral7b-ul2,both me and my girlfriend participate in winte...,The document describes the experience of a hig...,participates<sep>with their girlfriend<sep>ends,mistral7b,REDDIT


In [5]:
import json 
prompt_template = f'{{instruction}}{{prompt_prefix}}{{source}}{{prompt_suffix}}{{summary}}'

prompt_template_path = '/home/ramprasad.sa/probing_summarization_factuality/datasets/prompt_templates/'
prompt_type = 'document_context_causal'

with open(f'{prompt_template_path}/{prompt_type}.json', 'r') as fp:
    prompt_dict = json.load(fp)
prompt_dict, prompt_template


({'instruction': 'Generate a summary for the following document in brief. When creating the summary, only use information that is present in the document',
  'prompt_prefix_template': ' CONTENT: ',
  'prompt_suffix_template': ' SUMMARY: '},
 '{instruction}{prompt_prefix}{source}{prompt_suffix}{summary}')

In [6]:
import json


def strip_bos_eos_ids(ids):
    ids = ids[:, 1:] if ids[0][0] in [0,1,2] else ids
    return ids
    
class PromptProcessor():

    def __init__(self,
                 prompt_template,
                 prompt_template_path,
                prompt_type,
                tokenizer):

        self.prompt_template = prompt_template

        with open(f'{prompt_template_path}/{prompt_type}.json', 'r') as fp:
            self.prompt_dict = json.load(fp)
            
        self.instruction = self.prompt_dict['instruction'] if 'instruction' in self.prompt_dict else ''
        self.prefix = self.prompt_dict['prompt_prefix_template'] if 'prompt_prefix_template' in self.prompt_dict else ''
        self.suffix = self.prompt_dict['prompt_suffix_template'] if 'prompt_suffix_template' in self.prompt_dict else ''
        self.tokenizer = tokenizer

    def get_prompt_attributes_idx(self,
                                  prompt_ids,
                                  source,
                                  summary):
        instr_idx = -1
        instr_prefix_idx = -1
        instr_prefix_src_idx = -1
        instr_prefix_src_suffix_idx = -1
        print('attri', len(prompt_ids))
        for span_idx in range( len(prompt_ids)):
            if span_idx %100 == 0:
                print(span_idx)
            span_tokens = prompt_ids[1:span_idx]
            if self.tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                            prompt_prefix = '',
                                            source = '',
                                            prompt_suffix = '',
                                            summary = ''
                                           ).strip():
                instr_idx = span_idx
            
            if self.tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                        prompt_prefix = self.prefix,
                                        source = '',
                                        prompt_suffix = '',
                                        summary = ''
                                       ).strip():
                instr_prefix_idx = span_idx
            
            if self.tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                        prompt_prefix = self.prefix,
                                        source = source,
                                        prompt_suffix = '',
                                        summary = ''
                                       ).strip():
                instr_prefix_src_idx = span_idx
            
            if self.tokenizer.decode(span_tokens) == self.prompt_template.format(instruction = self.instruction,
                                        prompt_prefix = self.prefix,
                                        source = source,
                                        prompt_suffix = self.suffix,
                                        summary = ''
                                       ).strip():
                instr_prefix_src_suffix_idx = span_idx 
                
        return instr_idx, instr_prefix_idx, instr_prefix_src_idx, instr_prefix_src_suffix_idx
    
    def get_nonfactual_span_idx(self,
                                nonfactual_span,
                                summary_tokens,
                                start_idx = -100,
                                end_idx = -100):
        # print(len(summary_tokens), end_idx, nonfactual_span)
        for tok_idx, tok in enumerate(summary_tokens):
            if tok_idx > end_idx:
                tok_str = self.tokenizer.decode(tok)
                if tok_str in nonfactual_span[:len(tok_str)]:
                    start_idx = tok_idx
                
                if self.tokenizer.decode(summary_tokens[start_idx: tok_idx + 1]) == nonfactual_span:
                    end_idx = tok_idx
                    break
        
        return start_idx, end_idx

    def get_summary_labels(self,
                          summary_tokens,
                          nonfactual_spans):
        start_idx = -100
        end_idx = -100
        summary_labels = [0] * len(summary_tokens)
        nonfactual_spans = nonfactual_spans.split('<sep>')

        for i in range(len(nonfactual_spans)):
            nonfactual_span = nonfactual_spans.pop(0)
            start_idx, end_idx = self.get_nonfactual_span_idx(nonfactual_span,
                                             summary_tokens,
                                             start_idx = start_idx,
                                             end_idx = end_idx)
    
            for idx in range(start_idx, end_idx + 1):
                summary_labels[idx] = 1
        return summary_labels
        
    
    def make_prompt_token_labels(self,
                                 source,
                                 summary,
                                 nonfactual_spans):
        source = ' '.join(source.split(' ')[:50])
        prompt = self.prompt_template.format(instruction = self.instruction,
                                    prompt_prefix = self.prefix,
                                    source = source,
                                    prompt_suffix = self.suffix,
                                    summary = summary
                                   )

        
        prompt_tokens = tokenizer(prompt).input_ids
        instr_idx, instr_prefix_idx, instr_prefix_src_idx, instr_prefix_src_suffix_idx = self.get_prompt_attributes_idx(prompt_ids = prompt_tokens,
                                      source = source,
                                      summary = summary)
        
        summary_tokens = prompt_tokens[instr_prefix_src_suffix_idx:]
        
        summary_labels = self.get_summary_labels(summary_tokens,
                          nonfactual_spans)

        return_dict = {'prompt': prompt,
                      'instr_idx': instr_idx,
                      'instr_prefix_idx': instr_prefix_idx,
                      'instr_prefix_src_idx': instr_prefix_src_idx,
                      'instr_prefix_src_suffix_idx': instr_prefix_src_suffix_idx,
                      'summary_labels': summary_labels}
        return return_dict


    

    

In [7]:
prompt_processor = PromptProcessor(prompt_template = prompt_template,
                                   prompt_template_path = prompt_template_path,
                                   prompt_type = prompt_type,
                                   tokenizer = tokenizer
                                  )

In [8]:
counter = 0

for idx, row in df_genaudit_mistral[~df_genaudit_mistral['annotated_spans'].isnull()].iterrows():
    source = row['source']
    summary = row['summary']
    nonfactual_spans = row['annotated_spans']
    prompt_dict = prompt_processor.make_prompt_token_labels(source= source,
                                                            summary = summary,
                                                            nonfactual_spans = nonfactual_spans)
    if counter == 4:
        break
    counter += 1
    print(counter)

attri 231
0
100
200
1
attri 258
0
100
200
2
attri 214
0
100
200
3
attri 250
0
100
200
4
attri 165
0
100


In [9]:
prompt_dict.keys()


dict_keys(['prompt', 'instr_idx', 'instr_prefix_idx', 'instr_prefix_src_idx', 'instr_prefix_src_suffix_idx', 'summary_labels'])

In [10]:
model_path = {'mistral7b': 'mistralai/Mistral-7B-Instruct-v0.1',
             'falcon7b': 'tiiuae/falcon-7b-instruct',
             'llama7b': '/work/frink/models/Llama-2-7b-chat-hf',
             'flanul2': 'google/flan-ul2'}





def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_path[model_name],
                                         cache_dir = '/scratch/ramprasad.sa/huggingface_models')

    model = AutoModelForCausalLM.from_pretrained(model_path[model_name],
                                            cache_dir = '/scratch/ramprasad.sa/huggingface_models')
    model = model.to('cuda')
    return tokenizer, model

In [11]:
tokenizer, mistral_model = load_model(model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = LanguageModel(mistral_model, tokenizer=tokenizer, device_map="auto", dispatch=True)
model



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [12]:
N_LAYERS = len(model.model.layers)
N_LAYERS

32

In [13]:

source = source
summary  = summary


In [14]:
prompt = prompt_dict['prompt']
instr_idx = prompt_dict['instr_idx']
instr_prefix_idx = prompt_dict['instr_prefix_idx']
instr_prefix_src_idx = prompt_dict['instr_prefix_src_idx']
instr_prefix_src_suffix_idx = prompt_dict['instr_prefix_src_suffix_idx']
instr_prefix_src_suffix_idx

97

In [15]:
# model.layers

In [57]:
def get_clean_logits(prompt,
                    model):
    with model.trace() as tracer:
        with tracer.invoke(prompt) as invoker:
            
            clean_logits = model.lm_head.output[0]
            clean_logits = F.softmax(clean_logits, dim = 1).detach().cpu().save()
            
            clean_hs = [
                model.model.layers[layer_idx].output[0].detach().cpu().save()
                for layer_idx in range(N_LAYERS)
            ]
            
            clean_input_embeddings = model.model.embed_tokens.output.detach().cpu().save()
            
    return clean_input_embeddings, clean_hs, clean_logits

def get_corrupted_logits(corrupted_prompt,
                        model):



    with model.trace() as tracer:
        with tracer.invoke(corrupted_prompt) as invoker:
            
            corrupted_logits = model.lm_head.output[0]
            corrupted_logits = F.softmax(corrupted_logits, dim = 1).detach().cpu().save()
            
            corrupted_hs = [
                model.model.layers[layer_idx].output[0].detach().cpu().save()
                for layer_idx in range(N_LAYERS)
            ]
            
            noised_embeddings = model.model.embed_tokens.output.detach().cpu().save()
    return noised_embeddings, corrupted_hs, corrupted_logits


def get_patched_logits(corrupted_prompt,
                      layer_idx,
                      token_idx):
    with model.trace() as tracer:
        with tracer.invoke(corrupted_prompt) as invoker:
            model.model.layers[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx].t[token_idx]
            
            patched_logits = model.lm_head.output[0]
            patched_logits = F.softmax(patched_logits, dim = 1).detach().cpu().save()
            patched_hs = [
                model.model.layers[layer_idx].output[0].detach().cpu().save()
                for layer_idx in range(N_LAYERS)
            ]
            
            patched_embeddings = model.model.embed_tokens.output.detach().cpu().save()
            
    return patched_embeddings, patched_hs, patched_logits


def corrupt_prompt(prompt_tokens,
                   corruption_idx):
    repl_token = tokenizer('_').input_ids[1:]
    assert(len(repl_token) == 1)
    repl_token = repl_token[0]

    corrupted_tokens = [repl_token if tok_idx in corruption_idx else tok.item() for tok_idx, tok in enumerate(prompt_tokens)] 
    corrupted_prompt = tokenizer.decode(corrupted_tokens)
    return corrupted_prompt


with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        clean_tokens = model.input[1]["input_ids"].squeeze().save()

clean_input_embeddings, clean_hs, clean_logits = get_clean_logits(prompt,
                    model)

corruption_idx = [i for i in range(prompt_dict['instr_idx'])]
corrupted_prompt = corrupt_prompt(prompt_tokens = clean_tokens,
               corruption_idx = corruption_idx)
noised_embeddings, corrupted_hs, corrupted_logits = get_corrupted_logits(corrupted_prompt,
                                                                             model = model)

##### corruption with replacement for summary tokens

# summary_tokens = clean_tokens[instr_prefix_src_suffix_idx:]
# summary_label = prompt_dict['summary_labels']
# assert(summary_tokens.shape[-1] == len(summary_label))
# summary_token_patching_results = []
# summ_idx = 0
# for tidx, tgt_token in enumerate(clean_tokens):
#         if tidx >= instr_prefix_src_suffix_idx:
#             assert(tgt_token == summary_tokens[summ_idx])
            

#             layer_wise_patching_results = []
# #             with model.trace() as tracer:
#             for layer_idx in range(len(model.model.layers)):
                    
#                     _, patched_hs, patched_logits =  get_patched_logits(prompt,
#                       noised_embeddings,
#                       layer_idx,
#                       token_idx = tidx - 1)

#                     append_dict = {
#                         'layer': layer_idx,
#                         'target': tgt_token.item(),
#                         'predicted': torch.argmax(clean_logits[tidx - 1]).item(),
#                         'factual_label': summary_label[summ_idx],
#                         'prob_clean': clean_logits[tidx - 1][tgt_token].item(),
#                         'prob_corrupted': corrupted_logits[tidx - 1][tgt_token].item(),
#                         'prob_patched': patched_logits[tidx - 1][tgt_token].item()
#                     }
#                     layer_wise_patching_results.append(append_dict)
#                     break
                

#             summary_token_patching_results.append(layer_wise_patching_results)
#             summ_idx += 1
#             if summ_idx == 10:
#                 break


    
    

In [58]:
corrupted_hs, c

NameError: name 'c' is not defined

In [None]:
len(summary_token_patching_results)
import re
import string

def check_only_trailing_punctuations(token_list):
    regex = re.compile('[%s]' % re.escape(string.punctuation))
    for idx in [0, len(token_list) - 1]:
        if idx <= len(token_list) - 1:
            if not regex.sub('', token_list[idx][1]).strip():
                token_list.pop(idx)
    return token_list

def add_word_list(token_list, word_list):
    token_list = check_only_trailing_punctuations(token_list)
    if token_list:
        word_list.append(token_list)
    return word_list
    
def get_word_based_idx(summary_tokens):
    all_words = []
    token_list = []
    for summ_tok_idx, summ_tok in enumerate(summary_tokens):
        prefix_tok_str = ''
        current_tok_str = tokenizer.decode(summ_tok)
        if summ_tok_idx > 0:
            prefix_tok_str = tokenizer.decode([summary_tokens[summ_tok_idx - 1], summ_tok])

        #### new token is start of a new word 
        if len(prefix_tok_str.split(' '))> 1:

            ### append stored tokens of a previous word
            if len(token_list) > 0:
                all_words = add_word_list(token_list, all_words)

            #### now reset token list for new word 
            token_list = [(summ_tok_idx, current_tok_str, summ_tok.item(), )]
            
        else:
            token_list.append((summ_tok_idx, current_tok_str, summ_tok.item()))

    all_words = add_word_list(token_list, all_words)
    return all_words



    # print(current_tok, len(word_list))
    # if len(prefix_str.split(' ')) 
    # if prev_tok is None or 
    # summ_tok_label = summary_label[summ_tok_idx]
    # summ_tok_layers_patches = summary_token_patching_results[summ_tok_idx]
    # for layer_idx, tgt_token, _, pred_token, p_clean, p_corr, p_patched in summ_tok_layers_patches:
    #     ie_score = (p_patched - p_corr) / (p_clean - p_corr)
        

In [102]:
summary_word_list = get_word_based_idx(summary_tokens)

In [103]:
summary_word_list[:5]

[[(0, 'The', 415)],
 [(1, 'document', 3248)],
 [(2, 'describes', 13966)],
 [(3, 'a', 264)],
 [(4, 'situation', 4620)]]

In [104]:
for summ_word in summary_word_list:
    print(summ_word)

[(0, 'The', 415)]
[(1, 'document', 3248)]
[(2, 'describes', 13966)]
[(3, 'a', 264)]
[(4, 'situation', 4620)]
[(5, 'where', 970)]
[(6, 'a', 264)]
[(7, 'person', 1338)]
[(8, 'switched', 21187)]
[(9, 'out', 575)]
[(10, 'their', 652)]
[(11, 'ear', 8120), (12, 'rings', 26661)]
[(13, 'and', 304)]
[(14, 'tried', 3851)]
[(15, 'to', 298)]
[(16, 'clean', 3587)]
[(17, 'them', 706)]
[(18, 'using', 1413)]
[(19, 'per', 660), (20, 'ox', 1142), (21, 'ide', 547)]
[(23, 'However', 2993)]
[(25, 'the', 272)]
[(26, 'per', 660), (27, 'ox', 1142), (28, 'ide', 547)]
[(29, 'was', 403)]
[(30, 'left', 1749)]
[(31, 'in', 297)]
[(32, 'the', 272)]
[(33, 'ear', 8120), (34, 'rings', 26661)]
[(35, 'overnight', 22128)]
[(36, 'and', 304)]
[(37, 'exposed', 13438)]
[(38, 'to', 298)]
[(39, 'light', 2061)]
[(41, 'causing', 13098)]
[(42, 'burn', 5698), (43, 's', 28713)]
[(44, 'and', 304)]
[(45, 'ble', 8012), (46, 'aching', 10028)]
[(47, 'on', 356)]
[(48, 'the', 272)]
[(49, 'person', 1338), (50, "'", 28742), (51, 's', 28713)]

In [85]:
len(summary_label)

68

In [91]:
summary_token_patching_results[1][:1]

[(0,
  3248,
  0,
  3248,
  0.45301350951194763,
  8.643341425340623e-05,
  8.798035560175776e-05)]