<a href="https://colab.research.google.com/github/ronakdm/input-marginalization/blob/main/input_marge_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install pytorch_pretrained_bert
!pip install transformers
!git clone https://github.com/ronakdm/input-marginalization.git

In [2]:
%%bash
cd input-marginalization
git pull
cd ..

Already up to date.


In [3]:
import sys
sys.path.append("input-marginalization")
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
from transformers import BertTokenizer, BertModel
from utils import generate_dataloaders
from models import LSTM
from torch.nn import LogSoftmax
import math
import torch.nn.functional as F

In [4]:
from google.colab import drive
drive.mount('/content/gdrive',force_remount=True)
save_dir = "/content/gdrive/My Drive/input-marginalization"

Mounted at /content/gdrive


In [5]:
SAMPLE_SIZE = 3
SIGMA = 1e-4
log_softmax = LogSoftmax(dim=0)

In [6]:
%%capture
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')

In [7]:
cnn = torch.load(f"{save_dir}/cnn_sst2.pt")

In [8]:
def loaddata ():
  train_dataloader, validation_dataloader, test_dataloader = generate_dataloaders(1)
  return test_dataloader

In [9]:
def compute_probability(model, input_ids, attention_masks, label):
    logits = model(
        input_ids, token_type_ids=None, attention_mask=attention_masks, labels=label.repeat((len(input_ids))),
    ).logits
    
    return math.exp(logits[0][label])
    

In [10]:
def compute_probability2(model, input_ids, attention_masks, label):
  
    logits = model(
        input_ids.to(torch.int64), token_type_ids=None, attention_mask=attention_masks, labels=label.repeat((len(input_ids))),
    ).logits

    return torch.exp(torch.reshape(logits[:, label], (-1,)))
    

In [27]:
def calculate_woe(model, input_ids, attention_masks, label, sigma):
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  bert_model.to(device)
  
  #predictions is the probability distribution of each word in the vocabulary for each word in input sentence
  predictions = bert_model(input_ids)
  predictions = torch.squeeze(predictions)
  predictions = F.softmax(predictions, dim=1)

  #woe is the weight of evidence
  woe = []
  model.eval()

  with torch.no_grad():
    for j in range (len(predictions)):
      word_scores = predictions[j]
      input_batch = input_ids.clone().to(device)
      
      #word_scores_batch calculates the value of the MLM of Bert for each masked word
      #we put 0 for the first input which is unmasked
      word_scores_batch = [0]

      for k in range(len(word_scores)):
        if word_scores[k] > sigma:
           input_batch = torch.cat((input_batch, input_ids), 0)
           input_batch[len(input_batch)-1][j] = k
           word_scores_batch.append(word_scores[k].item())
      
      #probability_input calculates the p(label|sentence) of the target model given each masked input sentence
      probability_input = compute_probability2(model, input_batch, attention_masks, label)
      
      m = torch.dot(torch.tensor(word_scores_batch).to(device), probability_input)
      logodds_input = math.log(probability_input[0] / (1-probability_input[0]))
      logodds_m = math.log(m / (1-m))
      woe.append(logodds_input-logodds_m)
  return woe


In [23]:
def input_marg(model): 
  test_data = loaddata()
  device = "cuda" if next(model.parameters()).is_cuda else "cpu"
  iter_data = iter(test_data)

  for i in range(SAMPLE_SIZE):
    nextsample = next(iter_data)
    inputsequences = nextsample[0].to(device)
    inputmask =  nextsample[1].to(device)
    labels = nextsample[2].to(device)
    print("")
    print(tokenizer.convert_ids_to_tokens(inputsequences[0][1:11]))
    print(calculate_woe(model, torch.unsqueeze(inputsequences[0][1:11],0),torch.unsqueeze(inputmask[0][:11],0),  torch.unsqueeze(labels[0],0), SIGMA))
      

In [28]:
input_marg(cnn)

6,919 training samples.
  876 validation samples.
1,822 test samples.

['rates', 'an', '`', 'e', "'", 'for', 'effort', '-', '-', 'and']
[-0.023614781172927146, -0.5736915400651671, 0.01798837133650444, -0.35940286665283944, 0.0036026506841958517, 0.018793006829055603, 0.10160360009127833, 0.04824759663439149, 0.01074686440903283, 0.0003444468218778418]

['an', 'ideal', '##istic', 'love', 'story', 'that', 'brings', 'out', 'the', 'late']
[3.9840200238023047, 2.14271560182448, 2.260443213282069, 2.3856000592162365, 0.9963197812029252, 1.4108959269479553, 2.8481678784848943, 0.47058731449476277, 0.09098499923152481, 4.090750703705227]

['a', 'bold', 'and', 'sub', '##vers', '##ive', 'film', 'that', 'cuts', 'across']
[3.338206176788492, 3.094990815615888, 0.7888140804037334, 1.6613605814508112, 3.2641718414506165, 2.4920595978837055, 2.664964501682608, 1.7363559568443465, 0.6809756063880457, 1.6675010758344166]
