In [1]:
import warnings
warnings.filterwarnings("ignore")

In [22]:
import spacy
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertModel
from torch.nn.utils.rnn import pad_sequence

In [3]:
def load_dataset(file_path):
    sentences = []
    labels = []
    
    with open(file_path, 'r') as f:
        lines = f.readlines()
        sentence = []
        
        for line in lines:
            line = line.strip() # remove leading/trailing whitespaces
            if line:
                if not line.startswith('ARG') and not line.startswith('REL') and not line.startswith('NONE'):
                    sentence = line
                else:
                    current_label = line
                    sentences.append(sentence)
                    labels.append(current_label)
                    
    return sentences, labels

In [4]:
# load the dataset into a pandas dataframe
sentences, labels = load_dataset('./Dataset/original_cleaned')
df = pd.DataFrame({
    'Sentence': sentences,
    'Labels': labels
})

In [5]:
df.shape

(166889, 2)

In [6]:
df

Unnamed: 0,Sentence,Labels
0,Simon is quoted as saying `` if you 'd ever se...,ARG1 REL REL ARG2 ARG2 ARG2 ARG2 ARG2 ARG2 ARG...
1,Simon is quoted as saying `` if you 'd ever se...,NONE NONE NONE NONE NONE NONE NONE NONE NONE N...
2,Simon is quoted as saying `` if you 'd ever se...,NONE NONE NONE NONE NONE NONE NONE ARG1 REL TI...
3,Simon is quoted as saying `` if you 'd ever se...,ARG1 NONE NONE REL REL NONE NONE NONE NONE NON...
4,The couple had no children .,ARG1 ARG1 REL ARG2 ARG2 NONE
...,...,...
166884,TIME ARG1 REL NONE NONE ARG2 ARG2 ARG2 ARG2 AR...,NONE NONE NONE NONE NONE ARG1 REL ARG2 ARG2 AR...
166885,TIME ARG1 REL NONE NONE ARG2 ARG2 ARG2 ARG2 AR...,NONE NONE NONE NONE NONE NONE NONE NONE NONE N...
166886,This was the time when Yang Luchan made the Ch...,ARG1 REL ARG2 ARG2 ARG2 ARG2 ARG2 ARG2 ARG2 AR...
166887,This was the time when Yang Luchan made the Ch...,NONE NONE TIME TIME NONE ARG1 ARG1 REL ARG2 AR...


In [7]:
# change this later on !!!!
df = df[:50]

In [8]:
# convert all sentences to lower case
df['Sentence'] = df['Sentence'].str.lower()

In [9]:
def remerge_sent(sent):
    # merges tokens which are not separated by white-space
    # does this recursively until no further changes
    changed = True
    while changed:
        changed = False
        i = 0
        while i < sent.__len__() - 1:
            tok = sent[i]
            if not tok.whitespace_:
                ntok = sent[i + 1]
                # in-place operation.
                with sent.retokenize() as retokenizer:
                    retokenizer.merge(sent[i: i + 2])
                changed = True
            i += 1
    return sent

In [10]:
# Tokenize sentences using spacy
nlp = spacy.load('en_core_web_sm')

In [11]:
def check_token_label_length(row):
    doc = nlp(row['Sentence'])
    spacy_sentence = remerge_sent(doc)
    tokens = [token.text for token in spacy_sentence]
    labels = row['Labels'].split()

    is_match = len(tokens) == len(labels)
    return is_match, len(tokens), len(labels), tokens

In [12]:
df[['Token_Label_Match', 'Num_Tokens', 'Num_Labels', 'Tokens']] = df.apply(check_token_label_length, axis=1, result_type="expand")

In [13]:
df.head()

Unnamed: 0,Sentence,Labels,Token_Label_Match,Num_Tokens,Num_Labels,Tokens
0,simon is quoted as saying `` if you 'd ever se...,ARG1 REL REL ARG2 ARG2 ARG2 ARG2 ARG2 ARG2 ARG...,True,32,32,"[simon, is, quoted, as, saying, ``, if, you, '..."
1,simon is quoted as saying `` if you 'd ever se...,NONE NONE NONE NONE NONE NONE NONE NONE NONE N...,True,32,32,"[simon, is, quoted, as, saying, ``, if, you, '..."
2,simon is quoted as saying `` if you 'd ever se...,NONE NONE NONE NONE NONE NONE NONE ARG1 REL TI...,True,32,32,"[simon, is, quoted, as, saying, ``, if, you, '..."
3,simon is quoted as saying `` if you 'd ever se...,ARG1 NONE NONE REL REL NONE NONE NONE NONE NON...,True,32,32,"[simon, is, quoted, as, saying, ``, if, you, '..."
4,the couple had no children .,ARG1 ARG1 REL ARG2 ARG2 NONE,True,6,6,"[the, couple, had, no, children, .]"


In [14]:
# Load the pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = bert_model.to(device)

In [15]:
# Function to generate BERT embeddings for each token in a sentence.
def get_bert_embeddings(tokens):
    embeddings = []

    for token in tokens:
        # Tokenize the individual token
        inputs = tokenizer(token, return_tensors='pt', padding=True, truncation=True)

        # Get BERT embeddings from the model
        with torch.no_grad():  # Disable gradient computation for efficiency
            outputs = bert_model(**inputs)

        # outputs.last_hidden_state gives the embeddings for each token
        # Taking the mean of the embeddings if there are multiple subword tokens
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0)  # Shape: (hidden_size,)

        # Append the embedding
        embeddings.append(embedding)

    # Convert list of tensors to a single tensor (if needed)
    embeddings = torch.stack(embeddings)  # Shape: (num_tokens, hidden_size)

    return embeddings  # Shape: (num_tokens, hidden_size)

In [16]:
# Function to generate BERT embeddings for each sentence in the DataFrame
def generate_embeddings(df):
    embeddings_list = []
    for index, row in df.iterrows():
        tokenized_sentence = row['Tokens']
        embeddings = get_bert_embeddings(tokenized_sentence)
        embeddings_list.append(embeddings)
    return embeddings_list

In [17]:
embeddings_list = generate_embeddings(df)

In [65]:
# pad the embeddings to ensure uniformity across dataset as model will be trained in batches
padded_embeddings = pad_sequence(embeddings_list, batch_first=True)
df['Embeddings'] = [padded_embeddings[i] for i in range(padded_embeddings.shape[0])]
print(padded_embeddings.shape)

torch.Size([50, 63, 768])


In [66]:
# encode the labels
label_encoder = LabelEncoder()
labels_list = ['ARG1', 'ARG2', 'REL', 'TIME', 'LOC', 'NONE']
label_encoder.fit(labels_list)

df['Encoded_Labels'] = df['Labels'].apply(lambda x: label_encoder.transform(x.split()))

In [67]:
# pad the labels
def pad_labels(labels, max_len):
    padded_labels = torch.full((max_len,), label_encoder.transform(['NONE'])[0])  # Assuming 'NONE' is the padding label
    padded_labels[:len(labels)] = torch.tensor(labels)
    return padded_labels

In [68]:
encoded_labels = df['Encoded_Labels'].tolist()
max_len = max(len(label) for label in encoded_labels)

padded_labels = [pad_labels(label, max_len) for label in encoded_labels]
df['Padded_Labels'] = padded_labels

In [69]:
# Split the data into training and testing sets
train_embeddings, test_embeddings, train_labels, test_labels = train_test_split(
    df['Embeddings'].tolist(),  
    df['Padded_Labels'].tolist(),
    test_size=0.2,
    random_state=42
)

In [70]:
# Further split the training set into training and validation sets
train_embeddings, val_embeddings, train_labels, val_labels = train_test_split(
    train_embeddings,
    train_labels,
    test_size=0.1,  # 10% of training set for validation
    random_state=42
)

In [71]:
# Create new DataFrames for the splits
train_df = pd.DataFrame({'Embeddings': train_embeddings, 'Labels': train_labels})
val_df = pd.DataFrame({'Embeddings': val_embeddings, 'Labels': val_labels})
test_df = pd.DataFrame({'Embeddings': test_embeddings, 'Labels': test_labels})

train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [77]:
train_df.head

<bound method NDFrame.head of                                            Embeddings  \
0   [[tensor(0.0701), tensor(0.1535), tensor(-0.16...   
1   [[tensor(0.1237), tensor(0.3573), tensor(-0.11...   
2   [[tensor(0.3723), tensor(-0.0874), tensor(-0.1...   
3   [[tensor(0.3782), tensor(-0.1690), tensor(0.13...   
4   [[tensor(-0.0435), tensor(-0.1900), tensor(-0....   
5   [[tensor(-0.0435), tensor(-0.1900), tensor(-0....   
6   [[tensor(0.0701), tensor(0.1535), tensor(-0.16...   
7   [[tensor(0.3723), tensor(-0.0874), tensor(-0.1...   
8   [[tensor(0.3782), tensor(-0.1690), tensor(0.13...   
9   [[tensor(0.0092), tensor(-0.0992), tensor(-0.3...   
10  [[tensor(0.1542), tensor(0.1303), tensor(-0.17...   
11  [[tensor(0.5746), tensor(0.7125), tensor(0.388...   
12  [[tensor(0.1196), tensor(-0.0966), tensor(-0.1...   
13  [[tensor(0.1542), tensor(0.1303), tensor(-0.17...   
14  [[tensor(0.1612), tensor(-0.0972), tensor(-0.0...   
15  [[tensor(0.2098), tensor(-0.0714), tensor(0.07...   
1