<a href="https://colab.research.google.com/github/ronakdm/input-marginalization/blob/main/input_marge_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install pytorch_pretrained_bert
!pip install transformers
!git clone https://github.com/ronakdm/input-marginalization.git

In [2]:
%%bash
cd input-marginalization
git pull
cd ..

Already up to date.


In [3]:
import sys
sys.path.append("input-marginalization")
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from transformers import BertTokenizer, BertModel
from utils import generate_dataloaders
from models import LSTM
from torch.nn import LogSoftmax
import math
import torch.nn.functional as F

In [5]:
from google.colab import drive
drive.mount('/content/gdrive',force_remount=True)
save_dir = "/content/gdrive/My Drive/input-marginalization"

Mounted at /content/gdrive


In [6]:
SAMPLE_SIZE = 5
SIGMA = 1e-4
log_softmax = LogSoftmax(dim=0)

In [7]:
%%capture
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')

In [8]:
bert = torch.load(f"{save_dir}/bert_sst2.pt")
cnn = torch.load(f"{save_dir}/cnn_sst2.pt")
lstm = torch.load(f"{save_dir}/lstm_sst2.pt")

In [9]:
def loaddata ():
  train_dataloader, validation_dataloader, test_dataloader = generate_dataloaders(1)
  return test_dataloader

In [10]:
def compute_probability(model, input_ids, attention_masks, label):
    logits = model(
        input_ids, token_type_ids=None, attention_mask=attention_masks, labels=label.repeat((len(input_ids))),
    ).logits
    
    return math.exp(logits[0][label])
    

In [11]:
def compute_probability2(model, input_ids, attention_masks, label):
  
    logits = model(
        input_ids.to(torch.int64), token_type_ids=None, attention_mask=attention_masks, labels=label.repeat((len(input_ids))),
    ).logits

    return torch.exp(torch.reshape(logits[:, label], (-1,)))
    

In [15]:
def calculate_woe(model, input_ids, attention_masks, label, sigma):
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  bert_model.to(device)
  
  #predictions is the probability distribution of each word in the vocabulary for each word in input sentence
  predictions = bert_model(input_ids)
  predictions = torch.squeeze(predictions)
  predictions = F.softmax(predictions, dim=1)

  #woe is the weight of evidence
  woe = []
  model.eval()

  with torch.no_grad():
    for j in range (len(predictions)):
      word_scores = predictions[j]
      input_batch = input_ids.clone().to(device)
      
      #word_scores_batch calculates the value of the MLM of Bert for each masked word
      #we put 0 for the first input which is unmasked
      word_scores_batch = [0]

      for k in range(len(word_scores)):
        if word_scores[k] > sigma:
           input_batch = torch.cat((input_batch, input_ids), 0)
           input_batch[len(input_batch)-1][j] = k
           word_scores_batch.append(word_scores[k].item())
      
      #probability_input calculates the p(label|sentence) of the target model given each masked input sentence
      probability_input = compute_probability2(model, input_batch, attention_masks, label)
      
      m = torch.dot(torch.tensor(word_scores_batch).to(device), probability_input)
      logodds_input = math.log(probability_input[0] / (1-probability_input[0]))
      logodds_m = math.log(m / (1-m))
      woe.append(logodds_input-logodds_m)
  return woe


In [13]:
def input_marg(model): 
  test_data = loaddata()
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  iter_data = iter(test_data)
  results = []

  for i in range(SAMPLE_SIZE):
    nextsample = next(iter_data)
    inputsequences = nextsample[0].to(device)
    inputmask =  nextsample[1].to(device)
    labels = nextsample[2].to(device)
    print("")
    print(labels)
    token_ids = tokenizer.convert_ids_to_tokens(inputsequences[0][1:20])
    label = torch.unsqueeze(labels[0],0)
    woe = calculate_woe(model, torch.unsqueeze(inputsequences[0][1:20],0),torch.unsqueeze(inputmask[0][1:20],0),  label, SIGMA)
    results.append((label,woe, token_ids))
  return results
      

In [17]:
cnnresults = input_marg(cnn)

6,919 training samples.
  876 validation samples.
1,822 test samples.

tensor([0], device='cuda:0')

tensor([1], device='cuda:0')

tensor([0], device='cuda:0')

tensor([0], device='cuda:0')

tensor([1], device='cuda:0')


In [16]:
lstmresults = input_marg(lstm)

6,919 training samples.
  876 validation samples.
1,822 test samples.

tensor([0], device='cuda:0')

tensor([0], device='cuda:0')

tensor([1], device='cuda:0')

tensor([0], device='cuda:0')

tensor([1], device='cuda:0')


In [18]:
#bertresults = input_marg(bert)

In [19]:
for s in cnnresults:
  print("")
  print(s)


(tensor([0], device='cuda:0'), [3.4610748016280053, 2.38397474852603, 2.3923195258668084, 1.7607884267928884, 0.3674926114160879, 1.5414679254725554, 3.1376562405806236, 0.1259094629247639, 2.3533538663678235, 1.3535274124980452, 2.104138825536487, 2.059403313810178, 0.3496032190855063, 2.0087733035546256, 1.8335877501232627, 2.165256869315688, 2.0007512542773425, 0.3124485696755457, 1.4257422518866631], ['a', 'cum', '##bers', '##ome', 'and', 'cl', '##iche', '-', 'ridden', 'movie', 'grease', '##d', 'with', 'every', 'emotional', 'device', 'known', 'to', 'man'])

(tensor([1], device='cuda:0'), [0.019363717935888314, -0.061479726320976535, -0.00011158547597389656, 0.0010579135350781144, -0.5096246861903642, -0.83539128521819, 0.00045235341203397894, -0.98840695155201, -0.7205186279995546, 0.0007020201621914524, 0.0004937733128724808, -0.0056011366841750565, -0.052475071732195566, -0.06809665905804474, -0.3865993082092616, 0.09922096857926999, -0.0063797004028489646, 0.02325916561745256, 

In [20]:
for s in lstmresults:
  print("")
  print(s)


(tensor([0], device='cuda:0'), [0.37156821811137064, 0.017492801567271066, 0.013098964308289385, -0.12523953950215794, 0.6860030198666149, 0.0040648343216771465, 0.01789189489033749, 0.22077135663777514, 0.0012861338516603205, 0.3946944524937588, 0.006771245770374801, 0.7494196675600019, 0.003535269832497856, 0.889286613006903, -0.0007121696756314799, -0.15560619923878383, 0.8239208374883507, -0.08432024024142892, 0.5903470888925955], ['what', "'", 's', 'at', 'stake', 'in', 'this', 'film', 'is', 'nothing', 'more', 'than', 'an', 'obsolete', ',', 'if', 'irritating', ',', 'notion'])

(tensor([0], device='cuda:0'), [0.40896555142323227, -0.5955950919284161, 0.3057715897392918, -0.04506092700470665, 0.001421766975496841, 0.06626859221042797, 0.4055718487843695, -0.05357000897940517, -0.3228402221788481, 0.426460677970432, 0.15528935255175463, 0.04751550712606961, -0.0020060389923851063, 0.6506775834303204, -0.003255061313046259, 0.11698716214179333, 0.8466005020088545, 0.001986620493516278

In [30]:
lstmsentences_pos = [
  lstmresults[2], lstmresults[4]
]

In [31]:
cnnsentences_pos = [
  cnnresults[1], cnnresults[4]
]

In [32]:
lstmsentences_neg = [
  lstmresults[0], lstmresults[1]
]

In [33]:
cnnsentences_neg = [
  cnnresults[0], cnnresults[3]
]

In [None]:
bertsentences = [
"it's a lovely film with lovely performances by buy and accorsi .",
"more romantic , more emotional and ultimately more satisfying than the teary-eyed original .",
"it's a bit disappointing that it only manages to be decent instead of dead brilliant .",
"suffers from the lack of a compelling or comprehensible narrative ."
]

In [35]:
# LSTM pos
for labels, auclist, tokens in lstmsentences_pos:
  colored_sentence(lstm, tokens, auclist)

 [48;2;255;0;0mit[0m [48;2;255;230;234m'[0m [48;2;255;230;234ms[0m [48;2;225;102;102mend[0m [48;2;225;102;102mearing[0m [48;2;255;230;234mto[0m [48;2;204;229;255mhear[0m [48;2;255;230;234mmadame[0m [48;2;255;230;234md[0m [48;2;255;230;234m.[0m [48;2;255;230;234mrefer[0m [48;2;255;230;234mto[0m [48;2;255;230;234mher[0m [48;2;255;230;234mhusband[0m [48;2;255;230;234mas[0m [48;2;102;178;225m`[0m [48;2;255;230;234mjackie[0m [48;2;255;230;234m'[0m [48;2;255;230;234m-[0m
 [48;2;255;0;0m-[0m [48;2;255;0;0ml[0m [48;2;225;102;102mrb[0m [48;2;255;230;234m-[0m [48;2;255;230;234ma[0m [48;2;255;230;234m-[0m [48;2;225;102;102mrr[0m [48;2;255;0;0mb[0m [48;2;255;230;234m-[0m [48;2;255;0;0mhollywood[0m [48;2;255;0;0msheen[0m [48;2;225;102;102mbed[0m [48;2;255;230;234mev[0m [48;2;225;102;102mils[0m [48;2;255;230;234mthe[0m [48;2;225;102;102mfilm[0m [48;2;255;230;234mfrom[0m [48;2;102;178;225mthe[0m [48;2;255;230;234mvery[0m


In [34]:
# CNN pos
for labels, auclist, tokens in cnnsentences_pos:
  colored_sentence(cnn, tokens, auclist)

 [48;2;255;230;234mit[0m [48;2;204;229;255mhelps[0m [48;2;204;229;255mthat[0m [48;2;255;230;234mthe[0m [48;2;0;0;255mcentral[0m [48;2;0;0;255mperformers[0m [48;2;255;230;234mare[0m [48;2;0;0;255mexperienced[0m [48;2;0;0;255mactors[0m [48;2;255;230;234m,[0m [48;2;255;230;234mand[0m [48;2;204;229;255mthat[0m [48;2;204;229;255mthey[0m [48;2;204;229;255mknow[0m [48;2;0;0;255mtheir[0m [48;2;255;230;234mroles[0m [48;2;204;229;255mso[0m [48;2;255;230;234mwell[0m [48;2;255;230;234m.[0m
 [48;2;225;102;102m.[0m [48;2;255;230;234m.[0m [48;2;255;230;234m.[0m [48;2;255;230;234mthere[0m [48;2;255;230;234mare[0m [48;2;255;0;0menough[0m [48;2;255;0;0mmoments[0m [48;2;255;230;234mof[0m [48;2;255;0;0mheartbreak[0m[48;2;255;230;234ming[0m [48;2;255;0;0mhonesty[0m [48;2;255;230;234mto[0m [48;2;255;230;234mkeep[0m [48;2;255;230;234mone[0m [48;2;255;0;0mglued[0m [48;2;255;230;234mto[0m [48;2;255;230;234mthe[0m [48;2;255;230;234mscreen

In [37]:
# LSTM neg
for labels, auclist, tokens in lstmsentences_neg:
  colored_sentence(lstm, tokens, auclist)

 [48;2;255;204;204mwhat[0m [48;2;255;230;234m'[0m [48;2;255;230;234ms[0m [48;2;102;178;225mat[0m [48;2;225;102;102mstake[0m [48;2;255;230;234min[0m [48;2;255;230;234mthis[0m [48;2;255;230;234mfilm[0m [48;2;255;230;234mis[0m [48;2;255;204;204mnothing[0m [48;2;255;230;234mmore[0m [48;2;225;102;102mthan[0m [48;2;255;230;234man[0m [48;2;225;102;102mobsolete[0m [48;2;204;229;255m,[0m [48;2;102;178;225mif[0m [48;2;225;102;102mirritating[0m [48;2;204;229;255m,[0m [48;2;225;102;102mnotion[0m
 [48;2;255;204;204mfocuses[0m [48;2;0;0;255mon[0m [48;2;255;204;204mjoan[0m [48;2;204;229;255m'[0m [48;2;255;230;234ms[0m [48;2;255;230;234mraging[0m [48;2;255;204;204mhormones[0m [48;2;204;229;255mand[0m [48;2;0;0;255msl[0m [48;2;255;204;204medge[0m [48;2;255;230;234mhammer[0m [48;2;255;230;234ms[0m [48;2;204;229;255mthe[0m [48;2;225;102;102maudience[0m [48;2;204;229;255mwith[0m [48;2;255;230;234mspanish[0m [48;2;225;102;102minquisitio

In [36]:
# CNN neg
for labels, auclist, tokens in cnnsentences_neg:
  colored_sentence(cnn, tokens, auclist)

 [48;2;255;0;0ma[0m [48;2;255;0;0mcum[0m[48;2;255;0;0mbers[0m[48;2;255;0;0mome[0m [48;2;255;204;204mand[0m [48;2;255;0;0mcl[0m[48;2;255;0;0miche[0m [48;2;255;230;234m-[0m [48;2;255;0;0mridden[0m [48;2;255;0;0mmovie[0m [48;2;255;0;0mgrease[0m[48;2;255;0;0md[0m [48;2;255;204;204mwith[0m [48;2;255;0;0mevery[0m [48;2;255;0;0memotional[0m [48;2;255;0;0mdevice[0m [48;2;255;0;0mknown[0m [48;2;255;204;204mto[0m [48;2;255;0;0mman[0m
 [48;2;255;0;0mthe[0m [48;2;255;230;234mkids[0m [48;2;255;230;234moften[0m [48;2;102;178;225mappear[0m [48;2;255;230;234mto[0m [48;2;0;0;255mbe[0m [48;2;102;178;225mreading[0m [48;2;255;230;234mthe[0m [48;2;204;229;255mlines[0m [48;2;204;229;255mand[0m [48;2;255;230;234mare[0m [48;2;102;178;225mincapable[0m [48;2;255;230;234mof[0m [48;2;255;230;234mconvey[0m[48;2;0;0;255ming[0m [48;2;255;230;234many[0m [48;2;255;230;234memotion[0m [48;2;255;230;234m.[0m [48;2;255;230;234m[SEP][0m


In [25]:
def colored_sentence(model, tokenized_sentence, auclist):
    # define some color for different levels of effect
    red3 = [255, 0, 0]
    red2 = [225, 102, 102]
    red1 = [255, 204, 204]
    red0 = [255, 230, 234]
    blue0 = [204, 229, 255]
    blue1 = [204, 229, 255]
    blue2 = [102, 178, 225]
    blue3 = [0, 0, 255]

    splits = [-0.2, -0.1, -0.05,0, 0.3, 0.5, 1]

    colored = []
    joined = []

    for i in range(len(tokenized_sentence)):
        if tokenized_sentence[i][0] == '#':
          tokenized_sentence[i] = tokenized_sentence[i][2:]
          joined.append(1)
        else:
          joined.append(0)

        if auclist[i] > splits[6]:  # very positive
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(red3[0]),
                    str(red3[1]),
                    str(red3[2]),
                    tokenized_sentence[i],
                )
            )
        elif auclist[i] > splits[5]:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(red2[0]), str(red2[1]), str(red2[2]), tokenized_sentence[i]
                )
            )
        elif auclist[i] > splits[4]:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(red1[0]), str(red1[1]), str(red1[2]), tokenized_sentence[i]
                )
            )
        elif auclist[i] > splits[3]:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(red0[0]), str(red0[1]), str(red0[2]), tokenized_sentence[i]
                )
            )
        elif auclist[i] > splits[2]:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(blue0[0]), str(blue0[1]), str(blue0[2]), tokenized_sentence[i]
                )
            )
        elif auclist[i] > splits[1]:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(blue1[0]), str(blue1[1]), str(blue1[2]), tokenized_sentence[i]
                )
            )
        elif auclist[i] > splits[0]:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(blue2[0]), str(blue2[1]), str(blue2[2]), tokenized_sentence[i]
                )
            )
       
        else:
            colored.append(
                "\033[48;2;{};{};{}m{}\033[0m".format(
                    str(blue3[0]), str(blue3[1]), str(blue3[2]), tokenized_sentence[i]
                )
            )
    sent = ""
    
    for i, elem in enumerate(colored):
      if joined[i] == 1:
        sent = sent+str(elem)
      else:
        sent = sent+" "+str(elem)

    print(sent)