In [1]:
'''2021 Spring NLP Final Project'''
'''Team Zulu'''
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim

import dataloader
from bilstm import BiLSTMs_Classifier

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-n475zl8j because the default path (/home/wlchen/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
# load trained model
def load_checkpoint(load_path, model, model_type, device):  
    state_dict = torch.load(load_path, map_location=device)
    model.load_state_dict(state_dict["model_state_dict"])

    return model

In [3]:
def saliency(test_sen, gt, device=device):
    
    # hyper-parameters
    batch_size = 16
    output_size = 2
    hidden_size = 256
    embedding_size = 300
    glove = '840B'

    model_pth = 'model_checkpoints/bilstm.pt'
    data_pth = 'preprocessed_tweet_data/'

    # data preproccess
    
    TEXT, vocab_size, word_embeddings, train_loader, vali_loader, test_loader = dataloader.load_dataset(data_pth, device, glove)

    # load model
    
    model = BiLSTMs_Classifier(batch_size, output_size, hidden_size, vocab_size, embedding_size, word_embeddings, train_embedding=False)
    load_checkpoint(model_pth, model, 'bilstms', device)
    
    # test sentense to be predicted and its ground truth
    test_sen = test_sen
    gt = gt

    raw_sen = test_sen # for later used in visualizing the result
    test_sen = TEXT.preprocess(test_sen) # tokenize test sentence
    test_sen = [[TEXT.vocab.stoi[x] for x in test_sen]] # tokens -> indices

    test_sen = np.asarray(test_sen)
    test_sen = torch.LongTensor(test_sen)
    test_sen = test_sen.cuda()
    
    # following is for calculating the saliency score
    
    model.cuda()
    model.train() # turn to training mode so taht it can perform backward()

    '''
        we want to observe the gradients of score w.r.t the input words, in tis case, it's there embedding, 
        so I set the embedding results as the leaf variable.
        ps. pytorch doesn't keep record of gradients for intermediate noeds.
    '''
    embedded = model.make_embedding(test_sen)
    embedded = Variable(embedded, requires_grad = True) # leaf variable

    preds = model.classifier(embedded) # get the predict results
    _, indices = torch.max(preds, 1) # get the index which represent the predict label

    score = preds[:, indices] # get the score of the predict label
    score.backward() # calculate gradients

    '''
        here we get the saliency value, which is the gradients w.r.t each word's embedding.
        ps. we take the maximum value in the embedding dimensions as the saliency vavle for each word.
    '''
    saliency, _ = torch.max(embedded.grad.data.abs(), dim=2)
    
    # Visualize the saliency values along with their coressponding words
    print('predict:      ' + ('fake' if indices == 1 else 'real'))
    print('ground truth: {}'.format(gt))
    _, max_word_idx = torch.max(saliency.squeeze(), dim=0)
    tokens = raw_sen.split(' ')

    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)

    # Some math trick for better visualization
    saliency_df = pd.DataFrame([tokens, list(F.softmax(saliency*50, dim=0).squeeze().cpu().numpy())])
    saliency_df = saliency_df.T
    saliency_df[1]*=100

    cm = sns.diverging_palette(260, 10, n=9, as_cmap=True)
    aliency_df = saliency_df.T
    aliency_df = aliency_df.drop(columns=[len(tokens)-1])
    aliency_df = aliency_df.style.background_gradient(subset=aliency_df.index[-1], cmap=cm, axis=1)
    display(aliency_df)

In [4]:
test_data = {
    'test_sen': "do not believe that $URL$ @criscarter80 I aint No Fool unamused face unamused face fire fire fire #LakeShow #LAbron #KobeDaGOAT $URL$ <end>",
    'gt': "fake"
}

In [5]:
saliency(test_data['test_sen'], test_data['gt'])

Length of Text Vocabulary: 139881
Vector size of Text Vocabulary:  torch.Size([139881, 300])
predict:      fake
ground truth: fake


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20
0,do,not,believe,that,$URL$,@criscarter80,I,aint,No,Fool,unamused,face,unamused,face,fire,fire,fire,#LakeShow,#LAbron,#KobeDaGOAT,$URL$
1,3.545461,3.793398,4.919140,6.593657,5.274556,8.139013,7.271398,7.108112,7.343378,5.915275,5.356946,5.455403,4.899125,4.963765,4.301469,3.489053,2.825715,2.514616,2.215898,1.946804,1.353247


In [6]:
test_data = {
    'test_sen': "Today is hard . My chest feels heavy . My anxiety is through the roof . Verge of tears . I want this to end already . I am not looking for advice . I just need a place to say it . <end>",
    'gt': "real"
}

In [7]:
saliency(test_data['test_sen'], test_data['gt'])

Length of Text Vocabulary: 139881
Vector size of Text Vocabulary:  torch.Size([139881, 300])
predict:      real
ground truth: real


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42
0,Today,is,hard,.,My,chest,feels,heavy,.,My,anxiety,is,through,the,roof,.,Verge,of,tears,.,I,want,this,to,end,already,.,I,am,not,looking,for,advice,.,I,just,need,a,place,to,say,it,.
1,5.082132,6.133394,4.790653,5.349478,4.235262,5.054872,7.047950,7.117216,5.356193,3.755646,3.638745,4.404605,4.011426,4.469451,4.105107,5.234344,4.642946,3.063004,2.047132,1.932239,1.287312,1.101402,0.917207,0.720460,0.558211,0.467331,0.404748,0.317489,0.275760,0.255156,0.227866,0.201181,0.182639,0.176057,0.160284,0.153682,0.145274,0.137963,0.132209,0.127162,0.123152,0.121399,0.125250
