# Utils

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

In [None]:
# Clone `CreateDebateScraper` library from github
!git clone https://github.com/utkarsh512/CreateDebateScraper.git
%cd CreateDebateScraper/src/nested/

In [None]:
!pip install transformers

In [None]:
!pip install cpnet

In [None]:
from   copy                     import deepcopy
import cpnet
from   itertools                import accumulate
import json
from   matplotlib               import pyplot as plt
import networkx as nx
import nltk
import numpy as np
import pandas as pd
import pickle
import re
from   scipy                    import stats
import textwrap
from   thread                   import Comment, Thread
from   tqdm                     import tqdm
nltk.download('punkt') # For tokenizers
nltk.download('stopwords')
import matplotlib
from   nltk.tokenize            import TweetTokenizer
from   nltk.corpus              import stopwords
from   pprint                   import pprint
from   transformers             import BertModel, BertTokenizer
# import shifterator as sh
# import wordcloud
# import skbio
matplotlib.rcParams.update({'font.size': 18})
matplotlib.rcParams["figure.figsize"] = (12, 5)
STOP_WORDS = list(stopwords.words('english'))

In [None]:
# Custom routine to clean texts scraped from Web.
# It removes hyperlinks, punctuation marks (except apostrophe)

tknz = TweetTokenizer()

def clean_text(text):
    """
    Preprocessing text
    """
    text = text.lower()
    text = re.sub(r"http\S+", "", text)
    text = re.sub(r"www\S+", "", text)
    text = re.sub("-", " ", text)
    text = re.sub("\s+", " ", text)
    text = re.sub("\u2018", "X", text) 
    text = re.sub("\u2019", "X", text) 
    text = re.sub("\'", "X", text) 
    wordTokens_ = tknz.tokenize(text)
    wordTokens = list()
    for x in wordTokens_:
        x = ''.join([v for v in x if v.isalnum() or v == ' '])
        if len(x) > 0 and x != 'X':
            x = x.replace('X', '\'')
            wordTokens.append(x)
    return wordTokens

In [None]:
comments = dict()

# Topical forums on CreateDebate. We have scraped comments for all of the
# following forurm.
categories = ['business', 'comedy', 'entertainment', 'health', 'law', 'nsfw',
              'politics2', 'religion', 'science', 'shopping', 'sports',
              'technology', 'travel', 'world']

# However, we will be analyzing comments from selected forum only!
# These forum have at least 10k comments each.
categories_selected = ['politics2', 'religion', 'world', 
                       'science', 'law', 'technology']

for x in categories_selected:
    comments[x] = list()

In [None]:
# Loading comments from select forums

for cat in tqdm(categories_selected):
    fp = open('/content/gdrive/MyDrive/DL/CreateDebate/' + cat + '/threads.log', 'rb')

    # Get all the `Thread` objects pickled while scraping.
    threads = list()
    try:
        while True:
            e = pickle.load(fp)
            threads.append(e)
    except EOFError:
        fp.close()

    # While classifying CreateDebate comments, we used comments as per author mode.
    # Hence, using the same mode to attach classification score with the comments.
    # 
    # score < 0.5 -> ad hominem comment
    #       > 0.5 -> non ad hominem comment
    authors = dict()
    for thread in threads:
        for k, v in thread.comments.items():
            try:
                authors[v.author].append((v, k))
            except:
                authors[v.author] = list()
                authors[v.author].append((v, k))

    ctr = 0
    # Load the classification score of the comments.
    with open('/content/gdrive/MyDrive/DL/CreateDebate/' + cat + '/comments_with_score.log', 'rb') as fp:
        cws = pickle.load(fp)
    # Attach classification score with the comments.
    for author in authors.keys():
        for i in range(len(authors[author])):
            comment, cid = authors[author][i]
            foo = deepcopy(comment.__dict__)
            foo['tag'] = cat
            foo['score'] = cws[ctr][0]
            foo['validation'] = cws[ctr][1][0]
            foo['id'] = int(cid[3:])
            comments[cat].append(foo)
            ctr += 1

In [None]:
ah_score_comments = dict()

for cat in categories_selected:
    for comment in comments[cat]:
        ah_score_comments[comment['id']] = 1 - comment['score']

In [None]:
def parse_tstring(tstring):
    """
    Parses comment's time to an integer to enable
    comparison between comments based on their time of posting
    """
    if tstring == 'Not Available':
        raise ValueError('Invalid posting time for parse_tstring')
    tstring = tstring.replace('T', '-').replace(':', '-').replace('+', '-').split('-')
    return int(''.join(tstring[:-2]))

In [None]:
# Loading AH score

with open('/content/gdrive/MyDrive/Temp/47-ah-score.pkl', 'rb') as fp:
    ah_score = pickle.load(fp)

# `ah_score` is a dictionary that contains the ah score of the comments written
# by all the users

# key: category -> user
# value: list of ah_score for given user for given category

# value > 0.5 --> ad hominem
# value < 0.5 --> non ad hominem

In [None]:
# Loading CreateDebate profile characteristics into dataframe
df = pd.read_json('/content/gdrive/MyDrive/DL/CreateDebate/profile/results.json', lines=True)

# Extract useful characteristics
reward_points_map = {k : v for k, v in zip(df['username'].tolist(), df['reward_points'].tolist())}
efficiency_map    = {k : v for k, v in zip(df['username'].tolist(), df['efficiency'].tolist())}
allies_map        = {k : len(v) for k, v in zip(df['username'].tolist(), df['allies'].tolist())}
enemies_map       = {k : len(v) for k, v in zip(df['username'].tolist(), df['enemies'].tolist())}
hostiles_map      = {k : len(v) for k, v in zip(df['username'].tolist(), df['hostiles'].tolist())}

In [None]:
def profile_characteristics_stats(user_subset):
    """
    Returns average and standard deviation of profile characteristics for 
    given subset of users.

    :param user_subset: Iterable containing usernames

    >>> avgs, stds = profile_characterisitics_stat(user_subset)
    >>> rewards_avg, efficiency_avg, n_allies_avg, n_enemies_avg, n_hostiles_avg = avgs
    >>> rewards_std, efficiency_std, n_allies_std, n_enemies_std, n_hostiles_std = stds

    Note that profile characteristics for some users might not be present in our
    dataset as some users might have deleted their account when we scraped the
    forum to obtain these characteristics.
    """
    rewards_ = list()
    efficiency_ = list()
    n_allies = list()
    n_enemies = list()
    n_hostiles = list()

    for user in user_subset:
        try:
            rewards_.append(reward_points_map[user])
        except:pass
        try:
            efficiency_.append(efficiency_map[user])
        except:pass
        try:
            n_allies.append(allies_map[user])
        except:pass
        try:
            n_enemies.append(enemies_map[user])
        except:pass
        try:
            n_hostiles.append(hostiles_map[user])
        except:pass
    
    grpd_data = [rewards_, efficiency_, n_allies, n_enemies, n_hostiles]
    avgs = [np.average(x) for x in grpd_data]
    stds = [np.std(x) for x in grpd_data]
    
    return avgs, stds

In [None]:
# Maximum ah score per category per author
#   key: category -> author
#   value: maximum ah score

ah_score_max = dict()

for category, author_data in ah_score.items():
    ah_score_max[category] = dict()
    for author, ah_scores in author_data.items():
        ah_score_max[category][author] = np.max(ah_scores)

In [None]:
comment_count = dict()
# key: category -> author
# value: number of comments written by author in the given forum

for category in categories_selected:
    comment_count[category] = dict()

    for comment in comments[category]:
        author = comment['author']
        try:
            comment_count[category][author] += 1
        except KeyError:
            comment_count[category][author] = 1

In [None]:
user_list = set()

for category in categories_selected:
    for comment in comments[category]:
        user_list.add(comment['author'])

user_list = list(user_list)

In [None]:
first_post_time = dict()
# key: category -> user
# value: post time of the first comment by given user in the given category
#        It is an integer as returned by parse_tstring routine

for category in categories_selected:
    first_post_time[category] = dict()

    for comment in comments[category]: 
        if comment['time'] == 'Not Available':
            continue
        author = comment['author']
        try:
            first_post_time[category][author] = min(first_post_time[category][author], parse_tstring(comment['time']))
        except KeyError:
            first_post_time[category][author] = parse_tstring(comment['time'])

In [None]:
def get_migrated_users(user_subset, categories_1, categories_2, categories_1_origin=True, require_migration=True):
    """
    Returns a list of usernames who migrated from categories_1 to categories_2

    If categories_1_origin is True, we will consider all other major categories
    to compute post_time_2, so as to ensure that first post by the user is in 
    categories_1

    If require_migration is True, post_time_1 < post_time_2 condition is relaxed
    """

    resultant_list = list()

    for user in user_subset:
        post_time_1 = 20220101000000
        post_time_2 = 20220101000000

        if not isinstance(categories_1, set):
            categories_1 = set(categories_1)
        if not isinstance(categories_2, set): 
            categories_2 = set(categories_2)
        
        for category in categories_1:
            try:
                cur_post_time = first_post_time[category][user]
                post_time_1 = min(post_time_1, cur_post_time)
            except KeyError:
                pass
        
        for category in categories_2:
            try:
                cur_post_time = first_post_time[category][user]
                post_time_2 = min(post_time_2, cur_post_time) 
            except KeyError:
                pass

        if post_time_1 == 20220101000000 or post_time_2 == 20220101000000:
            continue

        if categories_1_origin:
            for category in categories_selected:
                if not ((category in categories_1) or (category in categories_2)):
                    try:
                        cur_post_time = first_post_time[category][user]
                        post_time_2 = min(post_time_2, cur_post_time)
                    except KeyError:
                        pass

        if post_time_1 < post_time_2 or not require_migration:
            resultant_list.append(user)
        
    return resultant_list

In [None]:
def partition_migrated_users(migration_list, categories_1, categories_2):
    """
    Partitions the users into 4 categories: 
        AH-AH
        AH-NonAH
        NonAH-AH
        NonAH-NonAH

    Users are classified as AH in a given category if they post at least one 
    ad hominem comment in that category
    
    Note: migration_list should be obtained using get_migrated_users method
    """

    ah_ah_list = []
    ah_nonah_list = []
    nonah_ah_list = []
    nonah_nonah_list = []

    for user in migration_list:
        max_score_1 = 0
        max_score_2 = 0
        for category in categories_1:
            max_score_1 = max(max_score_1, ah_score_max[category].get(user, 0))
        for category in categories_2:
            max_score_2 = max(max_score_2, ah_score_max[category].get(user, 0))

        if max_score_1 > 0.5 and max_score_2 > 0.5:
            ah_ah_list.append(user)

        elif max_score_1 > 0.5 and max_score_2 < 0.5:
            ah_nonah_list.append(user)
        
        elif max_score_1 < 0.5 and max_score_2 > 0.5:
            nonah_ah_list.append(user)

        elif max_score_1 < 0.5 and max_score_2 < 0.5:
            nonah_nonah_list.append(user)
        
        else:
            print(user)

    return ah_ah_list, ah_nonah_list, nonah_ah_list, nonah_nonah_list 

In [None]:
# Get a list of all comment thread representative to build user network graph

threads = []

for category in categories_selected:
    reader_addr = f'/content/gdrive/MyDrive/DL/CreateDebate/{category}/threads.log'
    reader = open(reader_addr, 'rb')
    try:
        while True:
            e = pickle.load(reader)
            threads.append(e)
    except:
        reader.close()

In [None]:
def build_graph(user_subset, n1 = 0, n2 = 0):
    """
    Builds user network graph from hyper-parameters n1 and n2
    
    Inputs
    ------
    :param n1: threshold on number of level-1 comments
    :param n2: threshold on number of direct replies

    Output
    ------
    (
        author_map: dict,
        reverse_map: list,
        author_count: int, 
        graph: nx.DiGraph,
        matrix: list
    )
    """

    # Uses globally defined `threads` variable to construct this dictionary.
    # You may choose which categories to be included while building `threads`

    # key  : author name
    # value: count of level-1 comments
    athr = dict()

    for e in threads:
        if 'root' in e.metaL.keys():
            for key in e.metaL['root'].keys():
                cmnt = e.comments[key]
                cur_athr = cmnt.author
                try:
                    athr[cur_athr] += 1
                except:
                    athr[cur_athr] = 1
        if 'root' in e.metaR.keys():
            for key in e.metaR['root'].keys():
                cmnt = e.comments[key]
                cur_athr = cmnt.author
                try:
                    athr[cur_athr] += 1
                except:
                    athr[cur_athr] = 1
    
    # Filter those authors who satisfy the contraint on number of level-1 comments
    L1_athr = dict()
    for x in athr:
        if athr[x] >= n1:
            L1_athr[x] = True

    # Now use `athr` for storing count of direct replies
    # key  : author name
    # value: count of direct replies received
    athr = dict()

    # Depth-first search utility to get number of direct replies for each author
    def dfs(Map, cmntMap, athr, cid='root'):
        if cid == 'root':
            for key in Map[cid].keys():
                dfs(Map[cid], cmntMap, athr, key)
            return

        cur_author = cmntMap[cid].author
        try:
            athr[cur_author] += len(Map[cid].keys())
        except:
            athr[cur_author] = len(Map[cid].keys())

        for key in Map[cid].keys():
            dfs(Map[cid], cmntMap, athr, key)

    # Traverse thread-tree to get number of direct replies for each author
    for e in threads:
        if 'root' in e.metaL.keys():
            dfs(e.metaL, e.comments, athr)
        if 'root' in e.metaR.keys():
            dfs(e.metaR, e.comments, athr) 
    
    # Filter authors who now satify both the contrainsts on count of 
    # - level-1 comments
    # - direct replies
    A = []
    for x in athr:
        if x not in user_subset:
            continue
        if athr[x] >= n2:
            try:
                z = L1_athr[x]
                A.append(x)
            except KeyError:
                pass

    # key  : author name
    # value: corresponing node number in the support/dispute network
    author_map = dict()

    # To get author name for node number
    reverse_map = ["" for _ in range(len(A))]
    author_count = len(A)

    for i in range(author_count):
        author_map[A[i]] = i
        reverse_map[i] = A[i]
    
    # Weighted adjacency matrices for user network
    # Weight for directed edge b/w Node A and Node B corresponsds to the number
    # of times Node A directly-replied Node B.
    matrix = [[0 for j in range(author_count)] for i in range(author_count)]

    # Depth-first search utility to build the adjacency matrices for graph.
    def dfs1(Map, cmntMap, cid='root'):
        if cid == 'root':
            for key in Map[cid].keys():
                dfs1(Map[cid], cmntMap, key)
            return

        cur_author = cmntMap[cid].author
        
        if cur_author in author_map:
            cur_author_id = author_map[cur_author]
            for key in Map[cid].keys():
                nxt_author = cmntMap[key].author
                if nxt_author in author_map:
                    nxt_author_id = author_map[nxt_author]
                    matrix[nxt_author_id][cur_author_id] += 1

        for key in Map[cid].keys():
            dfs1(Map[cid], cmntMap, key)

    for e in threads:
        if 'root' in e.metaL:
            dfs1(e.metaL, e.comments)
        if 'root' in e.metaR:
            dfs1(e.metaR, e.comments)
        
    # Create NetworkX graphs from the adjacency matrices.
    # We need nx graphs in order to get various network stats provided in nx
    # library.
    graph = nx.DiGraph()
    for i in range(author_count):
        for j in range(author_count):
            if matrix[i][j] != 0:
                graph.add_weighted_edges_from([(i, j, matrix[i][j])])
    
    return (author_map, reverse_map, author_count, graph, matrix)

In [None]:
# Construct global user network for entire CreateDebate corpus
user_map, user_reverse_map, user_count, Graph, Matrix = build_graph(user_list)

In [None]:
def get_reciprocity_stats(user_subset):
    """
    Returns reciprocity for given subset of users in local network

    >>> r = get_reciprocity_stats(user_subset)
    """
    _, _, _, Graph_, _ = build_graph(user_subset)

    try:
        r = nx.algorithms.reciprocity(Graph_)
    except:
        r = None

    return r

In [None]:
# Get dicts containing centrality value for each node from global network.
# This will be used for computing stats for user subset.
centrality_dict = nx.algorithms.centrality.degree_centrality(Graph)

In [None]:
def get_centrality_stats(user_subset):
    """
    Returns mean and standard deviation of degree centrality for given user 
    subset in the global network.

    >>> c_avg, c_std = get_centrality_stats(user_subset)
    """
    c = []

    for user in user_subset:
        try:
            c.append(centrality_dict[user_map[user]])
        except:
            pass
    
    return np.average(c), np.std(c)

In [None]:
# Get dicts containing clustering coeffieient for each node from global network. 
# This will be used for computing stats for user subset.
clustering_dict = nx.algorithms.cluster.clustering(Graph)

In [None]:
def get_clustering_stats(user_subset):
    """
    Returns mean and standard deviation of clustering coefficient for given user 
    subset in the global network.

    >>> c_avg, c_std = get_clustering_stats(user_subset)
    """
    c = []

    for user in user_subset:
        try:
            c.append(clustering_dict[user_map[user]])
        except:
            pass
    
    return np.average(c), np.std(c)

In [None]:
def normalize_dict(x):
    """
    Normalize elements in given dictionary as
        element = (element - min_element) / (max_element - min_element)
    """
    mini = min(x.values())
    maxa = max(x.values())

    res = dict()

    for k, v in x.items():
        res[k] = (v - mini) / (maxa - mini)
    return res

In [None]:
def normalize_array(x):
    """
    Normalize elements in given array as
        element = (element - min_element) / (max_element - min_element)
    """
    assert isinstance(x, (list, tuple)), "Expected a list or tuple"
    mini = min(x)
    maxa = max(x)
    res = []
    for e in x:
        res.append((e - mini) / (maxa - mini))
    return res

In [None]:
def display_stats(user_subset):
    n                          = len(user_subset)
    r                          = get_reciprocity_stats(user_subset) 
    deg_avg, deg_std           = get_centrality_stats(user_subset)
    clu_avg, clu_std           = get_clustering_stats(user_subset)
    user_chr_avg, user_chr_std = profile_characteristics_stats(user_subset) 

    print('Size: %d' % n)
    print('Graph reciprocity: %.2f' % r)

    print('Graph degree centrality: %.5f ± %.5f' % (deg_avg, deg_std))

    print('Graph clustering coeff: %.2f ± %.2f' % (clu_avg, clu_std))

    print('Reward points: %.2f ± %.2f' % (user_chr_avg[0], user_chr_std[0]))
    print('Efficiency   : %.2f ± %.2f' % (user_chr_avg[1], user_chr_std[1]))
    print('# Allies     : %.2f ± %.2f' % (user_chr_avg[2], user_chr_std[2]))
    print('# Enemies    : %.2f ± %.2f' % (user_chr_avg[3], user_chr_std[3]))
    print('# Hostiles   : %.2f ± %.2f' % (user_chr_avg[4], user_chr_std[4]))

# Analysis

In [None]:
# Partition the dataset into polar and non-polar comments

polar_cids = set()

comments_p = dict()
comments_np = dict()

for x in categories_selected:
    comments_p[x] = list()
    comments_np[x] = list()

    for comment in comments[x]:
        if comment['polarity'] == 'Not Available':
            comments_np[x].append(deepcopy(comment))
        else:
            comments_p[x].append(deepcopy(comment))

In [None]:
for x in categories_selected:
    print(f'{x} - {len(comments_p[x])} - {len(comments_np[x])}')

In [None]:
# User characterisitics

user_list_p = dict()
user_list_np = dict()

for cat in categories_selected:
    user_list_p[cat] = set()
    for comment in comments_p[cat]:
        user_list_p[cat].add(comment['author'])

    user_list_np[cat] = set()
    for comment in comments_np[cat]:
        user_list_np[cat].add(comment['author'])


In [None]:
for cat in categories_selected:
    print(f'{cat} - {len(user_list_p[cat])} - {len(user_list_np[cat])} - {len(user_list_p[cat] & user_list_np[cat])}')

In [None]:
ah_user_list_p = dict()
nonah_user_list_p = dict()

ah_user_list_np = dict()
nonah_user_list_np = dict()

for cat in categories_selected:
    ah_user_list_p[cat] = set() 
    nonah_user_list_p[cat] = set()
    for user in user_list_p[cat]:
        if ah_score_max[cat][user] > 0.5:
            ah_user_list_p[cat].add(user)
        else:
            nonah_user_list_p[cat].add(user)
    
    ah_user_list_np[cat] = set()
    nonah_user_list_np[cat] = set()
    for user in user_list_np[cat]:
        if ah_score_max[cat][user] > 0.5:
            ah_user_list_np[cat].add(user)
        else:
            nonah_user_list_np[cat].add(user)

In [None]:
# Sanity check for above block

for cat in categories_selected:
    l1 = len(ah_user_list_p[cat])
    l2 = len(nonah_user_list_p[cat])
    l3 = len(ah_user_list_np[cat])
    l4 = len(nonah_user_list_np[cat])
    l5 = len(ah_user_list_p[cat] | nonah_user_list_p[cat])
    l6 = len(ah_user_list_np[cat] | nonah_user_list_np[cat])
    assert(l1 + l2 == l5)
    assert(l3 + l4 == l6)

## User characteristics for topical forums

In [None]:
# User-characteristics for different forums

display_stats(nonah_user_list_np['technology'])

## Migration study

In [None]:
categories_1 = ['politics2']
categories_2 = ['law']

In [None]:
migration_list = get_migrated_users(user_list_np['politics2'], categories_1, categories_2, categories_1_origin=True, require_migration=True)

In [None]:
len(migration_list)

In [None]:
AA, AN, NA, NN = partition_migrated_users(migration_list, categories_1, categories_2)

In [None]:
print('AH-AH: %d, AH-NONAH: %d, NONAH-AH: %d, NONAH-NONAH: %d' % (len(AA), len(AN), len(NA), len(NN)))

In [None]:
A = AA + AN
N = NA + NN

In [None]:
display_stats(N)

# Top 1000 comments

In [None]:
comments_p_ = list()
comments_np_ = list()

for comment in comments_p['politics2']:
    comments_p_.append((comment['score'], comment['body']))

for comment in comments_np['politics2']:
    comments_np_.append((comment['score'], comment['body']))

In [None]:
comments_p_ = sorted(comments_p_)
comments_np_ = sorted(comments_np_)

In [None]:
top_ah_p_ = comments_p_[:1000]
top_ah_np_ = comments_np_[:1000]

In [None]:
top_ah_p = list()
top_ah_np = list()

for z in top_ah_p_:
    top_ah_p.append(' '.join(clean_text(z[1])))

for z in top_ah_np_:
    top_ah_np.append(' '.join(clean_text(z[1])))

## Visualization

In [None]:
class Visualizer:
    """Wrapper for creating heatmaps for documents"""
    def __init__(self):
        self._header = r'''\documentclass[10pt,a4paper]{article}
\usepackage[left=1.00cm, right=1.00cm, top=1.00cm, bottom=2.00cm]{geometry}
\usepackage{color}
\usepackage{tcolorbox}
\usepackage{CJK}
\usepackage{adjustbox}
\tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt}
\begin{document}
\begin{CJK*}{UTF8}{gbsn}''' + '\n\n'

        self._footer = r'''\end{CJK*}
\end{document}'''

    def visualize(self,
                  word_list,
                  attention_list,
                  label_list,
                  latex_file,
                  title,
                  batch_size=20,
                  color='blue'):
        """Routine to generate attention heatmaps for given texts
        ---------------------------------------------------------
        Input:
        :param word_list: list of texts (each text is a list of words)
        :param attention_list: scores for each word, dimension same as word_list
        :param label_list: label for each text
        :param latex_file: name of the latex file
        :param title: title of latex file
        :param batch_size: Number of comments in each batch
        :param color: color used for visualization, can be 'blue', 'red', 'green', etc.
        """
        word_list_processed = []
        for x in word_list:
            word_list_processed.append(self._clean_word(x))

        with open(latex_file, 'w', encoding='utf-8') as f:
            f.write(self._header)
            f.write('\\section{%s}\n\n' % title)

            n_examples = len(word_list)
            n_batches = n_examples // batch_size

            for i in range(n_batches):
                batch_word_list = word_list_processed[i * batch_size: (i + 1) * batch_size]
                batch_attention_list = attention_list[i * batch_size: (i + 1) * batch_size]
                batch_label_list = label_list[i * batch_size: (i + 1) * batch_size]
                f.write('\\subsection{Batch %d}\n\n' % (i + 1))
                for j in range(batch_size):
                    f.write('\\subsubsection{Comment %d - %s}\n\n' % (j + 1, batch_label_list[j]))
                    sentence = batch_word_list[j]
                    score = batch_attention_list[j]
                    assert len(sentence) == len(score)
                    f.write('\\noindent')
                    for k in range(len(sentence)):
                        f.write('\\colorbox{%s!%s}{' % (color, score[k]) + '\\strut ' + sentence[k] + '} ')
                    f.write('\n\n')

            f.write(self._footer)

    @staticmethod
    def _clean_word(word_list):
        new_word_list = []
        for word in word_list:
            for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]:
                if latex_sensitive in word:
                    word = word.replace(latex_sensitive, '\\' + latex_sensitive)
            new_word_list.append(word)
        return new_word_list

In [None]:
model_version = '/content/gdrive/MyDrive/DL/cnerg-bert-adhominem'
do_lower_case = True
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)

In [None]:
INTENSITY = 70

def attention_scores(text, layers=None, heads=None):
    sentence_a = text
    inputs = tokenizer.encode_plus(sentence_a, None, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    attention = model(input_ids)[-1]
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list) 
    sz = len(tokens)
    matrix = [0 for j in range(sz)]
    if layers is None:
        layers = [x for x in range(12)]
    if heads is None:
        heads = [x for x in range(12)]
    for layer in layers:
        for head in heads:
            for j in range(sz):
                matrix[j] += attention[layer][0, head, 0, j].item()
    for j in range(sz):
        matrix[j] = (matrix[j]) / (len(layers) * len(heads))
    return (tokens, matrix)

In [None]:
def clean_array(w, a):
    W = []
    A = []
    for i in range(len(w)):
        if (w[i].startswith('##')):
            W[len(W) - 1] += w[i][2:]
            A[len(A) - 1] = (A[len(A) - 1] + a[i]) / 2
        else:
            W.append(w[i])
            A.append(a[i])
    return clean_apos(W, A)

def clean_apos(w, a):
    W = []
    A = []
    ctr = 0
    while ctr != len(w):
        if w[ctr] == '\'':
            W[-1] += w[ctr] + w[ctr + 1]
            A[-1] = min(INTENSITY, A[-1] + a[ctr] + a[ctr + 1])
            ctr += 2
        else:
            W.append(w[ctr])
            A.append(a[ctr])
            ctr += 1
    return W, A

In [None]:
def top_three_tokens(text):
    words, attentions = attention_scores(text)
    words = words[1:-1] # Remove start and end tags
    attentions = attentions[1:-1]
    assert len(words) == len(attentions)
    words, attentions = clean_array(words, attentions)
    assert len(words) == len(attentions)
    top_tokens = list()
    for i in range(len(words)):
        top_tokens.append((attentions[i], i))
    top_tokens = sorted(top_tokens, reverse=True)
    ind = [0]
    cur = 1
    while len(ind) < 3:
        take = True
        for ids in ind:
            take = take and abs(top_tokens[ids][1] - top_tokens[cur][1]) > 2
        if take:
            ind.append(cur)
        cur += 1
    xx = []
    for x in ind:
        xx.append(top_tokens[x][1])
    scores = [0 for i in range(len(words))]
    for w in xx:
        lst = [w - 1, w, w + 1]
        for j in lst:
            if j >= 0 and j < len(words):
                scores[j] = INTENSITY
    return words, scores

In [None]:
viz = Visualizer()

def create_latex_file( do_polar=True):
    top_ah_comments = top_ah_p if do_polar else top_ah_np
    words_list = list()
    scores_list = list()

    for comment in top_ah_comments:
        try:
            words, scores = top_three_tokens(comment)
        except:
            continue
        words_list.append(words)
        scores_list.append(scores)
    
    label = 'For-against' if do_polar else 'Perspective'
    labels_list = [label for _ in range(len(words_list))]
    
    viz.visualize(words_list, scores_list, labels_list,
                  latex_file='sample.tex',
                  title=f'Top ad hominem comments from {label} debates in Politics forum',
                  batch_size=len(words_list),
                  color='cyan')

In [None]:
create_latex_file(do_polar=False)