In [None]:
# dependencies installments
!pip install transformers
!pip install captum

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# load data
import os
import json

data_path = '/content/drive/MyDrive/news-split-data-processed'

# with open('{}/train_text.json'.format(data_path), 'r') as f:
#     train_text = json.load(f)

with open('{}/test_text.json'.format(data_path), 'r') as f:
    test_data = json.load(f)

# with open('{}/train_label.json'.format(data_path), 'r') as f:
#     train_label = json.load(f)

with open('{}/test_label.json'.format(data_path), 'r') as f:
    test_label = json.load(f)

In [None]:
# define dataset
import torch
from torch.utils.data import Dataset

class NewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


In [None]:
# load model
import numpy as np
from torch.utils.data import DataLoader
from transformers import DistilBertForSequenceClassification
from transformers import DistilBertConfig, DistilBertTokenizerFast, DistilBertForSequenceClassification
from tqdm import tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model_path = '/content/drive/MyDrive/news_peace_models/distilbert-uncased-train-on-processed-data'
model = DistilBertForSequenceClassification.from_pretrained(model_path)
model.to(device)
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [None]:
# evaluate model

test_encodings = tokenizer(test_data, truncation=True, padding=True)
test_dataset = NewsDataset(test_encodings, test_label)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)


pred = []
labels = []
with torch.no_grad():
  for batch in tqdm(test_loader):
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels += list(batch['labels'])
      outputs = model(input_ids, attention_mask=attention_mask)
      pred += list(outputs.logits.argmax(axis=1).cpu().numpy())


100%|██████████| 652/652 [12:14<00:00,  1.13s/it]


In [None]:
# evaluate by computing confusion matrix
from sklearn.metrics import confusion_matrix, precision_score, recall_score

conf_mat = confusion_matrix(labels, pred)
precision = precision_score(labels, pred)
recall = recall_score(labels, pred)

print('confusion matrix:')
print(conf_mat)
print('recall: {}, precision: {}'.format(recall, precision))

confusion matrix:
[[19886   368]
 [  291 21166]]
recall: 0.986437992263597, precision: 0.9829107457973437


In [None]:
# def function to get attention sum score
import torch
from transformers import DistilBertConfig, DistilBertTokenizerFast, DistilBertForSequenceClassification


config = DistilBertConfig.from_pretrained(model_path, output_attentions=True)
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model_attention = DistilBertForSequenceClassification.from_pretrained(model_path, config=config)

model_attention.to(device)
# get attention score of text
# return an array of attention score
def get_attention_score(text):
    encodings = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)
    with torch.no_grad():
      out = model_attention(input_ids=input_ids, attention_mask=attention_mask)

    att_sum = out.attentions[0]\
                 .cpu()\
                 .reshape(12, len(input_ids[0]), len(input_ids[0]))\
                 .sum(axis=[0, 1]).numpy()
    # ignore [CLS] and [SEP] in bert tokens
    return att_sum[1: -1] / max(att_sum[1: -1]), input_ids.cpu()[0][1: -1]


In [None]:
# define function to visualize attention score of input text
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML

def colorize(words, color_array):
    cmap=matplotlib.cm.Blues
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    length = 0
    line_len = 50
    for word, color in zip(words, color_array):
        if (length + len(word)) // line_len  - length // line_len == 1:
            word += '\n'
        length += len(word)
        color *= 0.8
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')

    #display(HTML(colored_string))
    return colored_string

In [None]:
# an example to get attention score of input
idx = 2300
text = test_data[idx]
label = test_label[idx]
print('label of text: {}'.format(label))
att_score, input_ids = get_attention_score(text)
tokens = tokenizer.convert_ids_to_tokens(input_ids)
display(HTML(colorize(tokens, att_score)))


label of text: 0


In [None]:
# get global attribution score sum for each word
from tqdm import tqdm
VOCAB_SIZE = tokenizer.vocab_size

def get_global_attr_score_sum(list_of_texts):
    attr_score_sum_vec = np.zeros(VOCAB_SIZE)
    word_freq_vec = np.zeros(VOCAB_SIZE)

    for text in tqdm(list_of_texts):
        attr_score, input_ids = get_attention_score(text)
        word_freq_vec[input_ids.numpy()] += 1
        attr_score_sum_vec[input_ids.numpy()] += attr_score

    return attr_score_sum_vec, word_freq_vec


def get_top_words(word_weight_vec, n):
    top_n_indices = np.argsort(word_weight_vec)[::-1][:n]
    top_weight_vec = word_weight_vec[top_n_indices]
    return tokenizer.convert_ids_to_tokens(top_n_indices), top_weight_vec


In [None]:
pos_data = []
neg_data = []

for i, label in enumerate(test_label):
    if label == 0:
        neg_data.append(test_data[i])
    else:
        pos_data.append(test_data[i])

sample_pos_data = pos_data[0: 1000]
sample_neg_data = neg_data[0: 1000]

In [None]:
# try finding global word importance using sampled data

## for peaceful news
pos_attr_score_sum_vec, pos_word_freq_vec = get_global_attr_score_sum(pos_data)
# save
save_dir = '/content/drive/MyDrive/news-split-data-processed'
np.save('{}/pos_attr_score'.format(save_dir), pos_attr_score_sum_vec)
np.save('{}/pos_word_freq'.format(save_dir), pos_word_freq_vec)


100%|██████████| 21457/21457 [05:18<00:00, 67.31it/s]


In [34]:
pos_attr_score_avg_vec = pos_attr_score_sum_vec / (pos_word_freq_vec + 200)
pos_top_words, pos_top_weights = get_top_words(pos_attr_score_avg_vec, 30)
pos_top_words, pos_top_weights

(['##virus',
  'privacy',
  '##ø',
  'vaccine',
  'email',
  'according',
  'infections',
  'says',
  'sabotage',
  'advertisement',
  'please',
  'billion',
  'taxes',
  'countries',
  'prices',
  'conservatives',
  'ukraine',
  'isn',
  'thursday',
  'weekend',
  'sanctions',
  'nations',
  'amid',
  'installations',
  'wednesday',
  '##inen',
  'symptoms',
  'manchester',
  'customers',
  'officials'],
 array([0.81675715, 0.80593075, 0.75064789, 0.74524113, 0.7212161 ,
        0.71870005, 0.71578149, 0.69966735, 0.6934266 , 0.69015828,
        0.68613266, 0.67994335, 0.66877981, 0.66474767, 0.66274381,
        0.65596524, 0.65109709, 0.64679513, 0.64563156, 0.64245407,
        0.63688827, 0.6357619 , 0.63079662, 0.62982324, 0.62826459,
        0.62427367, 0.62367907, 0.62113535, 0.62073201, 0.61841205]))

In [17]:
## for unpeaceful news
neg_attr_score_sum_vec, neg_word_freq_vec = get_global_attr_score_sum(neg_data)
# save
np.save('{}/neg_attr_score'.format(save_dir), neg_attr_score_sum_vec)
np.save('{}/neg_word_freq'.format(save_dir), neg_word_freq_vec)


100%|██████████| 20254/20254 [05:26<00:00, 62.04it/s]


In [33]:
neg_attr_score_avg_vec = neg_attr_score_sum_vec / (neg_word_freq_vec + 200)
neg_top_words, neg_top_weights = get_top_words(neg_attr_score_avg_vec, 30)
neg_top_words, neg_top_weights

(['##virus',
  'sanctions',
  'billion',
  'according',
  'says',
  'countries',
  'raj',
  'nations',
  'presidential',
  '##pur',
  'tourism',
  'calendar',
  'spokesperson',
  'arrested',
  'vaccine',
  'elections',
  'unesco',
  '##while',
  'shall',
  'commissioner',
  'stakeholders',
  'congress',
  'abu',
  '##icia',
  'ensure',
  'agencies',
  'bengal',
  'officials',
  'minister',
  'ministry'],
 array([0.8017075 , 0.75088854, 0.72878839, 0.70891597, 0.70287592,
        0.6715634 , 0.66953148, 0.6631209 , 0.6624641 , 0.65966674,
        0.65788452, 0.65118907, 0.64760272, 0.64590822, 0.64407004,
        0.63879274, 0.63672333, 0.63469735, 0.63376146, 0.63234317,
        0.63230968, 0.63206509, 0.63125794, 0.63016866, 0.62740668,
        0.62405477, 0.61691774, 0.6136403 , 0.6135506 , 0.61266212]))

In [35]:
set(pos_top_words) - set(neg_top_words)

{'##inen',
 '##ø',
 'advertisement',
 'amid',
 'conservatives',
 'customers',
 'email',
 'infections',
 'installations',
 'isn',
 'manchester',
 'please',
 'prices',
 'privacy',
 'sabotage',
 'symptoms',
 'taxes',
 'thursday',
 'ukraine',
 'wednesday',
 'weekend'}

In [36]:
set(neg_top_words) - set(pos_top_words)

{'##icia',
 '##pur',
 '##while',
 'abu',
 'agencies',
 'arrested',
 'bengal',
 'calendar',
 'commissioner',
 'congress',
 'elections',
 'ensure',
 'minister',
 'ministry',
 'presidential',
 'raj',
 'shall',
 'spokesperson',
 'stakeholders',
 'tourism',
 'unesco'}