In [1]:
import sys, time, pickle, torch
sys.path.insert(0, '../../Models')
sys.path.insert(0, '../../Utils')
sys.path.insert(0, '../../Preprocess')
import numpy as np
import pandas as pd
from preload_models import get_sst2_tok_n_model
from _utils import sample_random_glue_sst2, get_continuation_mapping, \
                    get_continuous_attributions, get_continuous_raw_inputs, \
                    collect_info_for_metric, save_info

In [2]:
sst2_data_raw, targets, idxs = sample_random_glue_sst2()

Reusing dataset glue (/home/user/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/user/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-3b24abff24d1d8c0.arrow
Loading cached processed dataset at /home/user/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5960909ab3834668.arrow


In [3]:
tokenizer, model = get_sst2_tok_n_model()

In [4]:
#define some containers to save some info
model_out_list, raw_attr_list, conti_attr_list, raw_input_list = [], [], [], []

In [5]:
from captum.attr import Lime
from captum._utils.models.linear_model import SkLearnLasso
from captum.attr import visualization 

In [35]:
lime = Lime(model, interpretable_model=SkLearnLasso(alpha=0.0001))

In [36]:
def generate_record(raw_review, target):
    #tokenizer operations
    tokenized = tokenizer(raw_review, truncation=True, return_offsets_mapping=True)
    offset_mapping = tokenized['offset_mapping']
    conti_map = get_continuation_mapping(offset_mapping)
    input_ids = torch.tensor(tokenized['input_ids']).unsqueeze(0)
    detokenized = [t.replace('Ġ', '') for t in tokenizer.convert_ids_to_tokens(input_ids[0])]
    
    #feeding input forward 
    input_emb = model.get_embeddings(input_ids)
    pred_prob = model(input_emb).item()
    
    #categorizing results
    pred_class = 'Pos' if pred_prob > 0.5 else 'Neg' 
    true_class = 'Pos' if target > 0.5 else 'Neg' 
    
    #attribution algorithm working
    attribution = lime.attribute(input_emb)
    word_attributions = attribution.squeeze(0).sum(dim=1)
    word_attributions /= torch.norm(word_attributions)
    attr_score = torch.sum(word_attributions)
    attr_class = 'Pos' if attr_score > 0.5 else 'Neg'
    convergence_score = None
    
    
    #re-organizing tensors and arrays because words get split down
    conti_attr = get_continuous_attributions(conti_map, word_attributions)
    raw_input = get_continuous_raw_inputs(conti_map, detokenized)

#     print(f'word attributions {word_attributions}')
#     print(f'pred_prob {pred_prob}')
#     print(f'pred_class {pred_class}')
#     print(f'true_class {true_class}')
#     print(f'attribution {attribution}')
#     print(f'attr_class {attr_class}')
#     print(f'attr_score {attr_score}')
#     print(f'raw_input {raw_input}')

        
#     collect info for metrics later
    collect_info_for_metric(model_out_list, pred_prob, raw_attr_list, attribution, conti_attr_list, conti_attr, raw_input_list, raw_input)
        
    
    visual_record = visualization.VisualizationDataRecord(word_attributions=word_attributions,
                                                         pred_prob=pred_prob,
                                                         pred_class=pred_class,
                                                         true_class=true_class,
                                                         attr_class=attr_class,
                                                         attr_score=attr_score,
                                                         raw_input=raw_input,
                                                         convergence_score=convergence_score)
        
        
    return visual_record
      
    

In [37]:
for i, (datum_raw, target) in enumerate(zip(sst2_data_raw, targets), start=1):
    print(f'Raw review: {datum_raw}')
    print(f'GT target: {target}')
    visual_record=generate_record(datum_raw, target)
    print(visualization.visualize_text([visual_record]))
    if i > 3:
        break
   

Raw review: against all odds 
GT target: 1
asd 6


  model = cd_fast.enet_coordinate_descent(


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Pos,Pos (0.99),Neg,-0.65,#s against all odds #/s
,,,,


<IPython.core.display.HTML object>
Raw review: i had a dream that a smart comedy would come along to rescue me from a summer of teen-driven , toilet-humor codswallop 
GT target: 1




asd 25


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Pos,Pos (1.00),Pos,1.31,"#s i had a dream that a smart comedy would come along to rescue me from a summer of teen-driven , toilet-humor codswallop #/s"
,,,,


<IPython.core.display.HTML object>
Raw review: he makes sure the salton sea works the way a good noir should , keeping it tight and nasty 
GT target: 1




asd 22


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Pos,Pos (1.00),Neg,,"#s he makes sure the salton sea works the way a good noir should , keeping it tight and nasty #/s"
,,,,


<IPython.core.display.HTML object>
Raw review: suicidal 
GT target: 0
asd 4


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Neg,Neg (0.00),Pos,1.4,#s suicidal #/s
,,,,


<IPython.core.display.HTML object>


In [38]:
save_info(idxs, sst2_data_raw, targets, model_out_list, raw_attr_list, conti_attr_list, raw_input_list, fname='lime_out.pkl')

{'indices': [53963,
  64202,
  30501,
  50103,
  8727,
  17988,
  540,
  48977,
  25936,
  4415,
  28658,
  17550,
  28410,
  35211,
  60439,
  65301,
  53290,
  26647,
  23098,
  60767,
  60320,
  4954,
  40015,
  37441,
  3596,
  21218,
  46224,
  38829,
  17600,
  65782,
  25193,
  46255,
  17077,
  28351,
  28283,
  52895,
  67213,
  58524,
  18588,
  28334,
  47855,
  24000,
  63203,
  26377,
  60596,
  5452,
  45008,
  62398,
  44036,
  61886,
  20763,
  21053,
  51913,
  65777,
  57754,
  59641,
  32649,
  43453,
  34760,
  34688,
  19099,
  52078,
  23881,
  44711,
  35190,
  41947,
  42403,
  44743,
  37646,
  52526,
  2452,
  36111,
  28542,
  42135,
  14402,
  42458,
  57326,
  19943,
  47302,
  30747,
  24723,
  22180,
  25634,
  16318,
  26867,
  35611,
  61283,
  53360,
  54581,
  39428,
  54341,
  63990,
  26324,
  3295,
  25595,
  3354,
  15446,
  15854,
  46797,
  49950],
 'raw_data': ['against all odds ',
  'i had a dream that a smart comedy would come along to rescue