In [1]:
import torch
import json
import argparse
import pandas as pd
import torch.nn.functional as F
from copy import deepcopy
from tqdm import tqdm
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy
from transformers import AutoTokenizer, AutoModelForCausalLM
from scripts.models.prompt_processor import PromptProcessor
from scripts.utils import load_model

In [2]:
df_genaudit  = pd.read_csv('/home/ramprasad.sa/probing_summarization_factuality/datasets/Genaudit_annotations.csv')
df_genaudit['docid_processed'] = [row['id'].split('#')[-1] for idx, row in df_genaudit.iterrows()]
df_genaudit.head()

Unnamed: 0.1,Unnamed: 0,id,source,summary,annotated_spans,model,origin,docid_processed
0,0,Rachel Usher#REDDIT-83:flanul2-ul2,both me and my girlfriend participate in winte...,I waited outside the locker rooms for my girlf...,nobody,flanul2,REDDIT,REDDIT-83:flanul2-ul2
1,1,Rachel Usher#REDDIT-83:llama70b-ul2,both me and my girlfriend participate in winte...,A high school student was waiting for his girl...,,llama70b,REDDIT,REDDIT-83:llama70b-ul2
2,2,Rachel Usher#REDDIT-83:falcon7b-ul2,both me and my girlfriend participate in winte...,The protagonist is concerned about a potential...,concerned about a potential delay in meeting<s...,falcon7b,REDDIT,REDDIT-83:falcon7b-ul2
3,3,Rachel Usher#REDDIT-83:llama7b-ul2,both me and my girlfriend participate in winte...,The writer and their girlfriend are both invol...,is late<sep>bus and doesn't realize that their...,llama7b,REDDIT,REDDIT-83:llama7b-ul2
4,4,Rachel Usher#REDDIT-83:mistral7b-ul2,both me and my girlfriend participate in winte...,The document describes the experience of a hig...,participates<sep>with their girlfriend<sep>ends,mistral7b,REDDIT,REDDIT-83:mistral7b-ul2


In [3]:

def get_nnsight_model_wrapper(model_name):
    tokenizer, mistral_model = load_model(model_name)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model = LanguageModel(mistral_model, tokenizer=tokenizer, device_map="auto", dispatch=True)
    return tokenizer, model

class MakeDataAIE:

    def __init__(self,
                model_name,
                prompt_template,
                prompt_template_path,
                prompt_type
                 ):
        
        self.tokenizer, self.model = get_nnsight_model_wrapper(model_name)

        self.prompt_processor  = PromptProcessor(prompt_template = prompt_template,
                                   prompt_template_path = prompt_template_path,
                                   prompt_type = prompt_type,
                                   tokenizer = self.tokenizer
                                  )
        
    def get_indirect_effect(self,
                            prompt_dict):

        with self.model.trace() as tracer:
            with tracer.invoke(prompt_dict['prompt']) as invoker:
                clean_tokens = self.model.input[1]["input_ids"].squeeze().save()

        
        prompt = prompt_dict['prompt']
        N_LAYERS = len(self.model.model.layers)
        summary_labels = prompt_dict['summary_labels']
        summary_tokens = clean_tokens[prompt_dict['instr_prefix_src_suffix_idx']:]
        corruption_idx = [i for i in range(prompt_dict['instr_idx'])]
        assert(summary_tokens.shape[-1] == len(summary_labels))

        corrupted_prompt = self.corrupt_prompt(prompt_tokens = clean_tokens,
                                               corruption_idx = corruption_idx)

        clean_input_embeddings, clean_hs, clean_logits = self.get_clean_logits( prompt = prompt)

        '''noising all tokens of the instructon'''
        noised_embeddings, corrupted_hs, corrupted_logits = self.get_corrupted_logits(corrupted_prompt=corrupted_prompt)

        summary_token_patching_results = []
        summ_idx = 0
        for tidx, tgt_token in enumerate(clean_tokens):
                ''' after instr, prefix, src, suffix is summary tokens'''
                if tidx >= prompt_dict['instr_prefix_src_suffix_idx']:
                    assert(tgt_token == summary_tokens[summ_idx])
                    

                    layer_wise_patching_results = []
                    ''' iterate all layers'''
                    for layer_idx in range(len(self.model.model.layers)):
                            
                            _, patched_hs, patched_logits =  self.get_patched_logits(corrupted_prompt = corrupted_prompt,
                                                                            clean_hs = clean_hs,
                                                                            layer_idx = layer_idx,
                                                                            token_idx = tidx - 1)
                            
                            #### check if while predicting tgt token the corrupted runs layer is not same as pattched run layer
                            assert( not torch.allclose(corrupted_hs[layer_idx][:, tidx - 1, :], patched_hs[layer_idx][:, tidx - 1, :]))

                            #### check if while predicting tgt token the clean runs layer is same as pattched run layer
                            assert(  torch.allclose(clean_hs[layer_idx][:, tidx - 1, :], patched_hs[layer_idx][:, tidx - 1, :]))
                            
                            

                            append_dict = {
                                'layer': layer_idx,
                                'target': tgt_token.item(),
                                'predicted': torch.argmax(clean_logits[tidx - 1]).item(),
                                'factual_label': summary_labels[summ_idx],
                                'prob_clean': clean_logits[tidx - 1][tgt_token].item(),
                                'prob_corrupted': corrupted_logits[tidx - 1][tgt_token].item(),
                                'prob_patched': patched_logits[tidx - 1][tgt_token].item()
                            }
                            layer_wise_patching_results.append(append_dict)
                        

                    summary_token_patching_results.append(layer_wise_patching_results)
                    summ_idx += 1

        return summary_token_patching_results
        
    def get_layerwise_causal_analysis(self,
                                      df,
                                      write_path):
            
            for idx, row in tqdm(df[~df['annotated_spans'].isnull()].iterrows(), total = len(df[~df['annotated_spans'].isnull()])):
                uid = row['id']
                uid = '_'.join(uid.split())
                source = row['source']
                summary = row['summary']
                nonfactual_spans = row['annotated_spans']
                prompt_dict = self.prompt_processor.make_prompt_token_labels(source= source,
                                                                        summary = summary,
                                                                        nonfactual_spans = nonfactual_spans)
                print(prompt_dict)
                # summary_patching_results = self.get_indirect_effect(prompt_dict= prompt_dict)
                # filename = f'{write_path}/{uid}.jsonl'
                # self.write_patched_results(summary_patching_results = summary_patching_results,
                #                            write_path = filename)
                return prompt_dict

            # return

In [4]:
prompt_template_path = '/home/ramprasad.sa/probing_summarization_factuality/datasets/prompt_templates/'
prompt_type = 'document_context'
prompt_template = f'{{instruction}}{{prompt_prefix}}{{source}}{{prompt_suffix}}{{summary}}'
model_name = 'mistral7b' 
origin = 'XSUM'

df_filtered = df_genaudit[(df_genaudit['model'] == model_name) & (df_genaudit['origin'] == origin)]

nnsight_patcher = MakeDataAIE(model_name = model_name,
                prompt_template = prompt_template,
                prompt_template_path = prompt_template_path,
                prompt_type = prompt_type)





Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
df_filtered[~df_filtered['annotated_spans'].isnull()]

Unnamed: 0.1,Unnamed: 0,id,source,summary,annotated_spans,model,origin,docid_processed
112,112,Cynthia Lamanda#XSUM-35862185:mistral7b-ul2,England failed to build on the optimism genera...,The document discusses England's loss to the N...,in London<sep>the winner,mistral7b,XSUM,XSUM-35862185:mistral7b-ul2
373,373,Cynthia Lamanda#XSUM-34846955:mistral7b-ul2,Welsh housing associations directly contribute...,Welsh housing associations made a significant ...,£1.1,mistral7b,XSUM,XSUM-34846955:mistral7b-ul2
404,404,Rachel Usher#XSUM-34007864:mistral7b-ul2,The results have been published for more than ...,"The GCSE results have been published, showing ...",GC,mistral7b,XSUM,XSUM-34007864:mistral7b-ul2
416,416,Cynthia Lamanda#XSUM-37820092:mistral7b-ul2,Shanghai is trialling a unisex public toilet b...,Shanghai is trialting a unisex public toilet b...,trialting,mistral7b,XSUM,XSUM-37820092:mistral7b-ul2
500,500,Cynthia Lamanda#XSUM-19389161:mistral7b-ul2,"Indian Olympics bronze medallist, boxer MC Mar...",Bollywood director Sanjay Leela Bhansali plans...,with Mary Kom herself<sep>leadership,mistral7b,XSUM,XSUM-19389161:mistral7b-ul2
533,533,Rachel Usher#XSUM-38021627:mistral7b-ul2,In the Scottish Football Association's stateme...,The Scottish Football Association (SFA) made a...,regarding<sep>resignation<sep>finding replacem...,mistral7b,XSUM,XSUM-38021627:mistral7b-ul2
579,579,Cynthia Lamanda#XSUM-31763768:mistral7b-ul2,Bestival has revealed an all-female line-up in...,Bestival has announced an all-female line-up f...,male,mistral7b,XSUM,XSUM-31763768:mistral7b-ul2
688,688,Rachel Usher#XSUM-35643091:mistral7b-ul2,At Harper Adams University they are fitting tr...,Harper Adams University is conducting research...,and habitat,mistral7b,XSUM,XSUM-35643091:mistral7b-ul2
694,694,Rachel Usher#XSUM-39387550:mistral7b-ul2,US President Donald Trump has withdrawn his he...,"The American Health Care Act, a healthcare bil...",and the Republican Party,mistral7b,XSUM,XSUM-39387550:mistral7b-ul2


In [6]:
prompt_dict = nnsight_patcher.get_layerwise_causal_analysis(df_filtered,
                                             write_path = '')

  0%|                                                                                                                                                                                                                                                     | 0/9 [00:00<?, ?it/s]

attri 257
0
100
200


  0%|                                                                                                                                                                                                                                                     | 0/9 [00:00<?, ?it/s]

{'prompt': "Generate a summary for the following document in brief. When creating the summary, only use information that is present in the documentGenerate a summary for the following document in brief. When creating the summary, only use information that is present in the document CONTENT:England failed to build on the optimism generated by their thrilling victory in Germany as they were beaten in a friendly by the Netherlands, who have not even qualified for Euro 2016.. On a night when Wembley paid its respects to the late Dutch legend Johan Cruyff with applause in SUMMARY:The document discusses England's loss to the Netherlands in a friendly football match, which took place at Wembley Stadium in London.. Despite an impressive victory over Germany in their previous friendly, England failed to build on their momentum and lost 2-1 to the Netherlands.. Jamie Vardy, the man-of-the-match, scored England's goal with a slick passing move that led to a cross from Kyle Walker, but England's d




In [7]:
prompt_tokens = nnsight_patcher.tokenizer(prompt_dict['prompt'])
clean_tokens = prompt_tokens['input_ids']

In [8]:
prompt_template = prompt_dict['prompt_template']
prompt_prefix = prompt_dict['prompt_template']
prompt_suffix = prompt_dict['prompt_template']


In [32]:
from scripts.dataset_creators.coco_mask import Coco

coco_class = Coco()

prompt = prompt_dict['prompt']
N_LAYERS = len(nnsight_patcher.model.model.layers)
summary_labels = prompt_dict['summary_labels']
summary_tokens = clean_tokens[prompt_dict['instr_prefix_src_suffix_idx']:]
corruption_idx = [i for i in range(prompt_dict['instr_idx'])]
    
    
    
#### corrupt prompt of idx - 1 
# summary = nnsight_patcher.tokenizer.decode(summary_tokens)
print(nnsight_patcher.tokenizer.decode(summary_tokens),prompt_dict['summary'])
            
summary_keywords = coco_class.get_masked_token_list(summary)
# summary_keywords

The document discusses England's loss to the Netherlands in a friendly football match, which took place at Wembley Stadium in London.. Despite an impressive victory over Germany in their previous friendly, England failed to build on their momentum and lost 2-1 to the Netherlands.. Jamie Vardy, the man-of-the-match, scored England's goal with a slick passing move that led to a cross from Kyle Walker, but England's defence was unable to protect their lead.. Danny Rose was punished for a sloppy pass that led to a penalty, and Vincent Janssen scored the winner from the spot. summary
DIFF in masked {',', 'an', 'and', 'but', 'a', '.', "'s", '..', 'Despite', 'the', 'The', '-'}


In [16]:
mask_token = '_'
nnsight_patcher.tokenizer(mask_token)
# prompt_dict['prompt'], prompt_dict

{'input_ids': [1, 583], 'attention_mask': [1, 1]}

In [24]:


corrupt_token = nnsight_patcher.tokenizer('_').input_ids[1:][0]

summ_idx = 0
for tidx, tgt_token in enumerate(clean_tokens):
    ''' after instr, prefix, src, suffix is summary tokens'''
    if tidx >= prompt_dict['instr_prefix_src_suffix_idx']:
        assert(tgt_token == summary_tokens[summ_idx])
        tgt_token_str = nnsight_patcher.tokenizer.decode(tgt_token)
        # corrupted_idx = tidx - 1
        # corrupt_summary_tokens = clean_tokens[:corrupted_idx] + [corrupt_token] + clean_tokens[corrupted_idx + 1:]
        # print(nnsight_patcher.tokenizer.decode(corrupt_summary_tokens), nnsight_patcher.tokenizer.decode(tgt_token))
        summ_idx += 1
    
        keywords_i = []
        if tgt_token_str and len(tgt_token_str) > 1:
            keywords_i = list(set([each for each in summary_keywords if each.startswith(tgt_token_str)]))
            
            masked_document = coco_class.mask_document(source_doc = prompt_dict['source_doc'], 
                                                      masked_token_list = keywords_i, 
                                                      mask_token = mask_token,
                                                      mask_strategy ='span')
            if masked_document != prompt_dict['source_doc']:
                # print(tidx, tgt_token_str, keywords_i, masked_document)
                new_prompt  = prompt_template.format(instruction = prompt_dict['instruction'],
                                                     prompt_prefix = prompt_dict['prompt_prefix'],
                                                     source = masked_document,
                                                     prompt_suffix = prompt_dict['prompt_suffix'],
                                                    summary = summary
                                                    )
                # print(new_prompt)
            # new_prompt = 
        # print('***'* 13)
        

In [30]:
#### clean_prompt
### corrupted prompt 

# prompt_dict
# summary_keywords

['document',
 'discusses',
 'England',
 'loss',
 'to',
 'Netherlands',
 'in',
 'friendly',
 'football',
 'match',
 'which',
 'took',
 'place',
 'at',
 'Wembley',
 'Stadium',
 'in',
 'London',
 'impressive',
 'victory',
 'over',
 'Germany',
 'in',
 'their',
 'previous',
 'friendly',
 'England',
 'failed',
 'build',
 'on',
 'their',
 'momentum',
 'lost',
 '2',
 '1',
 'to',
 'Netherlands',
 'Jamie',
 'Vardy',
 'man',
 'of',
 'match',
 'scored',
 'England',
 'goal',
 'with',
 'slick',
 'passing',
 'move',
 'that',
 'led',
 'to',
 'cross',
 'from',
 'Kyle',
 'Walker',
 'England',
 'defence',
 'was',
 'unable',
 'protect',
 'their',
 'lead',
 'Danny',
 'Rose',
 'was',
 'punished',
 'for',
 'sloppy',
 'pass',
 'that',
 'led',
 'to',
 'penalty',
 'Vincent',
 'Janssen',
 'scored',
 'winner',
 'from',
 'spot']