<a href="https://colab.research.google.com/github/ronakdm/input-marginalization/blob/main/snli_input_marge_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install pytorch_pretrained_bert
!pip install transformers
!rm -rf input-marginalization
!git clone https://github.com/ronakdm/input-marginalization.git

In [2]:
%%bash
cd input-marginalization
git pull
cd ..

Already up to date.


In [3]:
import sys
sys.path.append("input-marginalization")
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from transformers import BertTokenizer, BertModel
from utils import generate_snli_dataloader, SNLIDataset
from models import LSTM
from torch.nn import LogSoftmax
import math
import torch.nn.functional as F
from torch.utils.data import DataLoader
from metrics import continuous_colored_sentence

In [5]:
from google.colab import drive
drive.mount('/content/gdrive',force_remount=True)
save_dir = "/content/gdrive/My Drive/NLP/imarg"

Mounted at /content/gdrive


In [15]:
SAMPLE_SIZE = 15
SIGMA = 1e-4
log_softmax = LogSoftmax(dim=0)

In [7]:
%%capture
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')

In [8]:
model = torch.load(f"{save_dir}/lstm_snli.pt")

In [25]:
def loaddata():
  test_dataset = SNLIDataset('input-marginalization/preprocessed_data/SNLI/snli_1.0/snli_test_string.pkl')
  print(test_dataset.le.classes_)
  test_dataloader = DataLoader(
        test_dataset, batch_size=1, shuffle=True
    )
  return test_dataloader

In [10]:
def compute_probability(model, sentences, toktype_mask, label):
    s1 = sentences.T[toktype_mask == 0].T
    s2 = sentences.T[toktype_mask == 1].T

    logits = model((s1, s2), labels=None)
    probabilitydist = F.softmax(logits, dim=1)
    return torch.reshape(probabilitydist[:, label], (-1,))

In [18]:
def calculate_woe(model, sentences, label, sigma):
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  bert_model.to(device)
  model.to(device)

  print(sentences[0], sentences[1])
  tok = tokenizer(sentences[0], sentences[1], return_tensors="pt")
  input_ids = tok['input_ids'].to(device)
  toktype = tok['token_type_ids'].to(device)[0]

  #predictions is the probability distribution of each word in the vocabulary for each word in input sentence
  predictions = bert_model(input_ids, token_type_ids=toktype.unsqueeze(0))
  predictions = torch.squeeze(predictions)
  predictions = F.softmax(predictions, dim=1)

  #woe is the weight of evidence
  woe = []
  model.eval()

  with torch.no_grad():
    for j in range (len(predictions)):
      word_scores = predictions[j]
      input_batch = input_ids.clone().to(device)
      
      #word_scores_batch calculates the value of the MLM of Bert for each masked word
      #we put 0 for the first input which is unmasked
      word_scores_batch = [0]

      for k in range(len(word_scores)):
        if word_scores[k] > sigma:
           input_batch = torch.cat((input_batch, input_ids), 0)
           input_batch[-1][j] = k
           word_scores_batch.append(word_scores[k].item())
      
      #probability_input calculates the p(label|sentence) of the target model given each masked input sentence
      probability_input = compute_probability(model, input_batch, toktype, label)
      m = torch.dot(torch.tensor(word_scores_batch).to(device), probability_input)
      logodds_input = math.log(probability_input[0] / (1-probability_input[0]))
      logodds_m = math.log(m / (1-m))
      woe.append(logodds_input-logodds_m)

  woe = torch.tensor(woe).to(device)
  return (input_ids[0][toktype==0], woe[toktype == 0]), (torch.cat([torch.tensor(input_ids[0][toktype==0][-1]).unsqueeze(0), input_ids[0][toktype==1]]), torch.cat([torch.tensor(woe[toktype==0][-1]).unsqueeze(0), woe[toktype==1]]))


In [23]:
def input_marg(model): 
  test_data = loaddata()
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  iter_data = iter(test_data)

  for i in range(SAMPLE_SIZE):
    curr = next(iter_data)
    sentences, labels = curr
    
    print("")
    (s1, a1), (s2, a2) = calculate_woe(model, sentences, labels, SIGMA)
    # print(tokenizer.convert_ids_to_tokens(s1))
    # print(a1)
    print(tokenizer.convert_ids_to_tokens(s2))
    # print(a2)
    # print(labels)
    print('pre: ', continuous_colored_sentence(s1.unsqueeze(0), a1.unsqueeze(0),pretok=True, verbose=False))
    print('hypo: ', continuous_colored_sentence(s2.unsqueeze(0), a2.unsqueeze(0),pretok=True, verbose=False))

In [24]:
input_marg(model)


('Shirtless man with long pole navigates covered boat down a palm tree lined river past a hut.',) ('A man without a shirt is on a river.',)




['[CLS]', 'shirt', '##less', 'man', 'with', 'long', 'pole', 'navigate', '##s', 'covered', 'boat', 'down', 'a', 'palm', 'tree', 'lined', 'river', 'past', 'a', 'hut', '.', '[SEP]']
tensor([ 1.5137e-01,  1.8670e-01,  8.4433e-02,  5.0246e-04,  6.8416e-04,
        -4.3368e-04,  1.6930e-02,  3.1362e-02,  9.1456e-03,  2.2176e-02,
         2.4020e-03,  3.9358e-04,  8.6103e-05,  3.6162e-03,  1.1089e-03,
         1.0514e-02,  1.8539e-03,  1.2440e-02,  1.2084e-04,  1.1828e-03,
         9.9907e-05,  3.0098e-02], device='cuda:0')
['[SEP]', 'a', 'man', 'without', 'a', 'shirt', 'is', 'on', 'a', 'river', '.', '[SEP]']
tensor([ 3.0098e-02,  1.3502e-03,  2.4498e-04,  3.3729e-03,  3.6641e-04,
         8.7076e-04,  2.4917e-03,  2.3083e-04,  5.7137e-05,  2.4577e-03,
         2.4633e-04, -1.2929e-01], device='cuda:0')
tensor([2])
pre:   [48;2;255;48;48mshirt[0m[48;2;255;0;0mless[0m [48;2;255;139;139mman[0m [48;2;255;254;254mwith[0m [48;2;255;254;254mlong[0m [48;2;0;0;255mpole[0m [48;2;255;231;2