<a href="https://colab.research.google.com/github/ronakdm/input-marginalization/blob/main/Copy_of_input_marge_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install pytorch_pretrained_bert
!pip install transformers
!git clone https://github.com/ronakdm/input-marginalization.git

In [None]:
%%bash
cd input-marginalization
git pull
cd ..

Already up to date.


In [None]:
import sys
sys.path.append("input-marginalization")
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from transformers import BertTokenizer, BertModel
from utils import generate_dataloaders
from models import LSTM
from torch.nn import LogSoftmax
import math
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/gdrive',force_remount=True)
save_dir = "/content/gdrive/My Drive/input-marginalization"

Mounted at /content/gdrive


In [None]:
SAMPLE_SIZE = 3
SIGMA = 1e-4
log_softmax = LogSoftmax(dim=0)

In [None]:
%%capture
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')

In [None]:
cnn = torch.load(f"{save_dir}/cnn_sst2.pt")

In [None]:
def loaddata ():
  train_dataloader, validation_dataloader, test_dataloader = generate_dataloaders(1)
  return test_dataloader

In [None]:
def compute_probability2(model, input_ids, attention_masks, label):
  
    logits = model(
        input_ids.to(torch.int64), attention_mask=attention_masks, labels=label.repeat((len(input_ids))),
    ).logits
    probabilitydist = F.softmax(logits, dim=1)
    return torch.reshape(probabilitydist[:, label], (-1,))
    
    

In [None]:
def calculate_woe(model, input_ids, attention_masks, label, sigma):
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  bert_model.to(device)
  
  #predictions is the probability distribution of each word in the vocabulary for each word in input sentence
  predictions = bert_model(input_ids)
  predictions = torch.squeeze(predictions)
  predictions = F.softmax(predictions, dim=1)

  #woe is the weight of evidence
  woe = []
  model.eval()

  with torch.no_grad():
    for j in range (len(predictions)):
      word_scores = predictions[j]
      input_batch = input_ids.clone().to(device)
      
      #word_scores_batch calculates the value of the MLM of Bert for each masked word
      #we put 0 for the first input which is unmasked
      word_scores_batch = [0]

      for k in range(len(word_scores)):
        if word_scores[k] > sigma:
           input_batch = torch.cat((input_batch, input_ids), 0)
           input_batch[len(input_batch)-1][j] = k
           word_scores_batch.append(word_scores[k].item())
      
      #probability_input calculates the p(label|sentence) of the target model given each masked input sentence
      probability_input = compute_probability2(model, input_batch, attention_masks, label)
      m = torch.dot(torch.tensor(word_scores_batch).to(device), probability_input)
      logodds_input = math.log(probability_input[0] / (1-probability_input[0]))
      logodds_m = math.log(m / (1-m))
      woe.append(logodds_input-logodds_m)
  return woe


In [None]:
def input_marg(model): 
  test_data = loaddata()
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  iter_data = iter(test_data)

  for i in range(SAMPLE_SIZE):
    nextsample = next(iter_data)
    inputsequences = nextsample[0].to(device)
    inputmask =  nextsample[1].to(device)
    labels = nextsample[2].to(device)
    print("")
    print(tokenizer.convert_ids_to_tokens(inputsequences[0][1:11]))
    print(calculate_woe(model, torch.unsqueeze(inputsequences[0][1:11],0),torch.unsqueeze(inputmask[0][:11],0),  torch.unsqueeze(labels[0],0), SIGMA))
      

In [None]:
input_marg(cnn)

6,919 training samples.
  876 validation samples.
1,822 test samples.

['glorious', '##ly', 'goofy', '-', 'l', '##rb', '-', 'and', 'go', '##ry']
[0.6360708526641077, 0.5205873643734907, -0.05158250303300005, -0.18309880637190987, 0.2416764006098786, 0.22148587930136387, -0.7405489300327377, 0.00102917462197466, -0.011354074059133856, 0.03784839513373134]

['sorry', ',', 'charlie', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
[1.297631683081951, 0.12134430704954913, -0.7411654629968591, -0.5226832001363202, 0.2521776600292692, 0.31093861638171205, 0.6917379951676591, 0.870505358450974, 0.9866025639505712, 0.6954753759481772]

['zhang', 'yi', '##mo', '##u', 'delivers', 'warm', ',', 'genuine', 'characters', 'who']
[0.4395113611917867, 0.607800883227863, 2.022660577388778, 1.755081696906517, 1.4283542730537588, 1.1987196021466933, 0.20366822120572348, 0.8403306200253711, 0.8149360267820158, 0.7691622882304443]
