In [1]:
from transformers import BertTokenizer
import torch
import torch.nn as nn
from transformers import BertModel
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
# from crfseg import CRF
from torchcrf import CRF
import torch.nn.functional as F
import tqdm as tqdm
from torchtext.vocab import GloVe

# Making Dataloader

In [2]:
class MyDataset(Dataset):
    def __init__(self, data_dir , label_dict , window_size = 2 , max_len_keyphrase = 30):
        self.data_dir = data_dir
        
        # iterate through the files in the data directory
        self.txtfiles = []
        self.annfiles = []
        self.max_len_keyphrase = max_len_keyphrase

        for file in os.listdir(data_dir):
            if file.endswith(".txt"):
                self.txtfiles.append(file)
        
        self.tokeniser = BertTokenizer.from_pretrained('bert-base-uncased')
        self.keyphrases_with_context = [] # it contains all the keyphrase tokens with context tokens
        self.labels = []
        self.label_dict = label_dict

        '''Keyphrases are stored in the ann files in the following format:'''
        for txtfile in self.txtfiles:
            sampleid = txtfile.split(".")[0]
            annfilename = sampleid + ".ann"
            with open(os.path.join(self.data_dir, annfilename), 'r') as file:
                ann = file.read()
                
            with open(os.path.join(self.data_dir, txtfile), 'r') as file:
                txt = file.read()
                
            offsets , tokenisedtxt = self.tokenise(txt)
            token_ids = self.tokeniser.convert_tokens_to_ids(tokenisedtxt)


            for line in ann.split('\n'):
                if line == '':
                    continue
                words = line.split()
                if words[0][0] != 'T':
                    continue
                
                label = self.label_dict[words[1]]
                ssofset = words[2]
                endoffset = words[3]

                '''get the index of first token whose offset is greater than or equal to ssofset'''
                start_index = 0
                for i in range(len(offsets)):
                    if offsets[i] >= int(ssofset):
                        start_index = i
                        break
                
                '''get the index of first token whose offset is greater than or equal to endoffset'''
                end_index = 0
                for i in range(len(offsets)):
                    if offsets[i] >= int(endoffset):
                        end_index = i
                        break
                
                '''get the keyphrase tokens with context tokens'''
                mins = max(0 , start_index - window_size)
                maxs = min(len(token_ids) , end_index + window_size)
                tks = []
                for j in range(mins , maxs):
                    tks.append(token_ids[j])
                
                # pad tks with bert_padding_token to make it of length max_len_keyphrase
                if len(tks) < max_len_keyphrase:
                    tks += [self.tokeniser.pad_token_id] * (max_len_keyphrase - len(tks))

                self.keyphrases_with_context.append(tks)
                self.labels.append(label)

    
    def tokenise(self , text):
        tokens = []  # List to store tokens
        starting_offsets = []  # List to store starting offsets
        current_token = ''  # Variable to store current token
        offset = 0  # Starting offset

        for char in text:
            if char == ' ':
                if current_token:  # If token is not empty
                    tokens.append(current_token.lower())  # Append token in lowercase
                    starting_offsets.append(offset - len(current_token))  # Store starting offset
                    current_token = ''  # Reset current token
                offset += 1  # Move offset to next character
            else:
                current_token += char  # Append character to current token
                offset += 1  # Move offset to next character

        # Handling the last token if it exists after the loop ends
        if current_token:
            tokens.append(current_token.lower())  # Append token in lowercase
            starting_offsets.append(offset - len(current_token))  # Store starting offset

        return starting_offsets , tokens

    def __len__(self):
        return len(self.keyphrases_with_context)

    def __getitem__(self, index):
        return self.keyphrases_with_context[index] , self.labels[index]
    
    def collate_fn(self , batch):
        keyphrases = []
        labels = []
        for keyphrase , label in batch:
            keyphrases.append(torch.tensor(keyphrase))
            labels.append(torch.tensor(label))
        
        keyphrases = torch.stack(keyphrases)
        labels = torch.stack(labels)

        return keyphrases , labels



In [3]:
label_dict = {
    'Process' : 0,
    'Task' : 1,
    'Material' : 2,
}
train_dataset = MyDataset('Data/train2' , label_dict)
print(len(train_dataset))

6732


In [4]:
batch_size = 32
train_dataloader = DataLoader(train_dataset , batch_size = batch_size , shuffle = True , collate_fn = train_dataset.collate_fn)

# Making Model

In [5]:
class KeyphraseClassificationLinear(nn.Module):
    def __init__(self , num_labels , hidden_size = 768):
        super(KeyphraseClassificationLinear , self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        for param in self.bert.parameters():
            param.requires_grad = False

        self.hidden_size = hidden_size
        self.num_labels = num_labels
        self.linear = nn.Linear(hidden_size , num_labels)
    
    def forward(self , input_ids):
        outputs = self.bert(input_ids)[0] # shape : B , 30 , 768
        # convert outputs to shape : B , 768 i.e make sentence level representation
        outputs = torch.mean(outputs , dim = 1)
        logits = self.linear(outputs)
        return logits

In [9]:
class KeyphraseClassificationAttentionLSTM(nn.Module):
    def __init__(self , num_labels , embedding_type = 'bert'):
        '''Model takes in B kayphrase tokens (all of length 30) and outputs probabilities for num_labels classes'''
        '''Input Shape : (B , 30) Output Shape : (B , num_labels)'''
        super(KeyphraseClassificationAttentionLSTM , self).__init__()

        '''First generate Embeddings for the input tokens'''
        self.embedding = BertModel.from_pretrained('bert-base-uncased')
        self.embed_dim = 768

        for param in self.embedding.parameters():
            param.requires_grad = False
                

        '''Then Generate Attentional Vectors for the input tokens'''
        self.conv1 = nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 1 , padding = 0) # (B , 768 , 30) -> (B , 256 , 30)
        self.conv2 = nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 3 , padding=1) # (B , 768 , 30) -> (B , 256 , 30)
        self.conv3 = nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 5 , padding=2) # (B , 768 , 30) -> (B , 256 , 30)
        self.conv4 = nn.Conv1d(in_channels = self.embed_dim , out_channels = 256 , kernel_size = 7 , padding=3) # (B , 768 , 30) -> (B , 256 , 30)
        self.max_over_time_pooling = nn.MaxPool1d(kernel_size=256*4) # (B , 30 , 256*4) -> (B , 30)

        '''Then we use Bi-lstm to generate hidden states for the input tokens'''
        self.bilstm = nn.LSTM(input_size = self.embed_dim  , hidden_size = 256 , num_layers = 2 , bidirectional = True , batch_first = True) # (B , 30 , 768) -> (B , 30 , 512)

        '''Output Layer'''
        self.linear = nn.Linear(512 , num_labels)
    

    def forward(self , input_ids , bert_tokeniser = None):

        embeddings = self.embedding(input_ids)[0] # shape : B , 30 , 768
        

        # All conv outputs have shape : B , 256 , 30
        conv1_out = self.conv1(embeddings.permute(0,2,1)) 
        conv2_out = self.conv2(embeddings.permute(0,2,1))
        conv3_out = self.conv3(embeddings.permute(0,2,1))
        conv4_out = self.conv4(embeddings.permute(0,2,1))

        concat_conv_outs = torch.cat([conv1_out , conv2_out , conv3_out , conv4_out] , dim = 1) # shape : B , 256*4 , 30
        concat_conv_outs = concat_conv_outs.permute(0,2,1) # shape : B , 30 , 256*4
        max_pooled = self.max_over_time_pooling(concat_conv_outs) # shape : B , 30
        max_pooled = F.softmax(max_pooled , dim = 1) # shape : B , 30

        bilstm_outs , _ = self.bilstm(embeddings) # shape : B , 30 , 512
        
        '''Multiply the max_pooled and bilstm_outs to weigh the importance of each token'''
        attentional_vector = torch.mul(max_pooled , bilstm_outs) # shape : B , 30 , 512

        '''Now to get the final output we sum over the 30 tokens'''
        final_output = torch.sum(attentional_vector , dim = 1) # shape : B , 512
        final_output = self.linear(final_output) # shape : B , num_labels

        return final_output

    def convert_input_ids_to_glove_indices(input_ids, bert_tokenizer, glove_word_to_index):
        glove_indices = []
        for batch_input_ids in input_ids:
            glove_indices_batch = []
            for token_id in batch_input_ids:
                # Convert BERT token ID to token
                token = bert_tokenizer.convert_ids_to_tokens(token_id.item())
                # Find the closest matching token in the GloVe vocabulary
                glove_index = glove_word_to_index.get(token.lower(), None)
                if glove_index is not None:
                    glove_indices_batch.append(glove_index)
                else:
                    # Handle out-of-vocabulary tokens (e.g., assign a special index)
                    glove_indices_batch.append(UNKNOWN_INDEX)
            glove_indices.append(glove_indices_batch)
        return torch.tensor(glove_indices)


In [196]:
sample_inp = torch.randint(0 , 30522 , (32 , 30))
# model = KeyphraseClassificationLinear(3)
model = KeyphraseClassificationAttentionLSTM(3)
logits = model(sample_inp)
print(logits.shape)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([32, 3])


# Training

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
# model = KeyphraseClassificationLinear(3).to(device)
model = KeyphraseClassificationAttentionLSTM(3).to(device)
optimizer = torch.optim.Adam(model.parameters() , lr = 0.001)
criterion = nn.CrossEntropyLoss()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [199]:
def get_accuracy(model , data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for keyphrases , labels in data_loader:
            keyphrases = keyphrases.to(device)
            labels = labels.to(device)
            logits = model(keyphrases)
            _, predicted = torch.max(logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [200]:
val_dataset = MyDataset('Data/dev' , label_dict)
val_dataloader = DataLoader(val_dataset , batch_size = batch_size , shuffle = False , collate_fn = val_dataset.collate_fn)

In [201]:
num_epochs = 20

for epoch in range(num_epochs):
    with tqdm.tqdm(train_dataloader, unit="batch") as tepoch:
        avg_loss = 0
        for tokenisedids , tagslist in tepoch:
            tokenisedids = tokenisedids.to(device)
            tagslist = tagslist.to(device)
            model.train()
            optimizer.zero_grad()
            out = model(tokenisedids)
            # print(out.size() , tagslist.size())
            loss = criterion(out, tagslist)
        
            loss.backward()
            optimizer.step()
            avg_loss += loss.item() / len(train_dataloader)
            acc = get_accuracy(model , val_dataloader)
            tepoch.set_postfix(loss=avg_loss , val_acc = acc)

100%|██████████| 211/211 [03:05<00:00,  1.14batch/s, loss=0.974, val_acc=0.609]
100%|██████████| 211/211 [02:59<00:00,  1.17batch/s, loss=0.915, val_acc=0.685]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.878, val_acc=0.678]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.857, val_acc=0.698]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.844, val_acc=0.704]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.83, val_acc=0.735] 
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.822, val_acc=0.727]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.812, val_acc=0.72] 
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.795, val_acc=0.717]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.79, val_acc=0.72]  
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.776, val_acc=0.721]
100%|██████████| 211/211 [03:00<00:00,  1.17batch/s, loss=0.77, val_acc=0.711] 
100%|██████████| 211/211 [03:00<00:00,  

In [202]:
'''Save the model'''
torch.save(model.state_dict() , 'model_att_bert.pth')

With Just linear : 66.7 %

With attentional LSTM and BERT : 73 %