In [None]:
import sys, time
import torch
sys.path.insert(0, '../../Models')
sys.path.insert(0, '../../Utils')
from _utils import  sample_random_glue_stsb, collect_info_for_metric, save_info, \
                    get_continuation_mapping, get_continuous_attributions, get_continuous_raw_inputs,\
                    attr_normalizing_func
from preload_models import get_stsb_tokenizer_n_model

In [None]:
stsb_data_raw, targets, idxs = sample_random_glue_stsb()

In [None]:
tokenizer, model = get_stsb_tokenizer_n_model()

In [None]:
#define some containers to save some info
model_out_list, raw_attr_list, conti_attr_list, raw_input_list = [], [], [], []

In [None]:
from captum.attr import Lime
from captum._utils.models.linear_model import SkLearnLasso
from captum.attr import visualization 

In [None]:
lime = Lime(model, interpretable_model=SkLearnLasso(alpha=0.0003))

In [None]:
def generate_record(raw_datum, target): #raw_datum expected to be a tuple/list of 2 sentences
    #tokenizer operations
    tokenized = tokenizer(raw_datum, truncation=True, return_offsets_mapping=True)
    offset_mappings = tokenized['offset_mapping']
    #concatenate the two offset_mappings together because they are fed in together
    conti_map = get_continuation_mapping(offset_mappings[0]) + get_continuation_mapping(offset_mappings[1])
    #change the first input_id of the second sentence to be the last input_id of the 1st sentence (i.e. an [END] token))
    tokenized_input_ids = tokenized['input_ids'][0] + \
                        [tokenized['input_ids'][1][i] if i != 0 else tokenized['input_ids'][0][-1] \
                         for i in range(len(tokenized['input_ids'][1]))]
    input_ids = torch.tensor(tokenized_input_ids).unsqueeze(0)
    detokenized = [t.replace('Ġ', '') for t in tokenizer.convert_ids_to_tokens(input_ids[0])]
    
    #feeding input forward 
    input_emb = model.get_embeddings(input_ids)
    pred_prob = model(input_emb).item()
    print(f'pred_prob {pred_prob*5}')

     #categorizing results
    pred_class = 'Similar' if pred_prob > 0.5 else 'Not Similar' 
    true_class = 'Similar' if target > 2.5 else 'Not Similar' 
    
    #attribution algorithm working
    attri_start_time = time.time()
    attribution = lime.attribute(input_emb, n_samples=5000, show_progress=True, perturbations_per_eval=5000)
    print(f'attribution took {time.time() - attri_start_time:2f} seconds')
    
    attribution[torch.isnan(attribution)] = 0
    word_attributions = attribution.squeeze(0).sum(dim=1)
#     word_attributions = attr_normalizing_func(word_attributions)
    word_attributions /= torch.norm(word_attributions)
    attr_score = torch.sum(word_attributions)
    attr_class = 'Similar' if attr_score > 0.5 else 'Not Similar'
    convergence_score = None
    
    
#     #re-organizing tensors and arrays because words get split down
    conti_attr = get_continuous_attributions(conti_map, word_attributions)
    raw_input = get_continuous_raw_inputs(conti_map, detokenized)

#     print(f'word attributions {word_attributions}')
#     print(f'pred_prob {pred_prob}')
#     print(f'pred_class {pred_class}')
#     print(f'true_class {true_class}')
#     print(f'attribution {attribution}')
#     print(f'attr_class {attr_class}')
#     print(f'attr_score {attr_score}')
#     print(f'raw_input {raw_input}')

        
#   collect info for metrics later
    collect_info_for_metric(model_out_list, pred_prob, raw_attr_list, attribution, conti_attr_list, conti_attr, raw_input_list, raw_input)
        
    
    visual_record = visualization.VisualizationDataRecord(word_attributions=conti_attr,
                                                         pred_prob=pred_prob,
                                                         pred_class=pred_class,
                                                         true_class=true_class,
                                                         attr_class=attr_class,
                                                         attr_score=attr_score,
                                                         raw_input=raw_input,
                                                         convergence_score=convergence_score)
        
        
    return visual_record
      
    

In [None]:
for i, (datum_raw, target) in enumerate(zip(stsb_data_raw, targets), start=1):
    example_1 = "Here are many samples: Pbase.com Out of the 16k photos there, I'm sure some are with a d700."
    example_2 = "It's a 1:1 lens, so that means that the size of the subject will be the same size on the sensor."
    datum_raw, target = [example_1, example_1], 1
    print(f'Raw review: {datum_raw}') #datum expected to be a list of 2 sentences
    print(f'GT target: {target}')
    visual_record=generate_record(datum_raw, target)
    print(visualization.visualize_text([visual_record]))   
    break

In [None]:
save_info(idxs, stsb_data_raw, targets, model_out_list, raw_attr_list, conti_attr_list, raw_input_list, fname='lime_out.pkl')