In [1]:
import numpy as np
from transformers import AutoTokenizer
import torch
import json
from model.longformer_tfidf import LongformerTFIDFForSequenceClassification
from data_collator.data_collator_tfidf import DataCollatorTFIDF

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "models/longformer-tfidf/longformer-large-seed100"
max_length = 4096
truncation_side = 'left'

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=max_length)

device = torch.device('cuda:0')
tokenizer = AutoTokenizer.from_pretrained(model_path, truncation_side=truncation_side)
model = LongformerTFIDFForSequenceClassification.from_pretrained(model_path, 
                                                            num_labels=2).to(device)
model.eval()

EnsembleLongformerForSequenceClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 1024, padding_idx=1)
      (position_embeddings): Embedding(4098, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0): LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (query_global): Linear(in_features=1024, out_features=1024, bias=True)
              (key_global): Linear(in_features=1024, out_features=1024, bias=Tr

In [3]:
# Load ILDC expert
import glob
import re
files = sorted(glob.glob('data/ILDC/ILDC_expert/source/*.txt'))
len(files)

ori_texts = [open(file).read() for file in files]
texts = [' '.join(text.split('\n')) for text in ori_texts]
texts = [re.sub('\s+', ' ', text).strip() for text in texts]
texts = [text.lower() for text in texts]

In [24]:
from tqdm import tqdm
import numpy as np

import pickle as pkl
vectorizer = pkl.load(open('tfidf_vectorizer-threshold350.pkl', 'rb'))

data_collator = DataCollatorTFIDF(tokenizer=tokenizer)

def process_input(texts):
    features = []
    for text in texts:
        feature = tokenizer(text, truncation=True, max_length=max_length)
        tfidf_vector = np.array(vectorizer.transform([text]).todense())[0]
        feature['tfidf_feature'] = tfidf_vector
        features.append(feature)
    batch = data_collator(features)
    return batch

bs = 2

predicts = []

for i in tqdm(range(0, len(texts), bs)):
    batch = texts[i:i+bs]
    batch = process_input(batch)

    for k in batch:
        batch[k] = batch[k].to(device)

    with torch.no_grad():
        outputs = model(**batch)
        logits = outputs.logits.cpu().numpy()
        predict = logits.argmax(axis=1)
        predicts.extend(predict)
        
print(predicts)

100%|██████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:07<00:00,  3.54it/s]

[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1]





In [25]:
import scipy as sp

def predict_a_text(texts):
    batch = process_input(texts)
    for k in batch:
        batch[k] = batch[k].to(device)
    with torch.no_grad():
        outputs = model(**batch)
        logits = outputs.logits.cpu().numpy()
        
    scores = (np.exp(logits).T / np.exp(logits).sum(-1)).T
    val = sp.special.logit(scores)
    return val
    

import shap
import re

from nltk.tokenize.punkt import PunktSentenceTokenizer

def custom_tokenizer_nltk(text, return_offsets_mapping=True):
    sentences = []
    offset_mapping = []
    
    for start, end in PunktSentenceTokenizer().span_tokenize(text):
        length = end - start
        sentences.append(text[start:end])
        offset_mapping.append((start, end))
        
    out = {
        "input_ids": sentences,
        "offset_mapping": offset_mapping
    }
    
    return out

def get_end_of_text(text):
    text = re.sub(" decision \?\?$", "", text)
    last_tokens = tokenizer.encode(text, padding=True, max_length=4096, truncation=True)
    end_of_text = tokenizer.decode(last_tokens)[3:-4] # remove <s> </s>
    return end_of_text

masker = shap.maskers.Text(custom_tokenizer_nltk)
labels = ['denied', 'accepted']
explainer = shap.Explainer(predict_a_text, 
                           masker, 
                           output_names=labels, 
                           algorithm="permutation")

In [26]:
end_of_texts = [get_end_of_text(text) for text in texts]
shap_values = [explainer([end_of_text], batch_size=16) for end_of_text in end_of_texts]

Permutation explainer: 2it [00:30, 30.45s/it]                                                                              
Permutation explainer: 2it [00:22, 22.13s/it]                                                                              
Permutation explainer: 2it [00:37, 37.80s/it]                                                                              
Permutation explainer: 2it [00:22, 22.61s/it]                                                                              
Permutation explainer: 2it [00:21, 21.47s/it]                                                                              
Permutation explainer: 2it [00:39, 39.04s/it]                                                                              
Permutation explainer: 2it [00:24, 24.60s/it]                                                                              
Permutation explainer: 2it [00:26, 26.10s/it]                                                                              
Permutat

In [27]:
import pickle as pkl
pkl.dump(shap_values, open('shap_values_ildc-longformer-tfidf.pkl', 'wb'))

In [35]:
import re

def get_topk_sentences(shap_value, predict, sentences, k=10):
    sorted_inds = np.argsort(shap_value.values[:, predict])[::-1]
    sorted_inds = [x for x in sorted_inds if shap_value.values[x, predict] > 0]
    topk_sentences = np.array(sentences)[sorted_inds][:k]
    return topk_sentences

def get_end_of_text(text):
#     text = re.sub(" DECISION \?\?$", "", text)
    last_tokens = tokenizer.encode(text, padding=True, max_length=4096, truncation=True)
    end_of_text = tokenizer.decode(last_tokens)[3:-4] # remove <s> </s>
    return end_of_text

texts_upper = [' '.join(text.split('\n')) for text in ori_texts]
texts_upper = [re.sub('\s+', ' ', text).strip() for text in texts_upper]
texts_upper = [get_end_of_text(text) for text in texts_upper]
explanations = []

for shap_value, predict, text_upper in zip(shap_values, predicts, texts_upper):
    shap_value = shap_value[0]
    sentences = []

    for start, end in PunktSentenceTokenizer().span_tokenize(text_upper):
        length = end - start
        sentences.append(text_upper[start:end])
    
#     print(len(sentences), shap_value.data.shape[0])
#     sentences = shap_value.data
    
    k = int(len(sentences) * 0.4)
    # k = 10

    top_sentences = get_topk_sentences(shap_value, predict, sentences, k=k)
    top_sentences = [re.search(re.escape(s.strip()), text_upper, re.IGNORECASE).group() for s in top_sentences]
    explain = ' '.join(top_sentences).strip()
    explain = re.sub('\s+', ' ', explain)
    explanations.append(explain)

In [37]:
occ_exp = {}
for explain, file in zip(explanations, files):
    filename = file.split('/')[-1]
    occ_exp[filename] = explain

In [38]:
import nltk
from nltk.tokenize import word_tokenize 
from rouge import Rouge 
import nltk.translate
from tqdm import tqdm
import numpy as np

# In[77]:


def get_BLEU_score(ref_text, machine_text):
    tok_ref_text = word_tokenize(ref_text)
    tok_machine_text = word_tokenize(machine_text)
    sc = nltk.translate.bleu_score.sentence_bleu([tok_ref_text], tok_machine_text, weights = (0.5,0.5))
    return sc

def jaccard_similarity(query, document):
    query = word_tokenize(query)
    document = word_tokenize(document)
    intersection = set(query).intersection(set(document))
    union = set(query).union(set(document))
    if(len(union)==0):
        return 0
    return len(intersection)/len(union)

def overlap_coefficient_min(query, document):
    query = word_tokenize(query)
    document = word_tokenize(document)
    intersection = set(query).intersection(set(document))
    den = min(len(set(query)),len(set(document)))
    if(den==0):
        return 0
    return len(intersection)/den

def overlap_coefficient_max(query, document):
    query = word_tokenize(query)
    document = word_tokenize(document)
    intersection = set(query).intersection(set(document))
    den = max(len(set(query)),len(set(document)))
    if(den==0):
        return 0
    return len(intersection)/den

def occ_result_maker(Rank_initial, Rank_final, occ_exp, gold_exp):
    rouge1 = []
    rouge2 = []
    rougel = []
    jaccard = []
    bleu = []
    meteor = []
    overlap_min = []
    overlap_max = []
    
    files = list(gold_exp.keys())
    
    for u in range(5):
        user = "User " + str(u+1)
        r1 = []
        r2 = []
        rl = []
        jacc = []
        bl = []
        met = []
        omin = []
        omax = []
        
        for i in tqdm(range(len(files))):
            f = files[i]
            ref_text = ""
            for rank in range(Rank_initial, Rank_final+1, 1):
                if(gold_exp[f][user]["exp"]["Rank" + str(rank)]!=""):
                    ref_text += gold_exp[f][user]["exp"]["Rank" + str(rank)] + " "
                
            machine_text = occ_exp[f]
            machine_text = machine_text.lower()
            ref_text = ref_text.lower()
            
            if(ref_text == ""):
                continue
            rouge = Rouge()
            score = rouge.get_scores(machine_text, ref_text)
            r1.append(score[0]['rouge-1']['f'])
            r2.append(score[0]['rouge-2']['f'])
            rl.append(score[0]['rouge-l']['f'])
            jacc.append(jaccard_similarity(ref_text, machine_text))
            omin.append(overlap_coefficient_min(ref_text, machine_text))
            omax.append(overlap_coefficient_max(ref_text, machine_text))
            bl.append(get_BLEU_score(ref_text, machine_text))
            
#             print('===', ref_text)
#             print('===', machine_text)
            met.append(nltk.translate.meteor_score.meteor_score([ref_text.split()], machine_text.split()))
            
        rouge1.append(np.mean(r1))
        rouge2.append(np.mean(r2))
        rougel.append(np.mean(rl))
        jaccard.append(np.mean(jacc))
        overlap_min.append(np.mean(omin))
        overlap_max.append(np.mean(omax))
        bleu.append(np.mean(bl))
        meteor.append(np.mean(met))
        
    print("ROUGE-1 : {:}".format(rouge1) + "\n\n")
    print("ROUGE-2 : {:}".format(rouge2) + "\n\n")
    print("ROUGE-L : {:}".format(rougel)+ "\n\n")
    print("Jaccard : {:}".format(jaccard)+ "\n\n")
    print("Overmin : {:}".format(overlap_min)+ "\n\n")
    print("Overmax : {:}".format(overlap_max)+ "\n\n")
    print("BLEU    : {:}".format(bleu)+ "\n\n")
    print("METEOR  : {:}".format(meteor)+ "\n\n") 
            

In [39]:
import json

gold_exp = json.load(open('data/ILDC/gold_explanations_ranked.json'))

occ_result_maker(1, 10, occ_exp, gold_exp)

100%|██████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:19<00:00,  2.83it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:12<00:00,  4.64it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:25<00:00,  2.18it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:27<00:00,  2.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:14<00:00,  3.85it/s]

ROUGE-1 : [0.4217008427892159, 0.4235414892922794, 0.3925100223482196, 0.40051477922754647, 0.4008767238462048]


ROUGE-2 : [0.2734233987140901, 0.2511180389544381, 0.2622810186645104, 0.2704161595139528, 0.2388428937439398]


ROUGE-L : [0.40684062920739195, 0.3889178917862468, 0.38244524692006365, 0.39546020438287777, 0.37403833251896657]


Jaccard : [0.2895249496197833, 0.29340945890742465, 0.2644349635930708, 0.25950958198140095, 0.2710472908625645]


Overmin : [0.827674193497443, 0.661714637551502, 0.8929833173074578, 0.9221186760643102, 0.6821402184935772]


Overmax : [0.30958354751446987, 0.35279167905110903, 0.2736727432506641, 0.26637414048187236, 0.32089326786049976]


BLEU    : [0.051088720556393595, 0.15902416000622038, 0.028584337002642474, 0.028340908556390126, 0.12687817054964767]


METEOR  : [0.13906619140502868, 0.20915424453880904, 0.11173817812724796, 0.10976786488200796, 0.1886860583743743]





