For each sentence , we have 30x768 features .
We will make 1 homonym classifier with 3 classes (forward relation , backward relation and no relation) and a synonym classifier with 2 classes (relation or no relation)

Thus our dataset will be like (2 keyphrases , issynonym , ishyponym)

In [1]:
from transformers import BertTokenizer
import torch
import torch.nn as nn
from transformers import BertModel
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
# from crfseg import CRF
import torch.nn.functional as F
import tqdm as tqdm


# Defining Dataset and Dataloader

In [2]:
class MyDataset(Dataset):
    def __init__(self, data_dir ,window_size = 2 , max_len_keyphrase = 30):
        self.data_dir = data_dir
        
        # iterate through the files in the data directory
        self.txtfiles = []
        self.annfiles = []
        self.max_len_keyphrase = max_len_keyphrase

        for file in os.listdir(data_dir):
            if file.endswith(".txt"):
                self.txtfiles.append(file)
        
        self.tokeniser = BertTokenizer.from_pretrained('bert-base-uncased')
        
        # self.hyponym_label_dict = hyponym_label_dict

        self.keyphrase_pairs = []
        self.ishyponym = [] # for each keyphrase pair, 1 if it is a hyponym pair, 0 otherwise
        self.issynonym = [] # for each keyphrase pair, 1 if it is a synonym pair, 0 otherwise

        '''Keyphrases are stored in the ann files in the following format:'''
        for txtfile in self.txtfiles:
            sampleid = txtfile.split(".")[0]
            annfilename = sampleid + ".ann"
            with open(os.path.join(self.data_dir, annfilename), 'r') as file:
                ann = file.read()
                
            with open(os.path.join(self.data_dir, txtfile), 'r') as file:
                txt = file.read()
                
            offsets , tokenisedtxt = self.tokenise(txt)
            token_ids = self.tokeniser.convert_tokens_to_ids(tokenisedtxt)
            keyphrases = [] # contains all the keyphrases in the txtfile
            keyphrases_matching_dict = {} # contains the txtfile index(T1,T2 etc) with index of keyphrase in keyphrases
            hyponym_pairs = [] # contains all the hyponym pairs in the txtfile
            synonym_pairs = [] # contains all the synonym pairs in the txtfile

            '''First get all keyphrases'''
            for line in ann.split('\n'):
                if line == '':
                    continue
                words = line.split()
                
                txtindex = words[0]
                ssofset = words[2]
                endoffset = words[3]

                if txtindex[0] == 'T':
                    '''get the index of first token whose offset is greater than or equal to ssofset'''
                    start_index = 0
                    for i in range(len(offsets)):
                        if offsets[i] >= int(ssofset):
                            start_index = i
                            break
                    
                    '''get the index of first token whose offset is greater than or equal to endoffset'''
                    end_index = 0
                    for i in range(len(offsets)):
                        if offsets[i] >= int(endoffset):
                            end_index = i
                            break
                    
                    '''get the keyphrase tokens with context tokens'''
                    mins = max(0 , start_index - window_size)
                    maxs = min(len(token_ids) , end_index + window_size)
                    tks = []
                    for j in range(mins , maxs):
                        tks.append(token_ids[j])
                    
                    # pad tks with bert_padding_token to make it of length max_len_keyphrase
                    if len(tks) < max_len_keyphrase:
                        tks += [self.tokeniser.pad_token_id] * (max_len_keyphrase - len(tks))

                    keyphrases.append(tks)
                    keyphrases_matching_dict[txtindex] = len(keyphrases) - 1

                elif txtindex[0] == 'R':
                    arg1 = words[2].split(':')[1]
                    arg2 = words[3].split(':')[1]

                    # there is a forward relation between arg1 and arg2
                    if arg1[0] == 'T' and arg2[0] == 'T':
                        keyphrase1_index = keyphrases_matching_dict[arg1]
                        keyphrase2_index = keyphrases_matching_dict[arg2]
                        hyponym_pairs.append((keyphrase1_index , keyphrase2_index))
                
                elif txtindex[0] == '*':
                    '''synonym class'''
                    arg1 = words[2]
                    arg2 = words[3]
                    
                    if arg1[0] == 'T' and arg2[0] == 'T':
                        keyphrase1_index = keyphrases_matching_dict[arg1]
                        keyphrase2_index = keyphrases_matching_dict[arg2]
                        synonym_pairs.append((keyphrase1_index , keyphrase2_index))
            
            '''Now make all keyphrase pairs and add them to self.keyphrase_pairs'''
            for i in range(len(keyphrases)):
                for j in range(i+1 , len(keyphrases)):
                    self.keyphrase_pairs.append((keyphrases[i] , keyphrases[j]))
                    if ((i , j) in hyponym_pairs) or ((j , i) in hyponym_pairs):
                        self.ishyponym.append(1)
                    else:
                        self.ishyponym.append(0)
                    
                    if ((i , j) in synonym_pairs) or ((j , i) in synonym_pairs):
                        self.issynonym.append(1)
                    else:
                        self.issynonym.append(0)


    
    def tokenise(self , text):
        tokens = []  # List to store tokens
        starting_offsets = []  # List to store starting offsets
        current_token = ''  # Variable to store current token
        offset = 0  # Starting offset

        for char in text:
            if char == ' ':
                if current_token:  # If token is not empty
                    tokens.append(current_token.lower())  # Append token in lowercase
                    starting_offsets.append(offset - len(current_token))  # Store starting offset
                    current_token = ''  # Reset current token
                offset += 1  # Move offset to next character
            else:
                current_token += char  # Append character to current token
                offset += 1  # Move offset to next character

        # Handling the last token if it exists after the loop ends
        if current_token:
            tokens.append(current_token.lower())  # Append token in lowercase
            starting_offsets.append(offset - len(current_token))  # Store starting offset

        return starting_offsets , tokens

    def __len__(self):
        return len(self.keyphrase_pairs)

    def __getitem__(self, index):
        return self.keyphrase_pairs[index][0] ,self.keyphrase_pairs[index][1] , self.ishyponym[index] , self.issynonym[index]
    
    def collate_fn(self , batch):
        keyphrases1s = []
        keyphrases2s = []
        label1s = []
        label2s = []
        for keyphrase1 , keyphrase2 , label1 , label2 in batch:
            keyphrases1s.append(torch.tensor(keyphrase1))
            keyphrases2s.append(torch.tensor(keyphrase2))
            label1s.append(torch.tensor(label1))
            label2s.append(torch.tensor(label2))
        
        keyphrases1s = torch.stack(keyphrases1s)
        keyphrases2s = torch.stack(keyphrases2s)
        label1s = torch.stack(label1s)
        label2s = torch.stack(label2s)

        return keyphrases1s ,keyphrases2s , label1s , label2s



In [3]:
label_dict = {
    'Process' : 0,
    'Task' : 1,
    'Material' : 2,
}
train_dataset = MyDataset('Data/train2')
print(len(train_dataset))

72230


In [4]:
train_dataloader = DataLoader(train_dataset , batch_size = 32 , shuffle = True , collate_fn = train_dataset.collate_fn)

# Defining Model

In [5]:
class FeatureExtractor(nn.Module):
    def __init__(self , emb_size = 3):
        super(FeatureExtractor , self).__init__()
        self.embedding = BertModel.from_pretrained('bert-base-uncased')
        self.embed_dim = 768

        for param in self.embedding.parameters():
            param.requires_grad = False

        self.conv1 = nn.Sequential(nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 1 , padding = 0) ,
        nn.ReLU()) # (B , 768 , 30) -> (B , 256 , 30)

        self.conv2 = nn.Sequential(nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 3 , padding = 1) , 
        nn.ReLU())

        self.conv3 = nn.Sequential(nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 5 , padding = 2) ,
        nn.ReLU())

        self.conv4 = nn.Sequential(nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 7 , padding = 3) ,
        nn.ReLU())

        self.max_over_time_pooling = nn.MaxPool1d(kernel_size=60) # (B , 256*4 , 30) -> (B , 256*4)
        self.fc = nn.Sequential(
            nn.Linear(256*4 , 512) , nn.ReLU() , nn.Linear(512 , 256) , nn.ReLU() , nn.Linear(256 , emb_size)
        )
    
    def forward(self , x1 , x2):
        x = torch.cat([x1 , x2] , dim = 1) # shape : B , 60
        embeddings  = self.embedding(x)[0] # shape : B , 60 , 768
        embeddings = embeddings.permute(0 , 2 , 1)

        conv1_out = self.conv1(embeddings) # shape : B , 256 , 60
        conv2_out = self.conv2(embeddings) # shape : B , 256 , 60
        conv3_out = self.conv3(embeddings) # shape : B , 256 , 60
        conv4_out = self.conv4(embeddings) # shape : B , 256 , 60

        concat_conv_outs = torch.cat([conv1_out , conv2_out , conv3_out , conv4_out] , dim = 1) # shape : B , 256*4 , 60
        max_pooled = self.max_over_time_pooling(concat_conv_outs)
        max_pooled = max_pooled.squeeze(2) # B , 256*4
        features = self.fc(max_pooled)

        return features


In [6]:
feature_extractor_model = FeatureExtractor()
sample_inp = torch.randint(0 , 30522 , (32 , 30))
sample_inp2 = torch.randint(0 , 30522 , (32 , 30))
print(sample_inp.shape , sample_inp2.shape)
output_emb = feature_extractor_model(sample_inp , sample_inp2)
print(output_emb.shape)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([32, 30]) torch.Size([32, 30])
torch.Size([32, 3])


# Training

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [8]:
model = FeatureExtractor().to(device)
optimiser = torch.optim.Adam(model.parameters() , lr = 0.001)
criterion = nn.CrossEntropyLoss()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
val_dataset = MyDataset('Data/dev')
val_dataloader = DataLoader(val_dataset , batch_size = 32 , shuffle = False , collate_fn = val_dataset.collate_fn)

In [12]:
model.train()
num_epochs = 10

for i in range(num_epochs):
    avg_loss = 0
    with tqdm.tqdm(train_dataloader, unit="batch") as tepoch:
        for keyphrases1 , keyphrases2 , label1 , label2 in tepoch:
            keyphrases1 = keyphrases1.to(device)
            keyphrases2 = keyphrases2.to(device)
            label1 = label1.to(device = device , dtype = torch.float32) # B , 1
            label2 = label2.to(device = device , dtype = torch.float32) # B , 1

            optimiser.zero_grad()
            out = model(keyphrases1 , keyphrases2) # (B , 3)

            hyponym_loss = criterion(out[: , 0] , label1)
            synonym_loss = criterion(out[: , 1] , label2)
            loss = hyponym_loss + synonym_loss
            loss.backward()
            optimiser.step()

            avg_loss += loss.item() / len(train_dataloader)
            tepoch.set_postfix(loss=avg_loss)


100%|██████████| 2258/2258 [01:45<00:00, 21.37batch/s, loss=0.462]
100%|██████████| 2258/2258 [01:54<00:00, 19.71batch/s, loss=0.458]
100%|██████████| 2258/2258 [01:55<00:00, 19.57batch/s, loss=0.452]
100%|██████████| 2258/2258 [01:56<00:00, 19.45batch/s, loss=0.439]
100%|██████████| 2258/2258 [01:57<00:00, 19.20batch/s, loss=0.403] 
100%|██████████| 2258/2258 [01:58<00:00, 19.09batch/s, loss=0.388] 
100%|██████████| 2258/2258 [01:58<00:00, 19.13batch/s, loss=0.375]
100%|██████████| 2258/2258 [01:56<00:00, 19.31batch/s, loss=0.381]
100%|██████████| 2258/2258 [01:56<00:00, 19.37batch/s, loss=0.36]  
100%|██████████| 2258/2258 [01:56<00:00, 19.42batch/s, loss=0.327]
