In [4]:
pip install transformers==3.0.2 sentence_transformers==0.3.3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import re
import os
import json
import pickle
import numpy as np
import pandas as pd
import random
import torch 
from torch import nn
import seaborn as sns
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
cd /content/drive/MyDrive/Assignments/capstone/phrase-bert-topic-model-master/phrase-topic-model/

/content/drive/MyDrive/Assignments/capstone/phrase-bert-topic-model-master/phrase-topic-model


In [6]:
from model.dae_model import DictionaryAutoencoder
from model_utils import run_epoch, text_to_topic, rank_topics_by_percentage

In [7]:
cd /content/drive/MyDrive/Assignments/capstone/

/content/drive/MyDrive/Assignments/capstone


In [8]:
ls

 2017_headline_results.csv         pntm_headlines.ipynb
 FINAL_semeval2010t8_train.csv     pntm_semeval.ipynb
 headline_topic_df_100.csv         pntm_semeval_labels.ipynb
 phrase_bert_similarity.ipynb     [0m[01;34m'pooled_context_para_triples_p=0.8'[0m/
 [01;34mphrase-bert-topic-model-master[0m/   semeval2010t8_test.csv
 [01;34mpntm[0m/                             semeval2010t8_train.csv


In [9]:
semeval_train = pd.read_csv('semeval2010t8_train.csv')
semeval_test = pd.read_csv('semeval2010t8_train.csv')
semeval_label = pd.concat([semeval_train, semeval_test])

In [10]:
semeval_label.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents
0,semeval2010t8,train.json,0,0,semeval2010t8_train.json_0_0,The system as described above has its greatest...,The system as described above has its greatest...,0,0,,1
1,semeval2010t8,train.json,1,0,semeval2010t8_train.json_1_0,The child was carefully wrapped and bound into...,The <ARG1>child</ARG1> was carefully wrapped a...,0,0,,1
2,semeval2010t8,train.json,2,0,semeval2010t8_train.json_2_0,The author of a keygen uses a disassembler to ...,The <ARG1>author</ARG1> of a keygen uses a <AR...,0,0,,1
3,semeval2010t8,train.json,3,0,semeval2010t8_train.json_3_0,A misty ridge uprises from the surge .,A misty <ARG1>ridge</ARG1> uprises from the <A...,0,0,,1
4,semeval2010t8,train.json,4,0,semeval2010t8_train.json_4_0,The student association is the voice of the un...,The <ARG0>student</ARG0> <ARG1>association</AR...,0,0,,1


In [11]:
# filter the dataframe by pair_label for only causal texts
semeval_label_causal = semeval_label[semeval_label['pair_label'] == 1] 
len(semeval_label_causal), len(semeval_label)

(2006, 16000)

In [12]:
semeval_label_causal.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents
6,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is that the chronic <ARG1>inf...,1,1,,1
13,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,The <ARG1>burst</ARG1> has been caused by wate...,1,1,,1
22,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","The <ARG0>singer</ARG0> , who performed three ...",1,1,,1
26,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0>Suicide</ARG0> is one of the leading cau...,1,1,,1
31,semeval2010t8,train.json,31,0,semeval2010t8_train.json_31_0,He had chest pains and headaches from mold in ...,He had chest pains and <ARG1>headaches</ARG1> ...,1,1,,1


In [13]:
# textwpairs column
semeval_label_textwpairs = semeval_label_causal['text_w_pairs']

In [14]:
semeval_label_textwpairs.head(5)

6     The current view is that the chronic <ARG1>inf...
13    The <ARG1>burst</ARG1> has been caused by wate...
22    The <ARG0>singer</ARG0> , who performed three ...
26    <ARG0>Suicide</ARG0> is one of the leading cau...
31    He had chest pains and <ARG1>headaches</ARG1> ...
Name: text_w_pairs, dtype: object

In [15]:
def extract_args(dataset):
    arg0s = []
    arg1s = []
    for textwpair in dataset:
        arg0 = re.findall(r"<ARG0>(.*?)</ARG0>", textwpair) # list of all argument0s in string textwpair
        arg1 = re.findall(r"<ARG1>(.*?)</ARG1>", textwpair) # list of all argument1s in string textwpair
        if len(arg0) != 0:
            # unpack the list of argument0s and append them one by one
            for arg in arg0:
                arg0s.append(arg)
        if len(arg1) != 0:
            # unpack the list of argument1s and append them one by one
            for arg in arg1:
                arg1s.append(arg)
    return arg0s, arg1s

In [16]:
# list of arg0s and arg1s for SemEval labels
semeval_label_arg0s, semeval_label_arg1s = extract_args(semeval_label_textwpairs)
# list of all args for SemEval labels
semeval_label_args = semeval_label_arg0s + semeval_label_arg1s
len(semeval_label_args)

4012

In [17]:
# construct text_list
semeval_label_text = semeval_label_causal['text']
semeval_label_text_list = semeval_label_text.tolist()

In [18]:
# construct dictionaries of word2id
semeval_label_word2id = {val : idx for idx, val in enumerate(set(semeval_label_args))}
len(semeval_label_word2id.keys()), len(semeval_label_args) # duplicates

(1131, 4012)

In [19]:
# construct dictionaries of id2word
semeval_label_id2word = {val: key for key, val in semeval_label_word2id.items()}

In [20]:
# contruct dictionaries of id2freq
semeval_label_id2freq = semeval_label_id2word.copy()
semeval_label_freq = [(arg, semeval_label_args.count(arg)) for arg in set(semeval_label_args)]
i = 0
for key, val in semeval_label_id2freq.items():
    semeval_label_id2freq[key] = semeval_label_freq[i][1]
    i += 1

In [21]:
# load the Phrase-BERT model through the sentence-BERT interface
model_path = "/content/drive/MyDrive/Assignments/capstone/pooled_context_para_triples_p=0.8/"
model = SentenceTransformer(model_path)

In [22]:
"""
# commented out because results are already saved and can be loaded from files
# compute phrase embeddings using Phrase-BERT
semeval_label_embs = model.encode(set(semeval_label_args), batch_size=8, show_progress_bar=True)
semeval_label_embs = np.asarray(semeval_label_embs)
"""

'\n# commented out because results are already saved and can be loaded from files\n# compute phrase embeddings using Phrase-BERT\nsemeval_label_embs = model.encode(set(semeval_label_args), batch_size=8, show_progress_bar=True)\nsemeval_label_embs = np.asarray(semeval_label_embs)\n'

In [23]:
"""
# save the results
topic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"
np.save(os.path.join(topic_model_data_path, 'semeval_label_embs_matrix_np'), semeval_label_embs)
"""

'\n# save the results\ntopic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"\nnp.save(os.path.join(topic_model_data_path, \'semeval_label_embs_matrix_np\'), semeval_label_embs)\n'

In [24]:
# set seed
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f91bfd62f50>

In [25]:
topic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"

In [26]:
semeval_label_embs_matrix_np = np.load(os.path.join(topic_model_data_path, f"semeval_label_embs_matrix_np.npy"))
print(f"Loaded semeval labels word embedding from {topic_model_data_path}")
print(f"Loaded vocab size of {len(semeval_label_word2id)} (including phrases)")

Loaded semeval labels word embedding from /content/drive/MyDrive/Assignments/capstone/pntm/
Loaded vocab size of 1131 (including phrases)


In [27]:
len(semeval_label_embs_matrix_np)

1131

Below code is adapted from Phrase-Bert: https://github.com/sf-wa-326/phrase-bert-topic-model

In [28]:
# word frequency and filter info

# compute the length (in n-grams)
# setting word_threshould really high to include every phrase, reset to lower value to remove longer phrases
word_threshold = 100
semeval_label_len_words = [0] * len(semeval_label_id2word)
for (id, word) in semeval_label_id2word.items():
    semeval_label_len_words[id] = len(word.split(' '))
semeval_label_indices_to_remove_based_on_len = [id 
                                                for id, word_len 
                                                in enumerate(semeval_label_len_words) 
                                                if (word_len > word_threshold)]

print(len(semeval_label_indices_to_remove_based_on_len)) # 0

0


In [29]:
# keeping every token / word, reset freq_threshold to remove lower frequency words
freq_threshold = 0

semeval_label_sorted_ids = [k for k, v in sorted(semeval_label_id2freq.items(), key=lambda item: item[1])]
semeval_label_sorted_ids.reverse()
semeval_label_indices_to_remove_based_on_freq = [k for k, v in semeval_label_id2freq.items() if v <= freq_threshold ]
semeval_label_to_be_removed = list(set(semeval_label_indices_to_remove_based_on_freq 
                                       + semeval_label_indices_to_remove_based_on_len))

In [30]:
# encode the text_lists
semeval_label_text_rep_list = model.encode(semeval_label_text, batch_size = 8, show_progress_bar = True)

Batches:   0%|          | 0/251 [00:00<?, ?it/s]

In [31]:
emb_model = "phrase-bert"
print(f"Building sentence model by using {emb_model} as embedding model")

semeval_label_uid_input_vector_list = [(i, semeval_label_text_rep_list[i]) for i in range(len(semeval_label_text_rep_list))]
print(f"Computed {len(semeval_label_uid_input_vector_list)} positive examples")

Building sentence model by using phrase-bert as embedding model
Computed 2006 positive examples


In [32]:
# setting the argument num_negative_samples for negative sampling
num_neg_samples = 5 # default in the original model

semeval_label_uid_input_vector_list_neg = []
indices = list(range(len(semeval_label_uid_input_vector_list)))
for idx in range(len(semeval_label_uid_input_vector_list)):
    indices_candidate = indices
    neg_indices = random.sample(indices_candidate, num_neg_samples)
    neg_samples = [semeval_label_uid_input_vector_list[neg_i][1] for neg_i in neg_indices]
    neg_vector = np.mean(neg_samples, axis=0)
    semeval_label_uid_input_vector_list_neg.append(neg_vector)
print(f"Computed {len(semeval_label_uid_input_vector_list_neg)} negative examples")

Computed 2006 negative examples


In [33]:
# set up hyperparameters
semeval_label_net_params = {}
semeval_label_net_params["mode"] = "bert"
semeval_label_net_params["embedding"] = semeval_label_embs_matrix_np
semeval_label_net_params["d_hid"] = 100
semeval_label_net_params["num_rows"] = 50  # number of topics
semeval_label_net_params["num_sub_topics"] = 0
semeval_label_net_params["word_dropout_prob"] = 0.2
semeval_label_net_params["vrev"] = semeval_label_id2word  # idx to word map
semeval_label_net_params["device"] = 'cuda'
semeval_label_net_params["pred_world"] = False

In [34]:
semeval_label_net = DictionaryAutoencoder(net_params=semeval_label_net_params)
semeval_label_net.to('cuda')

DictionaryAutoencoder(
  (embeddings): Embedding(1131, 768)
  (W_proj): Linear(in_features=768, out_features=100, bias=True)
  (act): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (W_att): Linear(in_features=100, out_features=768, bias=True)
  (W_out): Linear(in_features=768, out_features=1131, bias=True)
)

In [35]:
# training specs (default from original code)
num_epochs = 300
batch_size = 100
ortho_weight = 1e-5
world_clas_weight = 0.0
semeval_label_optim = torch.optim.Adam(semeval_label_net.parameters(), lr=1e-4)
interpret_interval = int(np.ceil(num_epochs / 10))
h_model = 2

In [36]:
# iterating through batches
semeval_label_batch_intervals = [
    (start, start + batch_size)
    for start in range(0, len(semeval_label_uid_input_vector_list), batch_size)]
    # batch_intervals = batch_intervals[:100]
semeval_label_split = int(np.ceil(len(semeval_label_batch_intervals) * 0.9))
semeval_label_batch_intervals_train = semeval_label_batch_intervals[:semeval_label_split]
semeval_label_batch_intervals_valid = semeval_label_batch_intervals[semeval_label_split:]

In [37]:
import argparse
parser = argparse.ArgumentParser()
"""parser.add_argument("--lr", type=float, default=1e-4)
"""
args = parser.parse_args(args=[])
args.device = 'cuda:' + '0'
args.triplet_loss_margin = 1.0
args.triplet_loss_weight = 1.0
args.ortho_weight = 1e-5
args.neighbour_loss_weight = 1e-7
args.offset_loss_weight = 1e-4

In [38]:
# semeval label training
print("\n" + "=" * 70)
for epoch in range(num_epochs):
    # training
    semeval_label_net.train()
    train_mode = True
    print(f"Epoch {epoch}")
    run_epoch(semeval_label_net, semeval_label_optim, semeval_label_batch_intervals_train,
              semeval_label_uid_input_vector_list, semeval_label_uid_input_vector_list_neg,
              args, train_mode, h_model, epoch, 200)

    # validation
    semeval_label_net.eval()
    train_mode = False
    with torch.no_grad():
        run_epoch(
                semeval_label_net,
                semeval_label_optim,
                semeval_label_batch_intervals_valid,
                semeval_label_uid_input_vector_list,
                semeval_label_uid_input_vector_list_neg,
                args, 
                train_mode,
                h_model,
                epoch,
                200
        )

    if (epoch + 1) % interpret_interval == 0:
        print("Topics with probability argmax")
        topics_print_list = semeval_label_net.rank_vocab_for_topics(
            word_embedding_matrix=semeval_label_embs_matrix_np,
            to_be_removed=semeval_label_to_be_removed)
        print("=" * 70)

    print()
    print()
    print()
    print("=" * 70)


Epoch 0
[TRAIN] loss: 0.9694, 0.9694, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.30 s
[VALID] loss: 0.9551, 0.9551, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 1
[TRAIN] loss: 0.9127, 0.9127, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.12 s
[VALID] loss: 0.9060, 0.9060, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.01 s



Epoch 2
[TRAIN] loss: 0.8702, 0.8702, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.13 s
[VALID] loss: 0.8664, 0.8664, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 3
[TRAIN] loss: 0.8363, 0.8363, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.13 s
[VALID] loss: 0.8363, 0.8363, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 4
[TRAIN] loss: 0.8085, 0.8085, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.12 s
[VALID] loss: 0.8113, 0.8113, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 5
[TRAIN] loss: 0.

In [39]:
# semeval label print topic list
print("Finally after training")
semeval_label_net.eval()

print("Topics with probability argmax")
prob_over_vocab_np, topics_print_list = semeval_label_net.rank_vocab_for_topics(
            word_embedding_matrix=semeval_label_embs_matrix_np, to_be_removed=semeval_label_to_be_removed
    )
print("=" * 70)

Finally after training
Topics with probability argmax
50
50
[372 797 845 941 688 582 576 846 970 432]
topic 0 : temperature, Production, Addiction, influenza, shoulder problems, deluge, scandals, pie, consumption, Sadness
[481 419 664  54  40 979 300 638 887 903]
topic 1 : competition, increased pressure, surplus, flood, bite, pathogens, guy, harassment, Poverty, relaxation
[ 350 1061  497  948 1094  316  259    8  194 1032]
topic 2 : Inhibition, Sodium, malfunction, celebs, collapse, nerves, neglect, abuse, gash, discharge
[ 694 1128 1056  400  652  675    1  629  727 1093]
topic 3 : resignation, hormonal changes, legs, substance, Trauma, books, rising unemployment rate, anticipation, symptoms, profit
[ 251  438   86  769  293 1019  422 1086  912  770]
topic 4 : rubbing, toll, demolition, bike-accident, conditions, cancellation, condition, separation field, Bed sores, dams
[ 264  378 1024  634  610  913  335  398  337  471]
topic 5 : firing, Boils, particles, extinction, Osteoporosis,

In [40]:
# after training we evaluate all the topics percentage in the dataset and rank the topics by percentage
uid_list, vector_list = zip(*semeval_label_uid_input_vector_list)
topic_pred_list = text_to_topic(vector_list, semeval_label_net, 'cuda')

topic_id_ranked, topic_percentage_ranked = rank_topics_by_percentage( topic_pred_list )

for rank, (topic_id, topic_percentage) in enumerate( zip(topic_id_ranked, topic_percentage_ranked)):
    print(
            f"Rank: {rank}, Topic_id: {topic_id}, Topic Words: {topics_print_list[topic_id]}, \
            Topic Percentage: {topic_percentage}"
        )


100%|██████████| 6/6 [00:00<00:00, 488.83it/s]

Rank: 0, Topic_id: 48, Topic Words: topic 48 : power outage, stress, vocals, fungi, screen, transmitter, hitting, lack, alarm, consumption,             Topic Percentage: 9.67
Rank: 1, Topic_id: 31, Topic Words: topic 31 : inspiration, competition, properties, Preeclampsia, science, militancy, child abuse, bite, Properties, pyrotechnics,             Topic Percentage: 7.58
Rank: 2, Topic_id: 26, Topic Words: topic 26 : scandals, cancer, Birth defects, pie, havoc, emergency, discharge, abuse, gash, clot,             Topic Percentage: 7.38
Rank: 3, Topic_id: 22, Topic Words: topic 22 : electricity, drill, convergence, shingles, arterial blood pressure, progress, disorders, reboot, passage, pinkeye,             Topic Percentage: 7.08
Rank: 4, Topic_id: 15, Topic Words: topic 15 : slowdown, burn, fungi, substance, lack, alarm, flooding, women, press, dumping,             Topic Percentage: 6.28
Rank: 5, Topic_id: 38, Topic Words: topic 38 : Dehydration, gases, winds, process, toll, burn, roug




In [41]:
prob_over_vocab_df = pd.DataFrame(prob_over_vocab_np) # shape: 50 x 1131
prob_over_vocab_df.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1121,1122,1123,1124,1125,1126,1127,1128,1129,1130
0,0.000874,0.001002,0.001406,0.001139,0.000745,0.000712,0.000837,0.000529,0.001343,0.000847,...,0.000684,0.000562,0.000787,0.000835,0.000704,0.001145,0.000804,0.001376,0.000634,0.00093
1,0.000965,0.000786,0.000933,0.00101,0.000944,0.001202,0.000856,0.000855,0.000639,0.00077,...,0.001037,0.000811,0.000794,0.000997,0.000994,0.000813,0.000796,0.000787,0.000685,0.000951
2,0.00093,0.000553,0.000979,0.000761,0.00102,0.00092,0.000899,0.000889,0.001199,0.000974,...,0.000785,0.00079,0.000975,0.001023,0.000793,0.000977,0.000774,0.000817,0.000863,0.000932
3,0.000906,0.001484,0.000851,0.001184,0.0007,0.000715,0.001018,0.000902,0.00092,0.000879,...,0.00088,0.000874,0.000552,0.000576,0.000884,0.000881,0.000667,0.001408,0.000664,0.001035
4,0.000783,0.00098,0.000763,0.000896,0.001073,0.00087,0.000911,0.000952,0.000758,0.000918,...,0.000552,0.000962,0.000895,0.001043,0.000766,0.000925,0.000817,0.000656,0.000901,0.001124


In [42]:
prob_over_vocab_df.idxmax() # for each column, find the row number of the max

0       43
1        3
2       22
3        9
4       38
        ..
1126    22
1127    39
1128     9
1129    42
1130    22
Length: 1131, dtype: int64

In [44]:
df = pd.DataFrame(prob_over_vocab_df.idxmax())
df = df.reset_index()
df

Unnamed: 0,index,0
0,0,43
1,1,3
2,2,22
3,3,9
4,4,38
...,...,...
1126,1126,22
1127,1127,39
1128,1128,9
1129,1129,42


In [45]:
# construct a dictionary with arg ids as keys and topic ids as values
argid2topic_dict = df[0].to_dict()

In [46]:
# construct a dictionary with topic ids as keys and arg ids as values
# where only one topic is assigned to each argument
topic2argid_dict = df.groupby(0)['index'].apply(list).to_dict()

In [47]:
topic2argid_dict

{0: [12,
  48,
  95,
  131,
  138,
  155,
  178,
  198,
  219,
  246,
  251,
  296,
  305,
  311,
  346,
  372,
  373,
  378,
  399,
  432,
  434,
  452,
  502,
  529,
  530,
  576,
  582,
  601,
  614,
  624,
  640,
  654,
  655,
  688,
  694,
  699,
  740,
  749,
  778,
  797,
  803,
  834,
  841,
  845,
  846,
  857,
  863,
  882,
  890,
  893,
  910,
  920,
  926,
  941,
  954,
  968,
  969,
  970,
  971,
  975,
  1051,
  1064,
  1095,
  1118],
 1: [5, 132, 392, 419, 520, 638, 664, 806, 887, 903, 979],
 2: [316, 391, 508, 1094],
 3: [1,
  125,
  142,
  159,
  215,
  231,
  237,
  278,
  307,
  416,
  493,
  583,
  629,
  652,
  657,
  675,
  702,
  727,
  741,
  755,
  801,
  884,
  997,
  1012,
  1024,
  1056,
  1093],
 4: [27, 86, 609, 722, 767, 799, 809, 1052, 1086, 1117],
 5: [20,
  21,
  102,
  105,
  169,
  182,
  225,
  226,
  264,
  335,
  337,
  354,
  387,
  398,
  471,
  585,
  610,
  634,
  660,
  708,
  721,
  750,
  830,
  913,
  1043],
 6: [78,
  91,
  97,
  111,
  1

In [48]:
argid2topic_dict

{0: 43,
 1: 3,
 2: 22,
 3: 9,
 4: 38,
 5: 1,
 6: 39,
 7: 39,
 8: 26,
 9: 19,
 10: 33,
 11: 43,
 12: 0,
 13: 9,
 14: 39,
 15: 15,
 16: 39,
 17: 30,
 18: 36,
 19: 17,
 20: 5,
 21: 5,
 22: 36,
 23: 22,
 24: 34,
 25: 34,
 26: 35,
 27: 4,
 28: 43,
 29: 36,
 30: 11,
 31: 26,
 32: 13,
 33: 19,
 34: 18,
 35: 10,
 36: 43,
 37: 8,
 38: 45,
 39: 38,
 40: 31,
 41: 34,
 42: 14,
 43: 33,
 44: 31,
 45: 33,
 46: 26,
 47: 44,
 48: 0,
 49: 31,
 50: 10,
 51: 44,
 52: 16,
 53: 40,
 54: 31,
 55: 38,
 56: 22,
 57: 46,
 58: 18,
 59: 48,
 60: 10,
 61: 46,
 62: 30,
 63: 38,
 64: 38,
 65: 22,
 66: 15,
 67: 15,
 68: 45,
 69: 35,
 70: 16,
 71: 30,
 72: 30,
 73: 21,
 74: 10,
 75: 32,
 76: 33,
 77: 34,
 78: 6,
 79: 16,
 80: 39,
 81: 11,
 82: 20,
 83: 21,
 84: 34,
 85: 46,
 86: 4,
 87: 11,
 88: 49,
 89: 38,
 90: 10,
 91: 6,
 92: 46,
 93: 16,
 94: 26,
 95: 0,
 96: 10,
 97: 6,
 98: 36,
 99: 33,
 100: 45,
 101: 43,
 102: 5,
 103: 15,
 104: 16,
 105: 5,
 106: 31,
 107: 14,
 108: 18,
 109: 16,
 110: 48,
 111: 6,
 112: 43

Construct a new dataframe with columns: sentence, arg0, arg1, arg0id, arg1id, arg0topicid, arg1topicid

In [49]:
semeval_label_topic_df = semeval_label_causal.copy()
semeval_label_topic_df.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents
6,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is that the chronic <ARG1>inf...,1,1,,1
13,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,The <ARG1>burst</ARG1> has been caused by wate...,1,1,,1
22,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","The <ARG0>singer</ARG0> , who performed three ...",1,1,,1
26,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0>Suicide</ARG0> is one of the leading cau...,1,1,,1
31,semeval2010t8,train.json,31,0,semeval2010t8_train.json_31_0,He had chest pains and headaches from mold in ...,He had chest pains and <ARG1>headaches</ARG1> ...,1,1,,1


In [50]:
# modify the original extract_args function to differentiate which text arguments belong to
def extract_args2(dataset):
    arg0s = []
    arg1s = []
    for textwpair in dataset:
        arg0 = re.findall(r"<ARG0>(.*?)</ARG0>", textwpair) # list of all argument0s in string textwpair
        arg1 = re.findall(r"<ARG1>(.*?)</ARG1>", textwpair) # list of all argument1s in string textwpair
        arg0s.append(arg0)
        arg1s.append(arg1)
    return arg0s, arg1s

In [51]:
# create new columns for lists of arg0s and arg1s for each text
semeval_label_topic_df['arg0'], semeval_label_topic_df['arg1'] = extract_args2(semeval_label_textwpairs)

In [52]:
# create arg0_id and arg1_id columns
semeval_label_arg0id_list = []
for arg0s in semeval_label_topic_df['arg0']:
    temp = []
    for arg0 in arg0s:
        temp.append(semeval_label_word2id[arg0])
    semeval_label_arg0id_list.append(temp)
semeval_label_topic_df['arg0_id'] = semeval_label_arg0id_list

semeval_label_arg1id_list = []
for arg1s in semeval_label_topic_df['arg1']:
    temp = []
    for arg1 in arg1s:
        temp.append(semeval_label_word2id[arg1])
    semeval_label_arg1id_list.append(temp)
semeval_label_topic_df['arg1_id'] = semeval_label_arg1id_list

In [53]:
# create columns for arg0_topicid and arg1_topicid
semeval_label_arg0topic_list = []
for arg0s in semeval_label_arg0id_list:
    temp = []
    for arg0 in arg0s:
        temp.append(argid2topic_dict[arg0])
    semeval_label_arg0topic_list.append(temp)
semeval_label_topic_df['arg0_topicid'] = semeval_label_arg0topic_list

semeval_label_arg1topic_list = []
for arg1s in semeval_label_arg1id_list:
    temp = []
    for arg1 in arg1s:
        temp.append(argid2topic_dict[arg1])
    semeval_label_arg1topic_list.append(temp)
semeval_label_topic_df['arg1_topicid'] = semeval_label_arg1topic_list

In [54]:
semeval_label_topic_df.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents,arg0,arg1,arg0_id,arg1_id,arg0_topicid,arg1_topicid
6,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is that the chronic <ARG1>inf...,1,1,,1,[infection],[inflammation],[351],[406],[38],[48]
13,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,The <ARG1>burst</ARG1> has been caused by wate...,1,1,,1,[pressure],[burst],[269],[461],[45],[10]
22,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","The <ARG0>singer</ARG0> , who performed three ...",1,1,,1,[singer],[commotion],[1099],[163],[46],[7]
26,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0>Suicide</ARG0> is one of the leading cau...,1,1,,1,[Suicide],[death],[348],[630],[36],[48]
31,semeval2010t8,train.json,31,0,semeval2010t8_train.json_31_0,He had chest pains and headaches from mold in ...,He had chest pains and <ARG1>headaches</ARG1> ...,1,1,,1,[mold],[headaches],[1049],[472],[16],[21]


In [None]:
"""
# save the dataframe as a csv file
topic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"
np.save(os.path.join(topic_model_data_path, 'semeval_label_topic_df'), semeval_label_topic_df)
"""

In [None]:
"""
with open( os.path.join(topic_model_data_path, 'semeval_label_topic_model_50.pt'), "wb") as f:
    torch.save(semeval_label_net, f)
    print(f"Saved model at { os.path.join(topic_model_data_path) }")
"""

Saved model at /content/drive/MyDrive/Assignments/capstone/pntm/


construct a table to keep track of how many times each topic of cause has caused each topic of effect

In [97]:
semeval_label_topic_count_df = semeval_label_topic_df.copy()
semeval_label_topic_count_df.tail(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents,arg0,arg1,arg0_id,arg1_id,arg0_topicid,arg1_topicid
7983,semeval2010t8,train.json,7958,0,semeval2010t8_train.json_7958_0,Hand creams counteract dryness from exposure t...,Hand creams counteract <ARG1>dryness</ARG1> fr...,1,1,,1,[exposure],[dryness],[936],[454],[32],[9]
7984,semeval2010t8,train.json,7959,0,semeval2010t8_train.json_7959_0,Eye discomfort from this staring effect is exa...,Eye <ARG1>discomfort</ARG1> from this <ARG0>st...,1,1,,1,[staring effect],[discomfort],[807],[848],[39],[15]
7986,semeval2010t8,train.json,7961,0,semeval2010t8_train.json_7961_0,The transmitter emits a constant radio signal ...,The <ARG0>transmitter</ARG0> emits a constant ...,1,1,,1,[transmitter],[signal],[988],[219],[48],[0]
7987,semeval2010t8,train.json,7962,0,semeval2010t8_train.json_7962_0,Parents also experience anxiety from fear of t...,Parents also experience <ARG1>anxiety</ARG1> f...,1,1,,1,[fear],[anxiety],[399],[842],[0],[14]
7992,semeval2010t8,train.json,7967,0,semeval2010t8_train.json_7967_0,In chemical lasers the inversion is produced b...,In chemical lasers the <ARG1>inversion</ARG1> ...,1,1,,1,[reaction],[inversion],[4],[446],[38],[6]


In [98]:
semeval_label_topic_count_df = semeval_label_topic_count_df[['arg0_topicid', 'arg1_topicid']]
semeval_label_topic_count_df

Unnamed: 0,arg0_topicid,arg1_topicid
6,[38],[48]
13,[45],[10]
22,[46],[7]
26,[36],[48]
31,[16],[21]
...,...,...
7983,[32],[9]
7984,[39],[15]
7986,[48],[0]
7987,[0],[14]


In [99]:
semeval_label_topic_count_df = semeval_label_topic_count_df.explode('arg0_topicid')
semeval_label_topic_count_df = semeval_label_topic_count_df.explode('arg1_topicid')
semeval_label_topic_count_df['index'] = semeval_label_topic_count_df.index
semeval_label_topic_count_df

Unnamed: 0,arg0_topicid,arg1_topicid,index
6,38,48,6
13,45,10,13
22,46,7,22
26,36,48,26
31,16,21,31
...,...,...,...
7983,32,9,7983
7984,39,15,7984
7986,48,0,7986
7987,0,14,7987


In [100]:
semeval_label_topic_count_df = semeval_label_topic_count_df.groupby(['arg0_topicid', 'arg1_topicid'])['index'].count().unstack(fill_value=0)
semeval_label_topic_count_df

arg1_topicid,0,1,2,3,4,5,6,7,8,9,...,39,40,42,43,44,45,46,47,48,49
arg0_topicid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,4,0,0,2,2,0,12,0,0,4,...,2,0,2,4,2,10,4,4,2,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,2,2,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,2,2,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,4,0,0,2,...,2,2,0,0,0,0,0,0,4,0
4,0,0,0,0,0,0,2,0,0,0,...,0,0,0,0,0,0,0,0,0,0
5,0,0,4,2,0,2,4,0,0,2,...,2,0,0,0,0,2,0,0,8,0
6,0,2,2,0,0,4,8,0,0,6,...,8,0,2,2,0,2,0,2,4,0
7,2,0,0,0,0,0,0,0,0,2,...,0,0,0,0,0,0,0,0,0,0
8,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9,4,0,0,2,0,0,6,0,0,8,...,2,0,0,2,0,0,4,2,4,0
