In [1]:
!pip install transformers
!pip install captum

Installing collected packages: captum
Successfully installed captum-0.5.0


In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from transformers import DistilBertForSequenceClassification, DistilBertConfig, DistilBertTokenizerFast

from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients

In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
# import model
model_path = '/content/drive/MyDrive/news_peace_models/distilbert-uncased-train-on-processed-data'
model = DistilBertForSequenceClassification.from_pretrained(model_path)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
model.zero_grad()

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [5]:
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [6]:
# load data
import os
import json

data_path = '/content/drive/MyDrive/news-split-data-processed'

with open('{}/test_text.json'.format(data_path), 'r') as f:
    test_news = json.load(f)

with open('{}/test_label.json'.format(data_path), 'r') as f:
    test_labels = json.load(f)

In [7]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer(text, truncation=True, padding=True)['input_ids']
    # construct input token ids
    input_ids = text_ids
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * (len(text_ids) - 2) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def custom_forward(inputs):
    preds = model(inputs)[0]
    return torch.softmax(preds, dim = 1)[0][1].unsqueeze(-1)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def get_attributions(text):
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
    token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    attention_mask = construct_attention_mask(input_ids)
    
    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)
    attention_mask.detach()
    token_type_ids.detach() 
    ref_token_type_ids.detach()
    position_ids.detach()
    ref_position_ids.detach()
    ref_input_ids.detach()
    return attributions.cpu(), all_tokens, input_ids.cpu(), delta

lig = LayerIntegratedGradients(custom_forward, model.distilbert.embeddings)

In [None]:
# visualize example 1

idx = 0
text = test_news[idx]
label = test_labels[idx]
attributions,  all_tokens, input_ids, delta = get_attributions(text)
score = model(input_ids.to(device))[0]
prob = torch.softmax(score, dim = 1)[0][1]
attributions_sum = summarize_attributions(attributions)
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        prob,
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        label,
                        'peaceful',
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),peaceful,-6.85,[CLS] independent robert ss ##ent ##am ##u national unity platform presidential candidate says talks united states government human rights abuses military revealed addressing supporters campaign rally saturday didn ’ t reveal level engagement said citizens continue face brutality hands military remarks come two days secretary state michael po ##mp ##eo issued statement indicating closely paying attention actions individuals seeking imp ##ede democratic processes ahead general elections earlier eliot l eng ##el house committee foreign affairs wrote po ##mp ##eo secretary department treasury steven mn ##uchi ##n calling sanctions key military police officers particularly singled commander land forces lt gen peter el ##we ##lu commander special forces command sf ##c maj gen james bi ##run ##gi former sf ##c commander incoming commander troops somali maj gen william na ##bas ##a maj gen abel kan ##di ##ho chief military intelligence others deputy inspector general police maj gen steven sa ##bi ##iti mu ##zee ##yi commissioner police frank mw ##es ##ig ##wa police director crime intelligence col chris ser ##un ##jo ##gi dd ##am ##uli ##ra people killed last month security forces que ##lling protests resulting arrest according opposition presidential candidates also facing similar brutal ##izing interference campaigns security forces cites encounter li ##ra district friday police held team several hours one spot towing pad ##er district government hasn ’ t come formally respond po ##mp ##eo statements however national thanksgiving prayers president said country election won ’ t disrupted foreign influence ##rs elections peaceful nobody going disrupt heard people playing games foreigners backing well wish luck assure shall tolerate violence shall peaceful elections said previously rallied government suspend donation military seeking treatment following alleged torture military met top congressman bradley sherman highlighted described human rights violations repression freedoms government [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),peaceful,-6.85,[CLS] independent robert ss ##ent ##am ##u national unity platform presidential candidate says talks united states government human rights abuses military revealed addressing supporters campaign rally saturday didn ’ t reveal level engagement said citizens continue face brutality hands military remarks come two days secretary state michael po ##mp ##eo issued statement indicating closely paying attention actions individuals seeking imp ##ede democratic processes ahead general elections earlier eliot l eng ##el house committee foreign affairs wrote po ##mp ##eo secretary department treasury steven mn ##uchi ##n calling sanctions key military police officers particularly singled commander land forces lt gen peter el ##we ##lu commander special forces command sf ##c maj gen james bi ##run ##gi former sf ##c commander incoming commander troops somali maj gen william na ##bas ##a maj gen abel kan ##di ##ho chief military intelligence others deputy inspector general police maj gen steven sa ##bi ##iti mu ##zee ##yi commissioner police frank mw ##es ##ig ##wa police director crime intelligence col chris ser ##un ##jo ##gi dd ##am ##uli ##ra people killed last month security forces que ##lling protests resulting arrest according opposition presidential candidates also facing similar brutal ##izing interference campaigns security forces cites encounter li ##ra district friday police held team several hours one spot towing pad ##er district government hasn ’ t come formally respond po ##mp ##eo statements however national thanksgiving prayers president said country election won ’ t disrupted foreign influence ##rs elections peaceful nobody going disrupt heard people playing games foreigners backing well wish luck assure shall tolerate violence shall peaceful elections said previously rallied government suspend donation military seeking treatment following alleged torture military met top congressman bradley sherman highlighted described human rights violations repression freedoms government [SEP]
,,,,


In [8]:
# get global attribution score sum for each word
from tqdm import tqdm
VOCAB_SIZE = tokenizer.vocab_size

def get_global_attr_score_sum(list_of_texts):
    attr_score_sum_vec = np.zeros(VOCAB_SIZE)
    word_freq_vec = np.zeros(VOCAB_SIZE)

    for text in tqdm(list_of_texts):
        attr_score, _, input_ids, _ = get_attributions(text)
        attr_score = summarize_attributions(attr_score)
        word_freq_vec[input_ids.numpy()] += 1
        attr_score_sum_vec[input_ids.numpy()] += attr_score.numpy()

    return attr_score_sum_vec, word_freq_vec


def get_top_pos_words(word_weight_vec, n):
    top_n_indices = np.argsort(word_weight_vec)[::-1][:n]
    top_weight_vec = word_weight_vec[top_n_indices]
    return tokenizer.convert_ids_to_tokens(top_n_indices), top_weight_vec


def get_top_neg_words(word_weight_vec, n):
    top_n_indices = np.argsort(word_weight_vec)[:n]
    top_weight_vec = word_weight_vec[top_n_indices]
    return tokenizer.convert_ids_to_tokens(top_n_indices), top_weight_vec

In [None]:
# try finding global word importance using sampled data
attr_score_sum_vec, word_freq_vec = get_global_attr_score_sum(test_news[0:10000])
attr_score_avg_vec = attr_score_sum_vec / (word_freq_vec + 50)

# save attention_score_sum vec and word_freq vec
save_dir = '/content/drive/MyDrive/news-split-data-processed'
np.save('{}/integrated_grad_attr_score'.format(save_dir), attr_score_sum_vec)
np.save('{}/word_freq'.format(save_dir), word_freq_vec)

100%|██████████| 10000/10000 [2:55:35<00:00,  1.05s/it]


In [11]:
# load from file
import numpy as np

save_dir = '/content/drive/MyDrive/news-split-data-processed'
attr_score_sum_vec = np.load('{}/integrated_grad_attr_score.npy'.format(save_dir))
word_freq_vec = np.load('{}/word_freq.npy'.format(save_dir))

In [53]:
attr_score_avg_vec = attr_score_sum_vec / (word_freq_vec + 100)
pos_top_words, pos_top_weights = get_top_pos_words(attr_score_avg_vec, 100)

In [54]:
attr_score_avg_vec = attr_score_sum_vec / (word_freq_vec + 150)
neg_top_words, neg_top_weights = get_top_neg_words(attr_score_avg_vec, 100)

In [58]:
pos_top_words = [word for word in pos_top_words if '#' not in word and len(word) > 2][0: 20]
neg_top_words  = [word for word in neg_top_words if '#' not in word and len(word) > 2][0: 20]

In [56]:
pos_top_words

['europe',
 'colleagues',
 'conservative',
 'manchester',
 'martin',
 'nord',
 'kerry',
 'european',
 'infections',
 'winter',
 'extra',
 'conservatives',
 'stab',
 'russian',
 'investigation',
 'russia',
 'david',
 'ukraine',
 'tommy',
 'euros',
 'bergen',
 'review',
 'partner',
 'recent',
 'longer',
 'much',
 'berlin',
 'poland',
 'confirmed',
 'crime']

In [57]:
neg_top_words

['independent',
 'officials',
 'level',
 'state',
 'says',
 'started',
 'day',
 'district',
 'administration',
 'name',
 'pune',
 'various',
 'activities',
 'title',
 'shah',
 'days',
 'meet',
 'village',
 'raj',
 'road',
 'governor',
 'man',
 'board',
 'sub',
 'country',
 'capital',
 'sanctions',
 'government',
 'stated',
 'sides']