In [1]:
pip install transformers==3.0.2 sentence_transformers==0.3.3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==3.0.2
  Downloading transformers-3.0.2-py3-none-any.whl (769 kB)
[K     |████████████████████████████████| 769 kB 14.2 MB/s 
[?25hCollecting sentence_transformers==0.3.3
  Downloading sentence-transformers-0.3.3.tar.gz (65 kB)
[K     |████████████████████████████████| 65 kB 5.2 MB/s 
Collecting tokenizers==0.8.1.rc1
  Downloading tokenizers-0.8.1rc1-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 58.6 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 59.1 MB/s 
Collecting sentencepiece!=0.1.92
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 43.5 MB/s 
Building wheels for collected packages: sentence-transformers, sacremoses
  Building wheel 

In [2]:
import re
import os
import json
import pickle
import numpy as np
import pandas as pd
import random
import torch 
from torch import nn
import seaborn as sns
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
cd /content/drive/MyDrive/Assignments/capstone/phrase-bert-topic-model-master/phrase-topic-model/

/content/drive/MyDrive/Assignments/capstone/phrase-bert-topic-model-master/phrase-topic-model


In [5]:
from model.dae_model import DictionaryAutoencoder
from model_utils import run_epoch, text_to_topic, rank_topics_by_percentage

In [6]:
cd /content/drive/MyDrive/Assignments/capstone/

/content/drive/MyDrive/Assignments/capstone


In [7]:
ls

 2017_headline_results.csv         pntm_semeval.ipynb
 FINAL_semeval2010t8_train.csv     pntm_semeval_labels.ipynb
 phrase_bert_similarity.ipynb     [0m[01;34m'pooled_context_para_triples_p=0.8'[0m/
 [01;34mphrase-bert-topic-model-master[0m/   semeval2010t8_test.csv
 [01;34mpntm[0m/                             semeval2010t8_train.csv
 pntm_headlines.ipynb


In [8]:
semeval = pd.read_csv('FINAL_semeval2010t8_train.csv')

In [9]:
semeval.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents
0,semeval2010t8,train.json,3,0,semeval2010t8_train.json_3_0,A misty ridge uprises from the surge .,<ARG1> A misty ridge uprises from the </ARG1> ...,1,0,,1
1,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is <ARG1> that the chronic in...,1,1,,1
2,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,<ARG1> The burst has been </ARG1> caused by <A...,1,1,,1
3,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","<ARG0> The singer , who performed three of the...",1,1,,1
4,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0> Suicide is one of the leading causes of...,1,1,,1


In [10]:
# filter the dataframe by pair_label for only causal texts
semeval_causal = semeval[semeval['pair_label'] == 1] 
len(semeval_causal), len(semeval)

(1106, 1276)

In [11]:
semeval_causal.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents
1,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is <ARG1> that the chronic in...,1,1,,1
2,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,<ARG1> The burst has been </ARG1> caused by <A...,1,1,,1
3,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","<ARG0> The singer , who performed three of the...",1,1,,1
4,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0> Suicide is one of the leading causes of...,1,1,,1
5,semeval2010t8,train.json,31,0,semeval2010t8_train.json_31_0,He had chest pains and headaches from mold in ...,<ARG1> He had chest pains and headaches from <...,1,1,,1


In [12]:
# textwpairs column
semeval_textwpairs = semeval_causal['text_w_pairs']

In [13]:
semeval_textwpairs.head(5)

1    The current view is <ARG1> that the chronic in...
2    <ARG1> The burst has been </ARG1> caused by <A...
3    <ARG0> The singer , who performed three of the...
4    <ARG0> Suicide is one of the leading causes of...
5    <ARG1> He had chest pains and headaches from <...
Name: text_w_pairs, dtype: object

In [14]:
def extract_args(dataset):
    arg0s = []
    arg1s = []
    for textwpair in dataset:
        arg0 = re.findall(r"<ARG0>(.*?)</ARG0>", textwpair) # list of all argument0s in string textwpair
        arg1 = re.findall(r"<ARG1>(.*?)</ARG1>", textwpair) # list of all argument1s in string textwpair
        if len(arg0) != 0:
            # unpack the list of argument0s and append them one by one
            for arg in arg0:
                arg0s.append(arg)
        if len(arg1) != 0:
            # unpack the list of argument1s and append them one by one
            for arg in arg1:
                arg1s.append(arg)
    return arg0s, arg1s

In [15]:
# list of arg0s and arg1s for SemEval
semeval_arg0s, semeval_arg1s = extract_args(semeval_textwpairs)
# list of all args for SemEval
semeval_args = semeval_arg0s + semeval_arg1s

In [16]:
# construct text_list
semeval_text = semeval_causal['text']
semeval_text_list = semeval_text.tolist()

In [17]:
# construct dictionaries of word2id
semeval_word2id = {val : idx for idx, val in enumerate(set(semeval_args))}
len(semeval_word2id.keys()), len(semeval_args) # 299 duplicates

(2704, 3003)

In [18]:
# construct dictionaries of id2word
semeval_id2word = {val: key for key, val in semeval_word2id.items()}

In [19]:
# contruct dictionaries of id2freq
semeval_id2freq = semeval_id2word.copy()
semeval_freq = [(arg, semeval_args.count(arg)) for arg in set(semeval_args)]
i = 0
for key, val in semeval_id2freq.items():
    semeval_id2freq[key] = semeval_freq[i][1]
    i += 1

In [20]:
# load the Phrase-BERT model through the sentence-BERT interface
model_path = "/content/drive/MyDrive/Assignments/capstone/pooled_context_para_triples_p=0.8/"
model = SentenceTransformer(model_path)

In [None]:
"""
# commented out because results are already saved and can be loaded from files
# compute phrase embeddings using Phrase-BERT
semeval_phrase_embs = model.encode(set(semeval_args), batch_size=8, show_progress_bar=True)
semeval_embs = np.asarray(semeval_phrase_embs)
"""

In [None]:
"""
# save the results
topic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"
np.save(os.path.join(topic_model_data_path, 'semeval_embs_matrix_np'), semeval_embs)
"""

In [21]:
# set seed
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fc6b48be3d0>

In [22]:
topic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"

In [23]:
semeval_embs_matrix_np = np.load(os.path.join(topic_model_data_path, f"semeval_embs_matrix_np.npy"))
print(f"Loaded semeval word embedding from {topic_model_data_path}")
print(f"Loaded vocab size of {len(semeval_word2id)} (including phrases)")

Loaded semeval word embedding from /content/drive/MyDrive/Assignments/capstone/pntm/
Loaded vocab size of 2704 (including phrases)


In [24]:
len(semeval_embs_matrix_np)

2704

Below code is adapted from Phrase-Bert: https://github.com/sf-wa-326/phrase-bert-topic-model

In [25]:
# word frequency and filter info

# compute the length (in n-grams)
# setting word_threshould really high to include every phrase, reset to lower value to remove longer phrases
word_threshold = 100

semeval_len_words = [0] * len(semeval_id2word)
for (id, word) in semeval_id2word.items():
    semeval_len_words[id] = len(word.split(' '))
# setting word_len to 100 because we
# and args contain empty spaces at the beginning and end of the strings
semeval_indices_to_remove_based_on_len = [id 
                                          for id, word_len 
                                          in enumerate(semeval_len_words) 
                                          if (word_len > word_threshold )]

print(len(semeval_indices_to_remove_based_on_len)) # 0

0


In [26]:
# keeping every token / word, reset freq_threshold to remove lower frequency words
freq_threshold = 0

semeval_sorted_ids = [k for k, v in sorted(semeval_id2freq.items(), key=lambda item: item[1])]
semeval_sorted_ids.reverse()
semeval_indices_to_remove_based_on_freq = [k for k, v in semeval_id2freq.items() if v <= freq_threshold ]
semeval_to_be_removed = list(set(semeval_indices_to_remove_based_on_freq + semeval_indices_to_remove_based_on_len))

In [27]:
# encode the text_lists
semeval_text_rep_list = model.encode(semeval_text, batch_size = 8, show_progress_bar = True)

Batches:   0%|          | 0/139 [00:00<?, ?it/s]

In [28]:
emb_model = "phrase-bert"
print(f"Building sentence model by using {emb_model} as embedding model")

semeval_uid_input_vector_list = [(i, semeval_text_rep_list[i]) for i in range(len(semeval_text_rep_list))]
print(f"Computed {len(semeval_uid_input_vector_list)} positive examples")

Building sentence model by using phrase-bert as embedding model
Computed 1106 positive examples


In [29]:
# setting the argument num_negative_samples for negative sampling
num_neg_samples = 5 # default in the original model
semeval_uid_input_vector_list_neg = []
indices = list(range(len(semeval_uid_input_vector_list)))
for idx in range(len(semeval_uid_input_vector_list)):
    indices_candidate = indices
    neg_indices = random.sample(indices_candidate, num_neg_samples)
    neg_samples = [semeval_uid_input_vector_list[neg_i][1] for neg_i in neg_indices]
    neg_vector = np.mean(neg_samples, axis=0)
    semeval_uid_input_vector_list_neg.append(neg_vector)
print(f"Computed {len(semeval_uid_input_vector_list_neg)} negative examples")

Computed 1106 negative examples


In [30]:
# set up hyperparameters
semeval_net_params = {}
semeval_net_params["mode"] = "bert"
semeval_net_params["embedding"] = semeval_embs_matrix_np
semeval_net_params["d_hid"] = 100
semeval_net_params["num_rows"] = 100  # number of topics
semeval_net_params["num_sub_topics"] = 0
semeval_net_params["word_dropout_prob"] = 0.2
semeval_net_params["vrev"] = semeval_id2word  # idx to word map
semeval_net_params["device"] = 'cuda'
semeval_net_params["pred_world"] = False

In [31]:
semeval_net = DictionaryAutoencoder(net_params=semeval_net_params)
semeval_net.to('cuda')

DictionaryAutoencoder(
  (embeddings): Embedding(2704, 768)
  (W_proj): Linear(in_features=768, out_features=100, bias=True)
  (act): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (W_att): Linear(in_features=100, out_features=768, bias=True)
  (W_out): Linear(in_features=768, out_features=2704, bias=True)
)

In [32]:
# training specs (default from original code)
num_epochs = 300
batch_size = 100
ortho_weight = 1e-5
world_clas_weight = 0.0
semeval_optim = torch.optim.Adam(semeval_net.parameters(), lr=1e-4)
interpret_interval = int(np.ceil(num_epochs / 10))
h_model = 2

In [33]:
# iterating through batches
semeval_batch_intervals = [
    (start, start + batch_size)
    for start in range(0, len(semeval_uid_input_vector_list), batch_size)]
    # batch_intervals = batch_intervals[:100]
semeval_split = int(np.ceil(len(semeval_batch_intervals) * 0.9))
semeval_batch_intervals_train = semeval_batch_intervals[:semeval_split]
semeval_batch_intervals_valid = semeval_batch_intervals[semeval_split:]

In [34]:
import argparse
parser = argparse.ArgumentParser()
"""parser.add_argument("--lr", type=float, default=1e-4)
"""
args = parser.parse_args(args=[])
args.device = 'cuda:' + '0'
args.triplet_loss_margin = 1.0
args.triplet_loss_weight = 1.0
args.ortho_weight = 1e-5
args.neighbour_loss_weight = 1e-7
args.offset_loss_weight = 1e-4

In [35]:
# semeval training
print("\n" + "=" * 70)
for epoch in range(num_epochs):
    # training
    semeval_net.train()
    train_mode = True
    print(f"Epoch {epoch}")
    run_epoch(semeval_net, semeval_optim, semeval_batch_intervals_train,
              semeval_uid_input_vector_list, semeval_uid_input_vector_list_neg,
              args, train_mode, h_model, epoch, 200)

    # validation
    semeval_net.eval()
    train_mode = False
    with torch.no_grad():
        run_epoch(
                semeval_net,
                semeval_optim,
                semeval_batch_intervals_valid,
                semeval_uid_input_vector_list,
                semeval_uid_input_vector_list_neg,
                args, 
                train_mode,
                h_model,
                epoch,
                200
        )

    if (epoch + 1) % interpret_interval == 0:
        print("Topics with probability argmax")
        topics_print_list = semeval_net.rank_vocab_for_topics(
            word_embedding_matrix=semeval_embs_matrix_np,
            to_be_removed=semeval_to_be_removed)
        print("=" * 70)

    print()
    print()
    print()
    print("=" * 70)


Epoch 0
[TRAIN] loss: 0.9679, 0.9679, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.11 s
[VALID] loss: 0.9480, 0.9480, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 1
[TRAIN] loss: 0.9217, 0.9217, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.07 s
[VALID] loss: 0.8997, 0.8997, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 2
[TRAIN] loss: 0.8831, 0.8831, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.07 s
[VALID] loss: 0.8580, 0.8580, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 3
[TRAIN] loss: 0.8532, 0.8532, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.06 s
[VALID] loss: 0.8291, 0.8291, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 4
[TRAIN] loss: 0.8306, 0.8306, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.06 s
[VALID] loss: 0.8026, 0.8026, 0.0000, 0.0000, 0.0000 (all, tri, or, off, nei), time: 0.00 s



Epoch 5
[TRAIN] loss: 0.

In [36]:
# semeval print topic list
print("Finally after training")
semeval_net.eval()

print("Topics with probability argmax")
prob_over_vocab_np, topics_print_list = semeval_net.rank_vocab_for_topics(
            word_embedding_matrix=semeval_embs_matrix_np, to_be_removed=semeval_to_be_removed
    )
print("=" * 70)

Finally after training
Topics with probability argmax
100
100
[1190  256 1054 2249  576 1603  962 1249  830 2573]
topic 0 :  the hepatitis A virus (HAV) . ,  monitoring ,  the parasite entering the body through the skin during bathing or drinking of infested water . ,  The abdominal distention ,  Winds have been blowing ,  Heat , wind and smoke ,  In 1871 , nearly 60 percent of the Italian population farmed for a living , ,  Lymphedema ,  a blood clot . ,  Autologous blood clot is useful 
[1579  714 2406 2045 1158  787  906 2590  754 2588]
topic 1 :  gravitational interactions with the small satellites Prometheus and Pandora . ,  Canine flea infestation is ,  and currently about 90 million patients worldwide are affected by the disease . ,  the acid rains . ,  its turn by ,  The enclosed community ,  loss of ,  the hepatitis B virus . ,  The tides are ,  that NOX5 mediated overproduction of hydrogen peroxide is responsible for increased growth and decreased death of cancer cells . 
[19

In [37]:
# after training we evaluate all the topics percentage in the dataset and rank the topics by percentage
uid_list, vector_list = zip(*semeval_uid_input_vector_list)
topic_pred_list = text_to_topic(vector_list, semeval_net, 'cuda')

topic_id_ranked, topic_percentage_ranked = rank_topics_by_percentage( topic_pred_list )

for rank, (topic_id, topic_percentage) in enumerate( zip(topic_id_ranked, topic_percentage_ranked)):
    print(
            f"Rank: {rank}, Topic_id: {topic_id}, Topic Words: {topics_print_list[topic_id]}, \
            Topic Percentage: {topic_percentage}"
        )


100%|██████████| 3/3 [00:00<00:00, 215.33it/s]

Rank: 0, Topic_id: 72, Topic Words: topic 72 :  exposure to wind , sun , and detergents . ,  The injury ,  The changes now seen in the endometrium are ,  the decoration ,  interim contact with the participant should be made by the investigator . ,  work being carried out by track operator Network Rail . ,  from the amphibolites of Santiago-Ponte Ulla , Spain . ,  snapping tendons and ligaments , and rickety arthritic joints . ,  The landslides ,  to become "Zarthushtis" . ,             Topic Percentage: 6.24
Rank: 1, Topic_id: 32, Topic Words: topic 32 :  fleas , which are small , wingless blood-sucking insects . ,  is an on-going matter that the town and the police continue to address , particularly in the summer months . ,  The court ,  had still not been proven ,  economic crisis ,  clearly justifying European intervention . ,  In reality , however , the drama has been ,  out to my hips and back and then ,  Yoga is fantastic for ,  that private military and security companies are us




In [38]:
prob_over_vocab_df = pd.DataFrame(prob_over_vocab_np) # shape: 50 x 1131
prob_over_vocab_df.head(5)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2694,2695,2696,2697,2698,2699,2700,2701,2702,2703
0,0.000377,0.000251,0.000288,0.000381,0.000257,0.000259,0.000311,0.000362,0.000383,0.00027,...,0.000301,0.000317,0.000216,0.000344,0.000433,0.000325,0.000237,0.000384,0.000475,0.000408
1,0.000396,0.0003,0.000522,0.000378,0.000476,0.000483,0.000281,0.000311,0.000351,0.000504,...,0.000371,0.000518,0.000476,0.000339,0.000257,0.000334,0.000279,0.000428,0.000338,0.000496
2,0.000333,0.00043,0.000334,0.00041,0.000341,0.000385,0.000329,0.000427,0.000375,0.000335,...,0.000275,0.000345,0.000391,0.000302,0.000419,0.00034,0.000289,0.000379,0.000403,0.000274
3,0.00028,0.000646,0.000361,0.000308,0.000348,0.000669,0.000366,0.000596,0.000331,0.000433,...,0.000442,0.0004,0.000635,0.000504,0.000336,0.000273,0.000389,0.000503,0.000382,0.000298
4,0.000314,0.000431,0.000481,0.000337,0.000404,0.000447,0.000253,0.00035,0.000311,0.000428,...,0.00038,0.000286,0.000369,0.000419,0.000397,0.000375,0.000409,0.000416,0.000416,0.000261


In [39]:
prob_over_vocab_df.idxmax() # for each column, find the row number of the max

0       33
1       24
2       73
3       33
4       50
        ..
2699    97
2700    71
2701    75
2702    39
2703    61
Length: 2704, dtype: int64

In [40]:
# construct a dictionary with topic ids as keys and arg ids as values
# where only one topic is assigned to each argument
topic2argid_dict = {}
# construct a dictionary with arg ids as keys and topic ids as values
argid2topic_dict = {}
for i in range(100): # number of topics created
    argid = []
    for j in range(len(prob_over_vocab_df.idxmax())):
        if prob_over_vocab_df.idxmax()[j] == i:
            argid.append(j)
            argid2topic_dict[j] = i
    topic2argid_dict[i] = argid

In [41]:
topic2argid_dict

{0: [54,
  80,
  117,
  138,
  180,
  211,
  330,
  486,
  504,
  548,
  576,
  604,
  647,
  746,
  822,
  830,
  956,
  962,
  1037,
  1106,
  1183,
  1190,
  1234,
  1249,
  1277,
  1355,
  1458,
  1466,
  1479,
  1520,
  1534,
  1646,
  1665,
  1795,
  1966,
  2149,
  2224,
  2227,
  2249,
  2365,
  2573,
  2678],
 1: [551, 708, 754, 787, 1842, 2045, 2180, 2259],
 2: [191, 309, 461, 1135, 1358, 1555, 1601, 1625, 1727, 1974, 2272],
 3: [5,
  93,
  234,
  295,
  313,
  369,
  443,
  466,
  505,
  559,
  695,
  711,
  933,
  952,
  968,
  1039,
  1316,
  1367,
  1396,
  1410,
  1542,
  1589,
  1636,
  1654,
  1680,
  1712,
  1714,
  1722,
  1726,
  1739,
  1762,
  1785,
  1791,
  1799,
  1935,
  2057,
  2058,
  2129,
  2182,
  2237,
  2449,
  2461,
  2568,
  2585,
  2612,
  2659,
  2696],
 4: [44,
  68,
  294,
  316,
  338,
  343,
  400,
  453,
  566,
  663,
  908,
  930,
  936,
  1026,
  1073,
  1384,
  1394,
  1405,
  1439,
  1441,
  1462,
  1489,
  1541,
  1686,
  1893,
  1937,
  1

In [42]:
argid2topic_dict

{54: 0,
 80: 0,
 117: 0,
 138: 0,
 180: 0,
 211: 0,
 330: 0,
 486: 0,
 504: 0,
 548: 0,
 576: 0,
 604: 0,
 647: 0,
 746: 0,
 822: 0,
 830: 0,
 956: 0,
 962: 0,
 1037: 0,
 1106: 0,
 1183: 0,
 1190: 0,
 1234: 0,
 1249: 0,
 1277: 0,
 1355: 0,
 1458: 0,
 1466: 0,
 1479: 0,
 1520: 0,
 1534: 0,
 1646: 0,
 1665: 0,
 1795: 0,
 1966: 0,
 2149: 0,
 2224: 0,
 2227: 0,
 2249: 0,
 2365: 0,
 2573: 0,
 2678: 0,
 551: 1,
 708: 1,
 754: 1,
 787: 1,
 1842: 1,
 2045: 1,
 2180: 1,
 2259: 1,
 191: 2,
 309: 2,
 461: 2,
 1135: 2,
 1358: 2,
 1555: 2,
 1601: 2,
 1625: 2,
 1727: 2,
 1974: 2,
 2272: 2,
 5: 3,
 93: 3,
 234: 3,
 295: 3,
 313: 3,
 369: 3,
 443: 3,
 466: 3,
 505: 3,
 559: 3,
 695: 3,
 711: 3,
 933: 3,
 952: 3,
 968: 3,
 1039: 3,
 1316: 3,
 1367: 3,
 1396: 3,
 1410: 3,
 1542: 3,
 1589: 3,
 1636: 3,
 1654: 3,
 1680: 3,
 1712: 3,
 1714: 3,
 1722: 3,
 1726: 3,
 1739: 3,
 1762: 3,
 1785: 3,
 1791: 3,
 1799: 3,
 1935: 3,
 2057: 3,
 2058: 3,
 2129: 3,
 2182: 3,
 2237: 3,
 2449: 3,
 2461: 3,
 2568: 3,
 2585

Construct a new dataframe with columns: sentence, arg0, arg1, arg0id, arg1id, arg0topicid, arg1topicid

In [43]:
semeval_topic_df = semeval_causal.copy()
semeval_topic_df.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents
1,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is <ARG1> that the chronic in...,1,1,,1
2,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,<ARG1> The burst has been </ARG1> caused by <A...,1,1,,1
3,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","<ARG0> The singer , who performed three of the...",1,1,,1
4,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0> Suicide is one of the leading causes of...,1,1,,1
5,semeval2010t8,train.json,31,0,semeval2010t8_train.json_31_0,He had chest pains and headaches from mold in ...,<ARG1> He had chest pains and headaches from <...,1,1,,1


In [44]:
# modify the original extract_args function to differentiate which text arguments belong to
def extract_args2(dataset):
    arg0s = []
    arg1s = []
    for textwpair in dataset:
        arg0 = re.findall(r"<ARG0>(.*?)</ARG0>", textwpair) # list of all argument0s in string textwpair
        arg1 = re.findall(r"<ARG1>(.*?)</ARG1>", textwpair) # list of all argument1s in string textwpair
        arg0s.append(arg0)
        arg1s.append(arg1)
    return arg0s, arg1s

In [45]:
# create new columns for lists of arg0s and arg1s for each text
semeval_topic_df['arg0'], semeval_topic_df['arg1'] = extract_args2(semeval_textwpairs)

In [46]:
# create arg0_id and arg1_id columns
semeval_arg0id_list = []
for arg0s in semeval_topic_df['arg0']:
    temp = []
    for arg0 in arg0s:
        temp.append(semeval_word2id[arg0])
    semeval_arg0id_list.append(temp)
semeval_topic_df['arg0_id'] = semeval_arg0id_list

semeval_arg1id_list = []
for arg1s in semeval_topic_df['arg1']:
    temp = []
    for arg1 in arg1s:
        temp.append(semeval_word2id[arg1])
    semeval_arg1id_list.append(temp)
semeval_topic_df['arg1_id'] = semeval_arg1id_list

In [47]:
# create columns for arg0_topicid and arg1_topicid
semeval_arg0topic_list = []
for arg0s in semeval_arg0id_list:
    temp = []
    for arg0 in arg0s:
        temp.append(argid2topic_dict[arg0])
    semeval_arg0topic_list.append(temp)
semeval_topic_df['arg0_topicid'] = semeval_arg0topic_list

semeval_arg1topic_list = []
for arg1s in semeval_arg1id_list:
    temp = []
    for arg1 in arg1s:
        temp.append(argid2topic_dict[arg1])
    semeval_arg1topic_list.append(temp)
semeval_topic_df['arg1_topicid'] = semeval_arg1topic_list

In [48]:
semeval_topic_df.head(5)

Unnamed: 0,corpus,doc_id,sent_id,eg_id,index,text,text_w_pairs,seq_label,pair_label,context,num_sents,arg0,arg1,arg0_id,arg1_id,arg0_topicid,arg1_topicid
1,semeval2010t8,train.json,6,0,semeval2010t8_train.json_6_0,The current view is that the chronic inflammat...,The current view is <ARG1> that the chronic in...,1,1,,1,"[ Helicobacter pylori infection , increased a...",[ that the chronic inflammation in the distal ...,"[2047, 130]","[1675, 1835, 878]","[20, 20]","[89, 29, 81]"
2,semeval2010t8,train.json,13,0,semeval2010t8_train.json_13_0,The burst has been caused by water hammer pres...,<ARG1> The burst has been </ARG1> caused by <A...,1,1,,1,[ water hammer pressure . ],[ The burst has been ],[219],[2077],[67],[14]
3,semeval2010t8,train.json,22,0,semeval2010t8_train.json_22_0,"The singer , who performed three of the nomina...","<ARG0> The singer , who performed three of the...",1,1,,1,"[ The singer , who performed three of the nomi...",[ a commotion on the red carpet . ],"[1396, 1045]",[431],"[3, 42]",[36]
4,semeval2010t8,train.json,26,0,semeval2010t8_train.json_26_0,Suicide is one of the leading causes of death ...,<ARG0> Suicide is one of the leading causes of...,1,1,,1,[ Suicide is one of the leading causes of deat...,"[ , , and victims of bullying are at an incre...",[173],"[163, 799]",[59],"[72, 10]"
5,semeval2010t8,train.json,31,0,semeval2010t8_train.json_31_0,He had chest pains and headaches from mold in ...,<ARG1> He had chest pains and headaches from <...,1,1,,1,[ mold in the bedrooms ],"[ He had chest pains and headaches from , . ]",[1326],"[1200, 686]",[84],"[23, 95]"


In [49]:
# save the dataframe as a csv file
topic_model_data_path = "/content/drive/MyDrive/Assignments/capstone/pntm/"
np.save(os.path.join(topic_model_data_path, 'semeval_topic_df'), semeval_topic_df)

In [50]:
with open( os.path.join(topic_model_data_path, 'semeval_topic_model_100.pt'), "wb") as f:
    torch.save(semeval_net, f)
    print(f"Saved model at { os.path.join(topic_model_data_path) }")

Saved model at /content/drive/MyDrive/Assignments/capstone/pntm/
