__Objective__: To train PyTorch BERT model on _Change My View_ dataset and use it to generate heat maps for ad hominem tweets

__Runtime__: GPU

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm 
import pickle as pkl
import matplotlib.pyplot as plt
from matplotlib import colors

# Training PyTorch BERT

In [None]:
import os

def read_split(dir):
    texts = []
    labels = []
    with open(dir, 'r', encoding='utf-8') as f:
        lines = f.readlines()[1:]
        for line in lines:
            line = line.strip().split(",")
            assert(len(line) == 2)
            label = line[0].strip()
            text = line[1].strip()
            texts.append(text)
            if label == 'AH':
                labels.append(1)
            else:
                labels.append(0)
    return texts, labels


train_texts, train_labels = read_split('/content/gdrive/MyDrive/DL/dataset/pytorch/train.csv')
test_texts, test_labels = read_split('/content/gdrive/MyDrive/DL/dataset/pytorch/test.csv')

In [None]:
!pip install transformers
from transformers import BertTokenizer, BertForSequenceClassification

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
max_seq_length = 64
train_encodings = tokenizer(train_texts, truncation=True, max_length=max_seq_length, padding="max_length")
test_encodings = tokenizer(test_texts, truncation=True, max_length=max_seq_length, padding="max_length")

In [None]:
import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = CustomDataset(train_encodings, train_labels)
test_dataset = CustomDataset(test_encodings, test_labels)

In [None]:
from torch.utils.data import DataLoader
from transformers import AdamW

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.to(device)
model.train()

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
optim = AdamW(model.parameters(), lr=5e-5)

In [None]:
from tqdm import tqdm
for epoch in range(3):
    for batch in tqdm(train_loader):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()

In [None]:
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()

output_dir = './utkbert/'

# Create output directory if needed
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

print("Saving model to %s" % output_dir)

# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

# Visualization set-up

In [None]:
from transformers import BertModel, BertTokenizer
import re

In [None]:
model_version = 'utkbert'
do_lower_case = True
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

In [None]:
INTENSITY = 70

def attention_scores(text, layers=None, heads=None):
    sentence_a = text
    inputs = tokenizer.encode_plus(sentence_a, None, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    attention = model(input_ids)[-1]
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    sz = len(tokens)
    matrix = [0 for j in range(sz)]
    if layers is None:
        layers = [x for x in range(12)]
    if heads is None:
        heads = [x for x in range(12)]
    for layer in layers:
        for head in heads:
            for j in range(sz):
                matrix[j] += attention[layer][0, head, 0, j].item()
    for j in range(sz):
        matrix[j] = (matrix[j]) / (len(layers) * len(heads))
    return (tokens, matrix)

In [None]:
def clean_array(w, a):
    W = []
    A = []
    for i in range(len(w)):
        if (w[i].startswith('##')):
            W[len(W) - 1] += w[i][2:]
            A[len(A) - 1] = (A[len(A) - 1] + a[i]) / 2
        else:
            W.append(w[i])
            A.append(a[i])
    return clean_apos(W, A)

def clean_apos(w, a):
    W = []
    A = []
    ctr = 0
    while ctr != len(w):
        if w[ctr] == '\'':
            W[-1] += w[ctr] + w[ctr + 1]
            A[-1] = min(INTENSITY, A[-1] + a[ctr] + a[ctr + 1])
            ctr += 2
        else:
            W.append(w[ctr])
            A.append(a[ctr])
            ctr += 1
    return W, A

In [None]:
def top_three_tokens(text):
    words, attentions = attention_scores(text)
    words = words[1:-1] # Remove start and end tags
    attentions = attentions[1:-1]
    assert len(words) == len(attentions)
    words, attentions = clean_array(words, attentions)
    assert len(words) == len(attentions)
    top_tokens = list()
    for i in range(len(words)):
        top_tokens.append((attentions[i], i))
    top_tokens = sorted(top_tokens, reverse=True)
    ind = [0]
    cur = 1
    while len(ind) < 3:
        take = True
        for ids in ind:
            take = take and abs(top_tokens[ids][1] - top_tokens[cur][1]) > 2
        if take:
            ind.append(cur)
        cur += 1
    xx = []
    for x in ind:
        xx.append(top_tokens[x][1])
    scores = [0 for i in range(len(words))]
    for w in xx:
        lst = [w - 1, w, w + 1]
        for j in lst:
            if j >= 0 and j < len(words):
                scores[j] = INTENSITY
    return words, scores

In [None]:
def clean_word(word_list):
  new_word_list = []
  for word in word_list:
    for latex_sensitive in ["\\", "%", "&", "^", "#", "_",  "{", "}"]:
      if latex_sensitive in word:
        word = word.replace(latex_sensitive, '\\'+latex_sensitive)
    new_word_list.append(word)
  return new_word_list

In [None]:
header = r'''\documentclass[10pt,a4paper]{article}
\usepackage[left=1.00cm, right=1.00cm, top=1.00cm, bottom=2.00cm]{geometry}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}''' + '\n\n'

footer = r'''\end{CJK*}
\end{document}'''

def heatmap(word_list, attention_list, label_list, latex_file, title, batch_size=20, color='blue'):
    '''Routine to generate attention heatmaps for given texts
    ---------------------------------------------------------
    Input:
    :param word_list: array of texts
    :param attention_list: array of attention scores for each text
    :param label_list: label for each text
    :param latex_file: name of the latex file
    :param title: title of latex file
    :param batch_size: Number of comments in each batch
    '''
    with open(latex_file, 'w', encoding='utf-8') as f:
        f.write(header)
        f.write('\\section{%s}\n\n' % title)

        n_examples = len(word_list)
        n_batches = n_examples // batch_size

        for i in range(n_batches):
            batch_word_list = word_list[i * batch_size: (i + 1) * batch_size]
            batch_attention_list = attention_list[i * batch_size: (i + 1) * batch_size]
            batch_label_list = label_list[i * batch_size: (i + 1) * batch_size]
            f.write('\\subsection{Batch %d}\n\n' % (i + 1))
            for j in range(batch_size):
                f.write('\\subsubsection{Comment %d - %s}\n\n' % (j + 1, batch_label_list[j]))
                sentence = batch_word_list[j]
                score = batch_attention_list[j]
                assert len(sentence) == len(score)
                f.write('\\noindent')
                for k in range(len(sentence)):
                    f.write('\\colorbox{%s!%s}{' % (color, score[k]) + '\\strut ' + sentence[k] + '} ')
                f.write('\n\n')

        f.write(footer)

In [None]:
import string

def sanitize(text):
    text = text.lower()
    text = re.sub("\s+", " ", text)  # converting space-like character to single white space
    text = re.sub("\u2018", '\'', text)    # encoding apostrophe to X
    text = re.sub("\u2019", '\'', text)    # encoding apostrophe to X
    xx = ''
    for x in text:
        if x in string.punctuation and x != '\'':
            xx += ' '
        xx += x
    text = xx
    text = text.split()
    new_text = []
    for x in text:
        ok = False
        for y in x:
            ok = ok or y.isalnum()
        if ok:
            for c in string.punctuation:
                x = x.strip(c)
            new_text.append(x)
    return ' '.join(clean_word(new_text))

In [None]:
# modified implementation of `top_three_tokens` routine
# this will return those top tokens instead of token, score lists

def top_three_tokens2(text):
    words, attentions = attention_scores(text)
    words = words[1:-1] # Remove start and end tags
    attentions = attentions[1:-1]
    assert len(words) == len(attentions)
    words, attentions = clean_array(words, attentions)
    assert len(words) == len(attentions)
    top_tokens = list()
    for i in range(len(words)):
        top_tokens.append((attentions[i], i))
    top_tokens = sorted(top_tokens, reverse=True)
    ind = [0]
    cur = 1
    while len(ind) < 3:
        take = True
        for ids in ind:
            take = take and abs(top_tokens[ids][1] - top_tokens[cur][1]) > 2
        if take:
            ind.append(cur)
        cur += 1
    xx = []
    for x in ind:
        xx.append(top_tokens[x][1])
    scores = [0 for i in range(len(words))]
    res = list()
    for w in xx:
        res_ = list()
        lst = [w - 1, w, w + 1]
        for j in lst:
            if j >= 0 and j < len(words):
                res_.append(words[j])
        res.append(tuple(res_))
    #return words, scores
    return res

# Twitter Data

In [None]:
# loading the saved tweets
tw_base_addr = '/content/gdrive/MyDrive/DL/Twitter/classified/{}.csv'
tw_pages = ['nytimes', 'npr', 'foxnews', 'breitbart']

df = dict()
for tw_page in tw_pages:
    df[tw_page] = pd.read_csv(tw_base_addr.format(tw_page))

In [None]:
# combined df
cdf = pd.concat([df['nytimes'], df['npr'], df['foxnews'], df['breitbart']])

In [None]:
len(cdf)

In [None]:
cdf['score'].hist(bins=[x/100 for x in range(101)]) # classification score distribution

In [None]:
cdf = cdf[cdf['score'] < 0.05] # filtering the top ad hominem tweets
len(cdf)

In [None]:
cdf

In [None]:
# preparing trigrams
# these will be used for comparison with Facebook and CreateDebate

tw_texts = list(cdf['pptweet']) 
tw_freq = dict()
tw_success = 0 

for tw_text in tqdm(tw_texts):
    try:
        for tw_tr in top_three_tokens2(sanitize(tw_text)):
            try:
                tw_freq[tw_tr] += 1 
            except KeyError:
                tw_freq[tw_tr] = 1
        tw_success += 1 
    except: 
        pass

In [None]:
# saving the computation (useful in case of session crash)

with open('/content/gdrive/MyDrive/Temp/btp_36_tw_freq.pkl', 'wb') as fp:
    pkl.dump(tw_freq, fp)

In [None]:
with open('/content/gdrive/MyDrive/Temp/btp_36_tw_freq.pkl', 'rb') as fp:
    tw_freq = pkl.load(fp)

In [None]:
tw_data = list()

for k, v in tw_freq.items():
    tw_data.append((v, k))

tw_data = sorted(tw_data, reverse=True)

In [None]:
for cnt, tr in tw_data[20:40]:
    print(f'{str(tr):40} - {cnt}')

In [None]:
tw_ad_texts = ' '.join(tw_texts)

In [None]:
# extracting top 100 ad hominem tweets from each `tw_page`

tw_vdata = dict()

for tw_page in tw_pages: 
    tw_vdata[tw_page] = list()

    for index, row in df[tw_page].iterrows():
        tw_vdata[tw_page].append((row['score'], row['pptweet']))
    tw_vdata[tw_page].sort()

tw_viz_texts = dict()

for tw_page in tw_pages:
    tw_viz_texts[tw_page] = list()
    for _score, _text in tw_vdata[tw_page][:100]:
        tw_viz_texts[tw_page].append(_text)

In [None]:
# creating visualizations

for tw_page in tw_pages:
    tw_vtexts = list()
    tw_vscores = list()
    for tw_text in tw_viz_texts[tw_page]:
        sent = sanitize(tw_text)
        try:
            tw_t, tw_s = top_three_tokens(sent)
            tw_vtexts.append(tw_t) 
            tw_vscores.append(tw_s)
        except:
            pass
    heatmap(tw_vtexts, tw_vscores, ['Ad hominem'] * len(tw_vtexts), f'{tw_page}.tex', f'Ad hominem tweets: {tw_page}', color='cyan')

# Comparing triggers in FB Data and CreateDebate

## Facebook

In [None]:
fb_df = pd.read_csv('/content/gdrive/MyDrive/DL/Facebook/fbscraper/nytimes/2016/2016c.csv')

In [None]:
fb_df

In [None]:
fb_df['score'].hist(bins=[x/100 for x in range(101)])

In [None]:
# considering comments which are classified as ad hominem
# with at least 95% confidence
fb_df = fb_df[fb_df['score'] < 0.05] 

In [None]:
len(fb_df) / 126328 * 100

In [None]:
fb_texts = list(fb_df['processedText'])

In [None]:
#fb_texts = list(fb_df['processedText'])
fb_freq = dict() # dictionary with token trigrams: frequency
fb_success = 0
for fb_text in tqdm(fb_texts): 
    try:
        for fb_tr in top_three_tokens2(sanitize(fb_text)):
            try:
                fb_freq[fb_tr] += 1
            except KeyError: 
                fb_freq[fb_tr] = 1
        fb_success += 1
    except:
        # probably fb_text is too short to contain 3 trigrams
        pass

In [None]:
# saving the computation (useful in case of session crash)

with open('/content/gdrive/MyDrive/Temp/btp_36_fb_freq.pkl', 'wb') as fp:
    pkl.dump(fb_freq, fp)

In [None]:
with open('/content/gdrive/MyDrive/Temp/btp_36_fb_freq.pkl', 'rb') as fp:
    fb_freq = pkl.load(fp)

In [None]:
fb_data = list()
for k, v in fb_freq.items():
    fb_data.append((v, k))
fb_data = sorted(fb_data, reverse=True)

In [None]:
for cnt, tr in fb_data[:20]:
    print(f'{str(tr):40} - {cnt}')

In [None]:
len(fb_data)

In [None]:
fb_ad_text = ' '.join(fb_texts)

In [None]:
with open('/content/gdrive/MyDrive/Temp/btp_36_fb_ad_text.pkl', 'wb') as fp:
    pkl.dump(fb_ad_text, fp)

## CreateDebate

In [None]:
# loading createdebate corpus
!git clone https://github.com/utkarsh512/CreateDebateScraper.git
%cd CreateDebateScraper/src/nested/

In [None]:
from thread import Thread, Comment # for CreateDebate corpus
import pickle
from copy import deepcopy

In [None]:
cd_comments = dict()
cd_categories = ['politics2', 'religion', 'world', 'science', 'law', 'technology']

for cat in cd_categories:
    cd_comments[cat] = list()

In [None]:
for cat in tqdm(cd_categories):
    fp = open('/content/gdrive/MyDrive/DL/CreateDebate/' + cat + '/threads.log', 'rb')
    threads = list()
    try:
        while True:
            e = pickle.load(fp)
            threads.append(e)
    except EOFError:
        fp.close()
    print(f'{cat} - {len(threads)}')
    authors = dict()
    for thread in threads:
        for k, v in thread.comments.items():
            try:
                authors[v.author].append(v)
            except:
                authors[v.author] = list()
                authors[v.author].append(v)
    ctr = 0
    with open('/content/gdrive/MyDrive/DL/CreateDebate/' + cat + '/comments_with_score.log', 'rb') as fp:
        cws = pickle.load(fp)
    for author in authors.keys():
        for i in range(len(authors[author])):
            comment = authors[author][i]
            foo = deepcopy(comment.__dict__)
            foo['tag'] = cat
            foo['score'] = cws[ctr][0]
            foo['validation'] = cws[ctr][1][0]
            cd_comments[cat].append(foo)
            ctr += 1

In [None]:
cd_comments['law'][0]

In [None]:
# plotting score distribution

cd_scores = list()

for cat in cd_categories:
    for cd_comment in cd_comments[cat]:
        cd_scores.append(cd_comment['score'])

In [None]:
plt.hist(cd_scores, bins=[x/100 for x in range(101)])

In [None]:
len(cd_scores)

In [None]:
# considering only ad hominem comments with 95% classification score

cd_texts = list()
cd_authors = set()

for cat in cd_categories:
    for cd_comment in tqdm(cd_comments[cat]):
        cd_authors.add(cd_comment['author'].lower())
        if cd_comment['score'] < 0.05:
            cd_texts.append(sanitize(cd_comment['body']))

In [None]:
cd_texts = cd_texts[:10000]

In [None]:
# Removing name of authors from the comment text

cd_texts_pp = list()
for cd_text in tqdm(cd_texts):
    pp_tokens = list()
    for token in cd_text.split():
        if token not in cd_authors: 
            pp_tokens.append(token) 
    cd_texts_pp.append(' '.join(pp_tokens))

In [None]:
cd_it = 56
print(f'{cd_texts[cd_it]}\n\n{cd_texts_pp[cd_it]}')

In [None]:
cd_freq = dict() 
cd_success = 0 
for cd_text in tqdm(cd_texts_pp): 
    try:
        for cd_tr in top_three_tokens2(cd_text):
            try:
                cd_freq[cd_tr] += 1 
            except KeyError:
                cd_freq[cd_tr] = 1  
        cd_success += 1  
    except:
        pass

In [None]:
with open('/content/gdrive/MyDrive/Temp/btp_36_cd_freq.pkl', 'wb') as fp:
    pkl.dump(cd_freq, fp)

In [None]:
with open('/content/gdrive/MyDrive/Temp/btp_36_cd_freq.pkl', 'rb') as fp:
    cd_freq = pkl.load(fp)

In [None]:
cd_data = list()
for k, v in cd_freq.items():
    cd_data.append((v, k))  
cd_data = sorted(cd_data, reverse=True)  

In [None]:
for cd_s, cd_t in cd_data[:20]:   
    print(f'{str(cd_t):30} - {cd_s}')

In [None]:
cd_ad_text = ' '.join(cd_texts_pp)

In [None]:
with open('/content/gdrive/MyDrive/Temp/btp_36_cd_ad_text.pkl', 'wb') as fp:
    pkl.dump(cd_ad_text, fp)

## Comparison

In [None]:
!pip install shifterator

In [None]:
import shifterator as sh

In [None]:
fb_set = set(fb_freq.keys())
cd_set = set(cd_freq.keys())

In [None]:
for _x in fb_set[:100]:
    print(_x)

In [None]:
# jacquard overlap
len(fb_set & cd_set) / len(fb_set | cd_set) * 100

In [None]:
fb_tokens = dict()  
cd_tokens = dict() 
tw_tokens = dict() 

for token in fb_ad_text.strip().split():
    token = token.strip() 
    try:
        fb_tokens[token] += 1 
    except KeyError:
        fb_tokens[token] = 1  

for token in cd_ad_text.strip().split():
    token = token.strip() 
    try:
        cd_tokens[token] += 1  
    except KeyError:
        cd_tokens[token] = 1   

for token in tw_ad_texts.strip().split():
    token = token.strip() 
    try:
        tw_tokens[token] += 1  
    except KeyError:
        tw_tokens[token] = 1  

In [None]:
jsd_shift_1 = sh.JSDivergenceShift(type2freq_1=tw_tokens,
                                   type2freq_2=fb_tokens,
                                   weight_1=0.5,
                                   weight_2=0.5,
                                   base=2,
                                   alpha=1)

In [None]:
jsd_shift_1.get_shift_graph(title='Jensen-Shannon Divergence Shifts b/w Twitter and Facebook')

In [None]:
len(cd_data)

In [None]:
fr = [1, 2, 5, 10, 20, 50, 100]

jac_matrix = [[0 for i in range(len(fr))] for j in range(len(fr))]

for i in range(len(fr)):
    for j in tqdm(range(len(fr))):
        tw_limit = int(len(tw_data) * (fr[i] / 100))
        fb_limit = int(len(fb_data) * (fr[j] / 100))
        tw_set = set()
        fb_set = set()
        for _, tokens in tw_data[:tw_limit]:
            tw_set.add(tokens) 
        for _, tokens in fb_data[:fb_limit]: 
            fb_set.add(tokens) 
        jac_matrix[i][j] = len(tw_set & fb_set) / len(tw_set | fb_set) * 100

In [None]:
fr_str = list()
for _x in fr:
    fr_str.append(str(_x))

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(jac_matrix, interpolation='nearest')
fig.colorbar(cax)
ax.set_xticks(np.arange(len(fr)))
ax.set_yticks(np.arange(len(fr)))
ax.set_xticklabels(fr_str)
ax.set_yticklabels(fr_str)
ax.set_ylabel("% Top trigrams in Twitter", rotation='vertical')
ax.set_xlabel("% Top trigrams in Facebook")
plt.setp(ax.get_xticklabels(), rotation=90)
plt.show()

In [None]:
len(tw_data)