In [1]:
from transformers import BertTokenizer
import torch
import torch.nn as nn
from transformers import BertModel
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
# from crfseg import CRF
from torchcrf import CRF
import torch.nn.functional as F
import tqdm as tqdm

In [2]:
class MyDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        
        # iterate through the files in the data directory
        self.txtfiles = []
        self.annfiles = []

        for file in os.listdir(data_dir):
            if file.endswith(".txt"):
                self.txtfiles.append(file)
        
        self.tokeniser = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def tokenise(self , text):
        tokens = []  # List to store tokens
        starting_offsets = []  # List to store starting offsets
        current_token = ''  # Variable to store current token
        offset = 0  # Starting offset

        for char in text:
            if char == ' ':
                if current_token:  # If token is not empty
                    tokens.append(current_token.lower())  # Append token in lowercase
                    starting_offsets.append(offset - len(current_token))  # Store starting offset
                    current_token = ''  # Reset current token
                offset += 1  # Move offset to next character
            else:
                current_token += char  # Append character to current token
                offset += 1  # Move offset to next character

        # Handling the last token if it exists after the loop ends
        if current_token:
            tokens.append(current_token.lower())  # Append token in lowercase
            starting_offsets.append(offset - len(current_token))  # Store starting offset

        return starting_offsets , tokens

    def __len__(self):
        return len(self.txtfiles)

    def __getitem__(self, index):
        txtfile = self.txtfiles[index]
        sampleid = txtfile.split(".")[0]
        
        # read the text file
        with open(os.path.join(self.data_dir, txtfile), 'r') as file:
            txt = file.read()
        
        # read the annotation file
        annfilename = sampleid + ".ann"
        with open(os.path.join(self.data_dir, annfilename), 'r') as file:
            ann = file.read()
        
        offsets , tokenisedtxt = self.tokenise(txt)
        tagslist = np.zeros(len(tokenisedtxt))
        # now iterate through the ann file , in each line , divide into spaces and get the last word 
        # make tagslist[i] = 1 if the word is in the tokenisedtxt
        for line in ann.split('\n'):
            if line == '':
                continue
            words = line.split()
            if words[0][0] != 'T':
                continue

            ssofset = words[2]
            endoffset = words[3]

            # add a 1 to each index of tagslist for indexes where offset is between ssofset and endoffset (including both)
            for i in range(len(offsets)):
                if offsets[i] >= int(ssofset) and offsets[i] <= int(endoffset):
                    tagslist[i] = 1
        
        tokenisedids = self.tokeniser.convert_tokens_to_ids(tokenisedtxt)
        return torch.tensor(tokenisedids), torch.tensor(tagslist)
    
    def collate_fn(self , batch):
        # batch is a list of tuples
        # each tuple has 2 tensors , one for tokenisedids and one for tagslist
        # we need to return a tensor of tokenisedids and a tensor of tagslist
        tokenisedids = []
        tagslist = []
        for tup in batch:
            tokenisedids.append(torch.tensor(tup[0]))
            tagslist.append(torch.tensor(tup[1]))
        
        tokenisedids = torch.nn.utils.rnn.pad_sequence(tokenisedids , batch_first=True , padding_value=0) 
        tagslist = torch.nn.utils.rnn.pad_sequence(tagslist , batch_first=True , padding_value=0)

        tokenisedids = tokenisedids.type(torch.LongTensor)
        tagslist = tagslist.type(torch.LongTensor)
        
        return tokenisedids , tagslist


In [3]:
train_dataset = MyDataset('Data/train2')
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True , collate_fn=train_dataset.collate_fn)

## HELLO SARTHAK

In [5]:
class BiLSTMCRF(nn.Module):
    def __init__(self, num_labels):
        super(BiLSTMCRF, self).__init__()
        
        # Load pre-trained SciBERT embeddings
        self.bert = BertModel.from_pretrained("allenai/scibert_scivocab_uncased")
        
        # BiLSTM layers
        self.bilstm = nn.LSTM(input_size=768, hidden_size=96, bidirectional=True, batch_first=True)
        self.bilstm2 = nn.LSTM(input_size=192, hidden_size=48, bidirectional=True, batch_first=True)
        self.bilstm3 = nn.LSTM(input_size=96, hidden_size=24, bidirectional=True, batch_first=True)
        
        # Linear layer for downsizing
        self.linear = nn.Linear(48, num_labels)
        
        # CRF layer
        self.crf = CRF(num_labels , batch_first=True)
        
    def forward(self, input_ids , gt_tags = None):
        # Get SciBERT embeddings
        bert_outputs = self.bert(input_ids)[0]
        
        # Apply BiLSTM layers
        lstm_out, _ = self.bilstm(bert_outputs)
        lstm_out, _ = self.bilstm2(lstm_out)
        lstm_out, _ = self.bilstm3(lstm_out)
        
        linear_out = self.linear(lstm_out)
        
        # Apply CRF layer
        if gt_tags is not None:
            loss = -self.crf(linear_out , gt_tags)
        else:
            crf_out = self.crf.decode(linear_out)
            return crf_out
        
        return loss

In [6]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [17]:
model = BiLSTMCRF(5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
num_epochs = 20

for epoch in range(num_epochs):
    with tqdm.tqdm(train_dataloader, unit="batch") as tepoch:
        avg_loss = 0
        for tokenisedids , tagslist in tepoch:
            tokenisedids = tokenisedids.to(device)
            tagslist = tagslist.to(device)
            model.train()
            optimizer.zero_grad()
            loss = model(tokenisedids , tagslist)
            loss.backward()
            optimizer.step()
            avg_loss += loss.item() / len(train_dataloader)
            tepoch.set_postfix(loss=avg_loss)

100%|██████████| 175/175 [00:38<00:00,  4.56batch/s, loss=181]
100%|██████████| 175/175 [00:43<00:00,  4.04batch/s, loss=179]
100%|██████████| 175/175 [00:43<00:00,  4.04batch/s, loss=178]
100%|██████████| 175/175 [00:44<00:00,  3.97batch/s, loss=177]
100%|██████████| 175/175 [00:43<00:00,  3.98batch/s, loss=176] 
100%|██████████| 175/175 [00:42<00:00,  4.14batch/s, loss=175]
100%|██████████| 175/175 [00:43<00:00,  4.00batch/s, loss=174] 
100%|██████████| 175/175 [00:43<00:00,  4.05batch/s, loss=173] 
100%|██████████| 175/175 [00:43<00:00,  4.01batch/s, loss=172] 
100%|██████████| 175/175 [00:44<00:00,  3.96batch/s, loss=171] 
100%|██████████| 175/175 [00:42<00:00,  4.08batch/s, loss=170] 
100%|██████████| 175/175 [00:42<00:00,  4.14batch/s, loss=169] 
100%|██████████| 175/175 [00:43<00:00,  4.05batch/s, loss=168] 
100%|██████████| 175/175 [00:43<00:00,  4.01batch/s, loss=169] 
100%|██████████| 175/175 [00:43<00:00,  4.05batch/s, loss=168] 
100%|██████████| 175/175 [00:43<00:00,  3.99b

In [19]:
# save the model
torch.save(model.state_dict(), 'model2.pth')