In [9]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import logging as hf_logging
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import Dataset
from sklearn.utils.class_weight import compute_class_weight
import nltk
from nltk.corpus import stopwords
import argparse
import numpy as np
import string
import difflib
import string
import re 
import difflib
import string
from sklearn.model_selection import train_test_split
from scripts.utils import load_model, longest_common_substring_ignore_punctuations

nltk.download('stopwords')
stop_words = stopwords.words('english')
hf_logging.set_verbosity_error()

device = 'cuda'

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/ramprasad.sa/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [6]:
class ModelStateDataset():
    def __init__(self, 
                 model_name, 
                 tokenizer, 
                 doc_truncate = 500, 
                 summary_truncate = 500):
        
        tokenizer, model = load_model(model_name)
        tokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id == None else tokenizer.pad_token_id
        self.model = model
        self.tokenizer = tokenizer
        #### Do not use layer 0 ###
        self.doc_truncate = doc_truncate
        self.summary_truncate = summary_truncate

    def make_prompt_ids(self, 
                       prompt_prefix,
                       prompt_suffix):
        # self.instruction = "Generate a summary for the following document in brief. When creating the summary, only use information that is present in the document."
        # self.prompt_prefix = f'{instruction}\nDocument:'
        # self.prompt_suffix = f'\nSummary:'
        
        self.prefix_ids = self.tokenizer(self.prompt_prefix, return_tensors="pt").input_ids
        suffix_ids = self.tokenizer(self.prompt_suffix, return_tensors="pt").input_ids
        suffix_ids = suffix_ids[:, 1:] if suffix_idx[0][0] in [0,1,2]  else suffix_ids
        self.suffix_ids = suffix_ids
    return

    def get_tokens_labels(self,
                         inconsistent_spans, 
                         summary):
        
        labels = [0] * len(summary.split(' '))
        
        for nonfactual_span in inconsistent_spans:
            processed_nonfactual_span = longest_common_substring_ignore_punctuations(nonfactual_span, summary)
        
            if processed_nonfactual_span != None and re.search(processed_nonfactual_span, summary):
                start_idx, end_idx = re.search(processed_nonfactual_span, summary).span()
                curr_char_idx = 0
            
                for widx, w in enumerate(summary.split(' ')):
                    end_char_idx = curr_char_idx + (len(w) - 1 )
                    assert (summary[curr_char_idx: end_char_idx + 1] == w)
                    label = 0
                    if curr_char_idx>= start_idx and end_char_idx <= end_idx:
                        label = 1
                        labels[widx] = 1
                    curr_char_idx = end_char_idx + 2
            else:
                print('ERR', nonfactual_span, '|', summary)
                print('***'* 13)
                    
        words_labels = list(zip(summary.split(' '), labels))
    
        summ_tokens = []
        summ_token_labels = []
        for w, l in words_labels:
            word_tokens = self.tokenizer(f'{w}').input_ids
            word_tokens = word_tokens[1:] if word_tokens[0] in [0,1,2] else word_tokens
            # word_tokens = self.tokenizer(f'{w}').input_ids[1:]
            summ_tokens += word_tokens
            summ_token_labels += [l] * len(word_tokens)

        assert(len(summ_tokens) == len(summ_token_labels))
        return summ_tokens, summ_token_labels


    def make_truncated_prompt_tokens_labels(self,
                                            doc,
                                            summary,
                                            inconsistent_spans):

        {
            'all_prompt_tokens': ,
            'summ_tokens':
            'summ_tokens_labels':
            'all_prompt_tokens_padded': 
            'summ_tokens_padded':
            'summ_tokens_labels_padded': 

        }
        doc_ids = self.tokenizer(doc, return_tensors="pt").input_ids
        doc_ids = doc_ids[:,1:] if doc_ids[0][0] in [0,1,2] else doc_ids
        summ_tokens, summ_tokens_labels = self.get_tokens_labels(inconsistent_spans, summary)

        ####truncate all #####
        doc_ids = doc_ids[:,:self.doc_truncate]
        summ_tokens = summ_tokens[:self.summary_truncate]
        summ_token_labels = summ_token_labels[:self.summary_truncate]

        prompt_ids = torch.cat([self.prefix_ids, doc_ids, self.suffix_ids], dim = -1)
        all_tokens = torch.cat([prompt_ids.squeeze(0), torch.tensor(summ_tokens)], dim = -1)

        if padding:
            max_rows = self.doc_truncate + self.summary_truncate
            pad_rows = max_rows - all_tokens.shape[-1]
            all_tokens = all_tokens.tolist() + [self.tokenizer.pad_token_id] * pad_rows
            all_tokens = torch.tensor(all_tokens)
            
            max_summ_tokens = self.summary_truncate 
            pad_tokens = max_summ_tokens - len(summ_token_labels)
            summ_token_labels = summ_token_labels + [-100] * pad_tokens
            summ_token_labels = torch.tensor(summ_token_labels)
            
        return all_tokens, summ_tokens, summ_token_labels
        
    def get_internal_states(self,
                            doc,
                            summary,
                            nonfactual_spans,
                            padding = False):
        
        inconsistent_spans = nonfactual_spans.split('<sep>') if type(nonfactual_spans) is str else []

        all_tokens, summ_tokens, summ_token_labels = self.make_truncated_prompt_tokens_labels(doc,
                                            summary,
                                            inconsistent_spans)
        

        
        with torch.no_grad():
            outputs = self.model(all_tokens.unsqueeze(0).to(device),
               output_hidden_states = True,
               output_attentions = True,
               return_dict = True)

        example_dict['all_tokens'] = all_tokens
        example_dict['hidden_states'] = torch.cat(outputs['hidden_states']).cpu()
        example_dict['attentions'] = torch.cat(outputs['attentions']).cpu().cpu()
        example_dict['source_len'] = torch.tensor([source_len])
        example_dict['summary_len'] = torch.tensor([summary_len])
        example_dict['summary_token_labels'] = summ_token_labels
        if 'id' in row:
            example_dict['id'] = row['id']
        
        del outputs
        return example_dict


Collecting nltk
  Downloading nltk-3.8.1-py3-none-any.whl.metadata (2.8 kB)
Collecting click (from nltk)
  Downloading click-8.1.7-py3-none-any.whl.metadata (3.0 kB)
Downloading nltk-3.8.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading click-8.1.7-py3-none-any.whl (97 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m97.9/97.9 kB[0m [31m566.3 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: click, nltk
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
typer 0.3.2 requires click<7.2.0,>=7.1.1, but you have click 8.1.7 which is incompatible.[0m[31m
[0mSuccessfully installed click-8.1.7 nltk-3.8.1
Note: you may need to restart the kernel to use updated packages.
