In [7]:
import pandas as pd
import torch
import re
from tqdm.notebook import trange, tqdm
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# Labels

In [8]:
from labels import get_labels

_, labels, id2label, label2id = get_labels()


# Tokenizer

In [9]:
from transformers import AutoTokenizer, BertModel

tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

for id in [1, 3, 0, 2, 4]:
    print(f"{id}: {tokenizer.decode(id)}")

1: [UNK]
3: [SEP]
0: [PAD]
2: [CLS]
4: [MASK]


In [10]:
# relations = [f"[RELATION{i}]" for i in range(1, 10)]
# special_tokens_dict = {'additional_special_tokens': ['[ENTITY]', '[/ENTITY]', '[RELATION]', '[/RELATION]', '[SRC]', '[TGT]'] + relations}
# special_tokens_dict = {'additional_special_tokens': ['[B-Gene]', '[B-Disease]', '[B-Chemical]', '[IN]', '[OUT]']}
# num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
# print('We have added', num_added_toks, 'tokens')
# tokenizer.save_pretrained("NER_model_tokenizer")
# bert_model.resize_token_embeddings(len(tokenizer))

# dataset pre-processing

In [11]:
from data_preprocessing import all_line_of_pmid, NER_preprocess_function, make_dataset

In [12]:
train_file_path = 'data/BioRED/processed/train.tsv'
valid_file_path = 'data/BioRED/processed/dev.tsv'

In [13]:
pandas_data = pd.read_csv(train_file_path, delimiter="\t", header=None)
pmid, start, end = all_line_of_pmid(pandas_data, 0)
pandas_data.iloc[end - 5:end + 3, :]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
23,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3630,rs74805019,False,1,Hepatocyte nuclear factor-6 : associations bet...,,
24,10491763,ChemicalEntity,GeneOrGeneProduct,D005947,6927,False,2,Hepatocyte nuclear factor-6 : associations bet...,,
25,10491763,ChemicalEntity,GeneOrGeneProduct,D005947,3630,True,0,Hepatocyte nuclear factor-6 : associations bet...,Positive_Correlation,No
26,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3630,3651,False,2,Hepatocyte nuclear factor-6 : associations bet...,,
27,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3175,3651,False,1,@GeneOrGeneProductSrc$ Hepatocyte nuclear fact...,,
28,10661407,ChemicalEntity,GeneOrGeneProduct,D008358,50489,True,0,@GeneOrGeneProductTgt$ Langerin @/GeneOrGenePr...,Bind,Novel
29,10788334,GeneOrGeneProduct,GeneOrGeneProduct,672,c|DEL|4153|A,False,5,Founder mutations in the @GeneOrGeneProductSrc...,,
30,10788334,DiseaseOrPhenotypicFeature,GeneOrGeneProduct,D001943,rs28897672,False,4,Founder mutations in the BRCA1 gene in Polish ...,Positive_Correlation,Novel


In [14]:
pandas_data.iloc[0: 27, :]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,10491763,DiseaseOrPhenotypicFeature,GeneOrGeneProduct,D003924,3630,True,0,Hepatocyte nuclear factor-6 : associations bet...,,
1,10491763,ChemicalEntity,GeneOrGeneProduct,D005947,rs74805019,True,0,Hepatocyte nuclear factor-6 : associations bet...,,
2,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3175,3172,False,1,@GeneOrGeneProductSrc$ Hepatocyte nuclear fact...,,
3,10491763,DiseaseOrPhenotypicFeature,GeneOrGeneProduct,D003924,6927,True,0,Hepatocyte nuclear factor-6 : associations bet...,,
4,10491763,ChemicalEntity,GeneOrGeneProduct,D005947,3651,False,2,Hepatocyte nuclear factor-6 : associations bet...,,
5,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3172,rs74805019,False,2,Hepatocyte nuclear factor-6 : associations bet...,,
6,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3630,3172,False,2,Hepatocyte nuclear factor-6 : associations bet...,,
7,10491763,DiseaseOrPhenotypicFeature,GeneOrGeneProduct,D003924,rs74805019,True,0,Hepatocyte nuclear factor-6 : associations bet...,,
8,10491763,DiseaseOrPhenotypicFeature,GeneOrGeneProduct,D003924,3651,True,0,Hepatocyte nuclear factor-6 : associations bet...,,
9,10491763,GeneOrGeneProduct,GeneOrGeneProduct,3175,3630,True,0,@GeneOrGeneProductSrc$ Hepatocyte nuclear fact...,,


In [45]:
# iterate pandas_data and print the column 3 for each row


start = 0
while start < (len(pandas_data) - 1):
    pmid, start, end = all_line_of_pmid(pandas_data, start)

    # for i in range(start, end):
    #     pandas_data.iloc[i, 3]
    # filter pandas_data.iloc[start:end, :] with rows when colunm 8 is not None
    rows = []
    for i in range(start, end):
        if pandas_data.iloc[i, 8] != "None":
            rows.append(i)
    
    if len(rows) > 1:
        for i, row in enumerate(rows):
            try:
                for row2 in rows[i + 1:]:
                    if pandas_data.iloc[row, 3] == pandas_data.iloc[row2, 4]:
                        if pandas_data.iloc[row, 8] == pandas_data.iloc[row2, 8]:
                            print(f"pinch! with same relation but the same entity both as SRC and TGT: row:{row} and {row2}")

                    if pandas_data.iloc[row, 3] == pandas_data.iloc[row2, 3]:
                        if pandas_data.iloc[row, 8] == pandas_data.iloc[row2, 8]:
                            print(f"warning, an entity has the same relation for different entities: row:{row} and {row2}")
            except:
                pass

    # break
   


    start = end

pinch! with same relation but the same entity both as SRC and TGT: row:12 and 20
pinch! with same relation but the same entity both as SRC and TGT: row:588 and 593
pinch! with same relation but the same entity both as SRC and TGT: row:589 and 600
pinch! with same relation but the same entity both as SRC and TGT: row:590 and 600
pinch! with same relation but the same entity both as SRC and TGT: row:635 and 642
pinch! with same relation but the same entity both as SRC and TGT: row:635 and 646
pinch! with same relation but the same entity both as SRC and TGT: row:825 and 873
pinch! with same relation but the same entity both as SRC and TGT: row:828 and 873
pinch! with same relation but the same entity both as SRC and TGT: row:830 and 873
pinch! with same relation but the same entity both as SRC and TGT: row:1394 and 1395
pinch! with same relation but the same entity both as SRC and TGT: row:1612 and 1970
pinch! with same relation but the same entity both as SRC and TGT: row:1925 and 1978


In [17]:
# no such situation:
#     id1  id2  None
#     but
#     id2  id1  a certain relation
start = 0
while start < (len(pandas_data) - 1):
    pmid, start, end = all_line_of_pmid(pandas_data, start)

    # for i in range(start, end):
    #     pandas_data.iloc[i, 3]
    # filter pandas_data.iloc[start:end, :] with rows when colunm 8 is not None
    rows = []
    test_dict = {}
    for i in range(start, end):
        if pandas_data.iloc[i, 3] not in test_dict.keys():
            test_dict[pandas_data.iloc[i, 3]] = {}
        
        test_dict[pandas_data.iloc[i, 3]][pandas_data.iloc[i, 4]] = pandas_data.iloc[i, 8]

    for key in test_dict.keys():
        for key2 in test_dict[key].keys():
            if test_dict[key][key2] != "None":
                try:
                    if test_dict[key2][key] != "None":
                        print(111)

                except:
                    pass
   
    start = end

In [18]:
# No multiple relations for the same entity pair
start = 0
while start < (len(pandas_data) - 1):
    pmid, start, end = all_line_of_pmid(pandas_data, start)

    # for i in range(start, end):
    #     pandas_data.iloc[i, 3]
    # filter pandas_data.iloc[start:end, :] with rows when colunm 8 is not None
    rows = []
    test_dict = {}
    for i in range(start, end):
        if pandas_data.iloc[i, 3] not in test_dict.keys():
            test_dict[pandas_data.iloc[i, 3]] = {}
        
        if pandas_data.iloc[i, 4] not in test_dict[pandas_data.iloc[i, 3]].keys():
            test_dict[pandas_data.iloc[i, 3]][pandas_data.iloc[i, 4]] = pandas_data.iloc[i, 8]
        else:
            print("two ids have multiple reations")

    start = end

In [44]:
pandas_data.iloc[[12, 20], :]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
12,10491763,DiseaseOrPhenotypicFeature,GeneOrGeneProduct,D003924,3175,True,0,@GeneOrGeneProductTgt$ Hepatocyte nuclear fact...,Association,No
20,10491763,ChemicalEntity,DiseaseOrPhenotypicFeature,D005947,D003924,True,0,Hepatocyte nuclear factor-6 : associations bet...,Association,No


In [8]:
train_data = make_dataset(train_file_path, lower=True, ignore_relations=['None', 'Association'], NER=True, NER_in=True)
valid_data = make_dataset(valid_file_path, lower=True, ignore_relations=['None', 'Association'], NER=True, NER_in=True)

In [9]:
from datasets import DatasetDict, Dataset
train_dataset_raw = Dataset.from_dict(train_data)
valid_dataset_raw = Dataset.from_dict(valid_data)

In [10]:
dataset = DatasetDict({
    "train": train_dataset_raw,
    "valid": valid_dataset_raw
})

In [11]:
dataset

DatasetDict({
    train: Dataset({
        features: ['pmids', 'inputs', 'outputs'],
        num_rows: 398
    })
    valid: Dataset({
        features: ['pmids', 'inputs', 'outputs'],
        num_rows: 98
    })
})

In [12]:
label2id

{'[PAD]': 0,
 '[STOP]': 1,
 '[CLS]': 2,
 '[B-Gene]': 3,
 '[B-Disease]': 4,
 '[B-Chemical]': 5,
 '[IN-Gene]': 6,
 '[IN-Disease]': 7,
 '[IN-Chemical]': 8,
 '[OUT]': 9}

In [13]:
id2label

{0: '[PAD]',
 1: '[STOP]',
 2: '[CLS]',
 3: '[B-Gene]',
 4: '[B-Disease]',
 5: '[B-Chemical]',
 6: '[IN-Gene]',
 7: '[IN-Disease]',
 8: '[IN-Chemical]',
 9: '[OUT]'}

In [14]:
# tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["inputs", "outputs", "pmids"])
tokenized_datasets = dataset.map(lambda example: NER_preprocess_function(example, tokenizer=tokenizer, bert=True, NER_in=True), batched=False, remove_columns=["inputs", "outputs", "pmids"])

Map:   0%|          | 0/398 [00:00<?, ? examples/s]

Map:   0%|          | 0/98 [00:00<?, ? examples/s]

In [15]:
len(tokenized_datasets["train"]["input_ids"][0])

512

In [16]:
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'labels'])

In [17]:
tokenized_datasets['train']['labels'][3]

tensor([2, 5, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 4, 7, 7, 7, 9, 5, 8, 8, 8, 8, 9,
        9, 9, 5, 8, 9, 9, 9, 9, 9, 9, 4, 7, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 4, 7, 7, 7, 9, 9, 9, 9, 5, 8, 8, 8,
        8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 5, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 9, 9,
        5, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 5, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 8, 9, 5,
        8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 8, 8, 8, 8,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 8, 8, 8, 8, 9, 9, 9, 9,
        9, 9, 9, 9, 5, 8, 8, 4, 7, 7, 9, 9, 9, 9, 9, 9, 9, 5, 8, 8, 8, 8, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 4, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        5, 8, 8, 8, 8, 9, 9, 9, 4, 7, 7,

# Evaluate

In [18]:
import evaluate

metric = evaluate.load("seqeval")

In [19]:
preds = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
ref = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
metric.compute(predictions=preds, references=ref)

{'MISC': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'PER': {'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'number': 1},
 'overall_precision': 1.0,
 'overall_recall': 1.0,
 'overall_f1': 1.0,
 'overall_accuracy': 1.0}

In [20]:
import numpy as np


def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[id2label[l.item()] for l in label if l.item() != 0] for label in labels]
    true_predictions = [
        [id2label[p.item()] for (p, l) in zip(prediction, label) if l.item() != 0]
        for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

# model

In [None]:
# PEFT
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


In [None]:
class CrossAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_k = nn.Linear(dim, inner_dim , bias=False)
        self.to_v = nn.Linear(dim, inner_dim , bias = False)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x_qkv):
        b, n, _, h = *x_qkv.shape, self.heads

        k = self.to_k(x_qkv)
        # print(f"k: {k.shape}")
        k = rearrange(k, 'b n (h d) -> b h n d', h = h)
        # print(f"after rerange k: {k.shape}")

        v = self.to_v(x_qkv)
        v = rearrange(v, 'b n (h d) -> b h n d', h = h)

        q = self.to_q(x_qkv[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)



        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

In [None]:
from transformers.modeling_outputs import Seq2SeqLMOutput
class decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_lstm_layers, bidirectional=False):
        super(decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_lstm_layers, bidirectional=bidirectional, batch_first=False)
        # if bidirectional:
            # self.linear = nn.Linear(hidden_size * 2, hidden_size)
        # else:
            # self.linear = nn.Linear(hidden_size, hidden_size)
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=2, dropout=0.1, batch_first=False)
        self.norm = nn.LayerNorm(hidden_size)
        for p in self.attention.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for p in self.lstm.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, input_ids, m):
        x = self.embedding(input_ids)
        x = self.lstm(x)[0]
        # x = self.linear(x)
        x = self.attention(x, m, m)[0]
        x = self.norm(x)
        return x


class Encoder_Decoder_model(nn.Module):
    def __init__(self, encoder_model, vocab_size, hidden_size):
        # input for both encoder and decoder: (batch_size, seq_length <= 512, 768)
        super(Encoder_Decoder_model, self).__init__()
        self.vocab_size = vocab_size
        # encoder input and output: (batch_size, seq_length <= 512, 768)
        self.encoder = encoder_model
        self.decoder = decoder(vocab_size, 768, hidden_size, num_lstm_layers=1, bidirectional=False)
        # copy the embedding weight of the encoder to the decoder
        self.decoder.embedding.weight = self.encoder.get_input_embeddings().weight

        self.proj = nn.Linear(hidden_size, vocab_size)        
        for p in self.proj.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        # for p in self.decoder_embedding.parameters():
        #     if p.dim() > 1:
        #         nn.init.xavier_uniform_(p)

    def forward(self, 
                input_ids, 
                attention_mask=None, 
                token_type_ids=None, 
                decoder_input_ids=None,
                labels=None,
                return_dict=None):
        # encoder
        m = self.encoder(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True)[0]
        # decoder
        x = self.decoder(decoder_input_ids, m)
        x = self.proj(x)

        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            logits = x
            loss_fct = nn.CrossEntropyLoss()
            # print("logits.view(-1, self.vocab_size): ", logits.view(-1, self.vocab_size).shape)
            # print("labels.view(-1): ", labels.view(-1).shape)
            loss = loss_fct(logits.view(-1, self.vocab_size), labels.view(-1))
            # print("loss: ", loss)
        
        # for i in range(len(labels)):
        #     label = labels[i]
        #     if loss.ndim == 0:
        #         loss += los.item()
        #     else:
        #         loss[label.item()] += los[i].item()

        if not return_dict:
            if loss is not None:
                return tuple((loss, x, m))
            else:
                return tuple((x, m))
            
        return Seq2SeqLMOutput(
            loss=loss,
            logits=x
        )



In [None]:

class decoder(nn.Module):
    def __init__(self, decoder_vocab_size, hidden_size, num_lstm_layers, dropout=0., bidirectional=False):
        super(decoder, self).__init__()
        self.embedding = nn.Embedding(decoder_vocab_size, hidden_size)
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_lstm_layers, bidirectional=bidirectional, batch_first=True)
        if bidirectional:
            self.linear = nn.Linear(hidden_size * 2, hidden_size)
        else:
            self.linear = nn.Linear(hidden_size, hidden_size)
        self.attention = CrossAttention(dim=hidden_size, dropout=dropout)

        # initialization
        for p in self.embedding.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for p in self.attention.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        for p in self.lstm.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, input_ids, m):
        x = self.embedding(input_ids)
        x = self.lstm(x)
        # x = self.linear(x[0][:, -1, :].unsqueeze(1))
        x = self.linear(x[0])
        # print(f"x: {x.shape}")
        # make a outputs tensor with the same shape as the x but all zeros with float type
        output = torch.zeros(x.shape, dtype=torch.float).to(x.device)
        # using the [cls] token as representation of the sentences
        for n in range(x.shape[1]):
            x_qkv = torch.cat((x[:, n, :].unsqueeze(1), x[:, n, :].unsqueeze(1), m[:, 0, :].unsqueeze(1)), dim=1)
            # print(f"x_qkv: {x_qkv.shape}")
            # the output of self.attention(x_qkv) will be (batch_size, 1, hidden_size), replace the value of output[:, n, :]
            output[:, n, :] = self.attention(x_qkv).squeeze(1)
            # print(output)
        return output

In [None]:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id
    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
    return shifted_input_ids

In [None]:
from transformers.modeling_outputs import Seq2SeqLMOutput
class NER_Encoder_Decoder_model(nn.Module):
    def __init__(self, encoder_model, decoder_vocab_size, hidden_size):
        # input for both encoder and decoder: (batch_size, seq_length <= 512, 768)
        super(NER_Encoder_Decoder_model, self).__init__()
        self.decoder_vocab_size = decoder_vocab_size
        # encoder input and output: (batch_size, seq_length <= 512, 768)
        self.encoder = encoder_model
        self.decoder = decoder(decoder_vocab_size, hidden_size, num_lstm_layers=1, bidirectional=False, dropout=0.1)
        # copy the embedding weight of the encoder to the decoder
        # self.decoder.embedding.weight = self.encoder.get_input_embeddings().weight

        self.proj = nn.Linear(hidden_size, decoder_vocab_size)
        for p in self.proj.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, 
                input_ids, 
                attention_mask=None, 
                token_type_ids=None, 
                decoder_input_ids=None,
                labels=None,
                return_dict=None):
        # encoder
        m = self.encoder(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True)
        # decoder
        if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(labels, 0, 2)
                # print(f"no decoder inputs, \n decoder_input_ids: {decoder_input_ids}\nlabels: {labels}")
        x = self.decoder(decoder_input_ids, m[0])
        # print("after decoder: ", x.shape)
        x = self.proj(x)

        # Compute loss independent from decoder (as some shift the logits inside them)
        loss = None
        if labels is not None:
            logits = x
            loss_fct = nn.CrossEntropyLoss()
            # print("logits.view(-1, self.vocab_size): ", logits.view(-1, self.decoder_vocab_size).shape)
            # print("labels.view(-1): ", labels.view(-1).shape)
            loss = loss_fct(logits.view(-1, self.decoder_vocab_size), labels.view(-1))
            # print("loss: ", loss)
        
        # for i in range(len(labels)):
        #     label = labels[i]
        #     if loss.ndim == 0:
        #         loss += los.item()
        #     else:
        #         loss[label.item()] += los[i].item()

        if not return_dict:
            if loss is not None:
                return tuple((loss, x, m))
            else:
                return tuple((x, m))
            
        return Seq2SeqLMOutput(
            loss=loss,
            logits=x
        )

In [None]:
# from peft import LoraConfig, get_peft_model 


# class Encoder_Decoder_model_peft(nn.Module):
#     def __init__(self, encoder_model, vocab_size, hidden_size):
#         super(Encoder_Decoder_model_peft, self).__init__()
        
#         # encoder input and output: (batch_size, seq_length <= 512, 768)
#         self.encoder = encoder_model
        
#         self.decoder = decoder(vocab_size, hidden_size)

#         self.decoder.embedding.weight = self.encoder.get_input_embeddings().weight

#         # for p in self.decoder_embedding.parameters():
#         #     if p.dim() > 1:
#         #         nn.init.xavier_uniform_(p)

        
#         # initialization of peft

#         for name, param in self.encoder.named_parameters():
#             # print(name)
#             # Freeze the parameters except for embed_tokens and embed_positions
#             if 'word_embeddings' not in name and 'position_embeddings' not in name and 'token_type_embeddings' not in name:
#                 param.requires_grad = False
#             # else:
#                 # print("2")

#             if param.ndim == 1:
#                 # cast the small parameters (e.g. layernorm) to fp32 for stability
#                 param.data = param.data.to(torch.float32)



#         self.encoder.gradient_checkpointing_enable()  # reduce number of stored activations
#         self.encoder.enable_input_require_grads()

#         config = LoraConfig(
#             r=16,
#             lora_alpha=32,
#             target_modules=["query", "value"],
#             lora_dropout=0.02,
#             bias="none",
#             # task_type="SEQ_2_SEQ_LM"
#         )

#         self.encoder = get_peft_model(self.encoder, config)
#         print_trainable_parameters(self.encoder)

#     def forward(self, input_ids, attention_mask, token_type_ids, labels):
#         # encoder
#         m = self.encoder(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)['last_hidden_state']
#         # decoder
#         x = self.decoder(labels, m)[0]
#         return x

In [37]:
bert_model = BertModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
NER_model = NER_Encoder_Decoder_model(bert_model, len(NER_id_to_tag), 768)
print_trainable_parameters(NER_model)

trainable params: 116383496 || all params: 116383496 || trainable%: 100.0


In [38]:
# have a random tensor with long type and for the BERT model
input_ids = torch.randint(0, 1000, (2, 3), dtype=torch.long)
# adding pad token to the input tensor
input_ids_w_pad = torch.cat((input_ids, torch.zeros((2, 512 - 3), dtype=torch.long)), dim=1)
# have a mask tensor with long type and for the BERT model
attention_mask = torch.ones((2, 3), dtype=torch.long)
# adding pad token to the mask tensor
attention_mask_w_pad = torch.cat((attention_mask, torch.zeros((2, 512 - 3), dtype=torch.long)), dim=1)
# decoder_input_ids = torch.randint(0, 8, (2, 4), dtype=torch.long)
label = torch.randint(0, 4, (2, 5), dtype=torch.long)
decoder_input_ids = shift_tokens_right(label, 0, 2)
# output the last hidden state of the BERT model
with torch.no_grad():
    NER_model.eval()
    # encoder_outputs = NER_model.encoder(input_ids_w_pad, attention_mask = attention_mask_w_pad, return_dict=True)
    # decoder_outputs = NER_model.decoder(input_ids=decoder_input_ids, m=encoder_outputs[0])
    outputs = NER_model(input_ids_w_pad, attention_mask = attention_mask_w_pad, labels=label, return_dict=True)


# Classic Trainer

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [40]:
def train_loop(model, df_train, df_val):

    train_dataset = DataSequence(df_train)
    val_dataset = DataSequence(df_val)

    train_dataloader = DataLoader(train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, num_workers=4, batch_size=BATCH_SIZE)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    optimizer = nn.SGD(model.parameters(), lr=LEARNING_RATE)

    if use_cuda:
        model = model.cuda()

    best_acc = 0
    best_loss = 1000

    for epoch_num in range(EPOCHS):

        total_acc_train = 0
        total_loss_train = 0

        model.train()

        for train_data, train_label in tqdm(train_dataloader):

            train_label = train_label.to(device)
            mask = train_data['attention_mask'].squeeze(1).to(device)
            input_id = train_data['input_ids'].squeeze(1).to(device)

            optimizer.zero_grad()
            loss, logits = model(input_id, mask, train_label)

            for i in range(logits.shape[0]):

              logits_clean = logits[i][train_label[i] != -100]
              label_clean = train_label[i][train_label[i] != -100]

              predictions = logits_clean.argmax(dim=1)
              acc = (predictions == label_clean).float().mean()
              total_acc_train += acc
              total_loss_train += loss.item()

            loss.backward()
            optimizer.step()

        model.eval()

        total_acc_val = 0
        total_loss_val = 0

        for val_data, val_label in val_dataloader:

            val_label = val_label.to(device)
            mask = val_data['attention_mask'].squeeze(1).to(device)
            input_id = val_data['input_ids'].squeeze(1).to(device)

            loss, logits = model(input_id, mask, val_label)

            for i in range(logits.shape[0]):

              logits_clean = logits[i][val_label[i] != -100]
              label_clean = val_label[i][val_label[i] != -100]

              predictions = logits_clean.argmax(dim=1)
              acc = (predictions == label_clean).float().mean()
              total_acc_val += acc
              total_loss_val += loss.item()

        val_accuracy = total_acc_val / len(df_val)
        val_loss = total_loss_val / len(df_val)

        print(
            f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .3f} | Accuracy: {total_acc_train / len(df_train): .3f} | Val_Loss: {total_loss_val / len(df_val): .3f} | Accuracy: {total_acc_val / len(df_val): .3f}')

LEARNING_RATE = 5e-3
EPOCHS = 5
BATCH_SIZE = 2

model = BertModel()
train_loop(model, df_train, df_val)

In [None]:
labels.shape

torch.Size([1])

# S2S Trainer

In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import default_data_collator

In [None]:
from transformers import EarlyStoppingCallback

In [None]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="PubmedBERT-CrossAttention-LSTM-NER",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [45]:
NER_model.load_state_dict(torch.load("NER_model/final-afterCrossAttentionModified/pytorch_model.bin"))

<All keys matched successfully>

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="NER_model",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    # per_device_train_batch_size=4,
    # per_device_eval_batch_size=4,
    weight_decay=0.01,
    num_train_epochs=50,
    predict_with_generate=True,
    fp16=False,
    report_to="wandb",
    remove_unused_columns=False,
    save_steps=500,
    # push_to_hub=True,
)
# early_stop = EarlyStoppingCallback(2, 1.0)

trainer = Seq2SeqTrainer(
    model=NER_model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    # callbacks=[early_stop]
    # compute_metrics=compute_metrics,
)

trainer.train()

 10%|█         | 500/5000 [07:10<1:02:40,  1.20it/s]

{'loss': 0.3041, 'learning_rate': 1.8e-05, 'epoch': 5.0}


                                                    
 10%|█         | 500/5000 [07:16<1:02:40,  1.20it/s]

{'eval_loss': 0.3470618724822998, 'eval_runtime': 6.9428, 'eval_samples_per_second': 14.115, 'eval_steps_per_second': 3.601, 'epoch': 5.0}


 20%|██        | 1000/5000 [14:37<55:44,  1.20it/s] 

{'loss': 0.3005, 'learning_rate': 1.6000000000000003e-05, 'epoch': 10.0}


                                                   
 20%|██        | 1000/5000 [14:44<55:44,  1.20it/s]

{'eval_loss': 0.346300333738327, 'eval_runtime': 7.0295, 'eval_samples_per_second': 13.941, 'eval_steps_per_second': 3.556, 'epoch': 10.0}


 30%|███       | 1500/5000 [22:09<49:01,  1.19it/s]  

{'loss': 0.2977, 'learning_rate': 1.4e-05, 'epoch': 15.0}


                                                   
 30%|███       | 1500/5000 [22:16<49:01,  1.19it/s]

{'eval_loss': 0.34617459774017334, 'eval_runtime': 6.9766, 'eval_samples_per_second': 14.047, 'eval_steps_per_second': 3.583, 'epoch': 15.0}


 40%|████      | 2000/5000 [29:44<41:51,  1.19it/s]  

{'loss': 0.2947, 'learning_rate': 1.2e-05, 'epoch': 20.0}


                                                   
 40%|████      | 2000/5000 [29:51<41:51,  1.19it/s]

{'eval_loss': 0.3473012447357178, 'eval_runtime': 6.8936, 'eval_samples_per_second': 14.216, 'eval_steps_per_second': 3.627, 'epoch': 20.0}


 50%|█████     | 2500/5000 [37:14<34:59,  1.19it/s]  

{'loss': 0.292, 'learning_rate': 1e-05, 'epoch': 25.0}


                                                   
 50%|█████     | 2500/5000 [37:21<34:59,  1.19it/s]

{'eval_loss': 0.348152756690979, 'eval_runtime': 6.9325, 'eval_samples_per_second': 14.136, 'eval_steps_per_second': 3.606, 'epoch': 25.0}


 60%|██████    | 3000/5000 [44:47<27:53,  1.20it/s]  

{'loss': 0.2896, 'learning_rate': 8.000000000000001e-06, 'epoch': 30.0}


                                                   
 60%|██████    | 3000/5000 [44:54<27:53,  1.20it/s]

{'eval_loss': 0.34909626841545105, 'eval_runtime': 6.8855, 'eval_samples_per_second': 14.233, 'eval_steps_per_second': 3.631, 'epoch': 30.0}


 70%|███████   | 3500/5000 [52:16<20:50,  1.20it/s]  

{'loss': 0.2869, 'learning_rate': 6e-06, 'epoch': 35.0}


                                                   
 70%|███████   | 3500/5000 [52:23<20:50,  1.20it/s]

{'eval_loss': 0.34761694073677063, 'eval_runtime': 6.895, 'eval_samples_per_second': 14.213, 'eval_steps_per_second': 3.626, 'epoch': 35.0}


 80%|████████  | 4000/5000 [59:46<13:54,  1.20it/s]  

{'loss': 0.2848, 'learning_rate': 4.000000000000001e-06, 'epoch': 40.0}


                                                   
 80%|████████  | 4000/5000 [59:53<13:54,  1.20it/s]

{'eval_loss': 0.3482430875301361, 'eval_runtime': 6.9055, 'eval_samples_per_second': 14.192, 'eval_steps_per_second': 3.62, 'epoch': 40.0}


 90%|█████████ | 4500/5000 [1:07:16<06:59,  1.19it/s]  

{'loss': 0.2833, 'learning_rate': 2.0000000000000003e-06, 'epoch': 45.0}


                                                     
 90%|█████████ | 4500/5000 [1:07:23<06:59,  1.19it/s]

{'eval_loss': 0.35049009323120117, 'eval_runtime': 6.8983, 'eval_samples_per_second': 14.206, 'eval_steps_per_second': 3.624, 'epoch': 45.0}


100%|██████████| 5000/5000 [1:14:46<00:00,  1.20it/s]

{'loss': 0.2822, 'learning_rate': 0.0, 'epoch': 50.0}


                                                     
100%|██████████| 5000/5000 [1:14:52<00:00,  1.20it/s]

{'eval_loss': 0.3497946262359619, 'eval_runtime': 6.8795, 'eval_samples_per_second': 14.245, 'eval_steps_per_second': 3.634, 'epoch': 50.0}


100%|██████████| 5000/5000 [1:15:02<00:00,  1.11it/s]

{'train_runtime': 4502.7448, 'train_samples_per_second': 4.42, 'train_steps_per_second': 1.11, 'train_loss': 0.2915934600830078, 'epoch': 50.0}





TrainOutput(global_step=5000, training_loss=0.2915934600830078, metrics={'train_runtime': 4502.7448, 'train_samples_per_second': 4.42, 'train_steps_per_second': 1.11, 'train_loss': 0.2915934600830078, 'epoch': 50.0})

In [47]:
import wandb
wandb.finish()

0,1
eval/loss,▂▁▁▃▄▆▃▄█▇
eval/runtime,▄█▆▂▃▁▂▂▂▁
eval/samples_per_second,▅▁▃▇▅█▇▇▇█
eval/steps_per_second,▅▁▃▇▅█▇▇▇█
train/epoch,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇███
train/global_step,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇███
train/learning_rate,█▇▆▆▅▄▃▃▂▁
train/loss,█▇▆▅▄▃▃▂▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.34979
eval/runtime,6.8795
eval/samples_per_second,14.245
eval/steps_per_second,3.634
train/epoch,50.0
train/global_step,5000.0
train/learning_rate,0.0
train/loss,0.2822
train/total_flos,0.0
train/train_loss,0.29159


In [48]:
trainer.save_model("NER_model/afterCrossAttentionModified-100epochs")
# tokenizer.save_pretrained("NED_model_tokenizer")

# only BERT

In [65]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)


model.to("cuda")

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForToken

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [21]:
from transformers import AutoModelForTokenClassification
from transformers.modeling_outputs import TokenClassifierOutput
class Model(nn.Module):
    def __init__(self, input_size=768, hidden_size=128, num_layers=1, dropout=0.1, num_classes=None):
        super(Model,self).__init__()
        self.num_class = num_classes
        self.bert=BertModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
        for param in self.bert.parameters():
            param.requires_grad = True
        self.lstm=nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.dropout=nn.Dropout(dropout)

        self.fc=nn.Linear(hidden_size*2, num_classes)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, return_dict=None):

        encoder= self.bert(input_ids, attention_mask, token_type_ids, return_dict=True)
        out = self.lstm(encoder[0])
        out = self.dropout(out[0])
        out = self.fc(out)



        loss = None
        if labels is not None:
            logits = out
            loss_fct = nn.CrossEntropyLoss()
            # print("logits.view(-1, self_num_class): ", logits.view(-1, self.decoder_vocab_size).shape)
            # print("labels.view(-1): ", labels.view(-1).shape)
            loss = loss_fct(logits.view(-1, self.num_class), labels.view(-1))
        
        if not return_dict:
            # output = tuple((logits, out))
            return tuple((loss, out)) if loss is not None else out

        return TokenClassifierOutput(
            loss=loss,
            logits=logits
        )

In [22]:
model = Model(num_classes=len(labels))
# model.to("cuda")



In [23]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenized_datasets["train"][i] for i in range(2)])
batch["labels"]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


tensor([[2, 3, 6,  ..., 0, 0, 0],
        [2, 3, 6,  ..., 0, 0, 0]])

In [24]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="PubmedBERT-FT-NER",
    # notes="PubmedBERT-FT-NER_w_NERin_10epochs",
    name="NERin-lstm_15epochs",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [25]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="NER_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=15,
    weight_decay=0.01,
    report_to="wandb",
    # per_device_train_batch_size=4,
    # per_device_eval_batch_size=4,
    auto_find_batch_size=True,
    # push_to_hub=True,
)

In [26]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)
trainer.train()



  0%|          | 0/750 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.513088583946228, 'eval_precision': 0.14133627019089573, 'eval_recall': 0.062116811874798325, 'eval_f1': 0.08630351939027124, 'eval_accuracy': 0.7910879463902174, 'eval_runtime': 1.9631, 'eval_samples_per_second': 49.922, 'eval_steps_per_second': 6.622, 'epoch': 1.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.3456386625766754, 'eval_precision': 0.4476504534212696, 'eval_recall': 0.3504356243949661, 'eval_f1': 0.3931221719457013, 'eval_accuracy': 0.8495088633695123, 'eval_runtime': 1.9713, 'eval_samples_per_second': 49.713, 'eval_steps_per_second': 6.595, 'epoch': 2.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.25266504287719727, 'eval_precision': 0.691676845227709, 'eval_recall': 0.6395611487576638, 'eval_f1': 0.6645988766870651, 'eval_accuracy': 0.9039777771413844, 'eval_runtime': 1.9893, 'eval_samples_per_second': 49.264, 'eval_steps_per_second': 6.535, 'epoch': 3.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.20617817342281342, 'eval_precision': 0.7571142901092139, 'eval_recall': 0.7941271377863827, 'eval_f1': 0.7751791479644066, 'eval_accuracy': 0.9349064978951287, 'eval_runtime': 1.9755, 'eval_samples_per_second': 49.609, 'eval_steps_per_second': 6.581, 'epoch': 4.0}
{'loss': 0.4008, 'learning_rate': 6.640000000000001e-06, 'epoch': 5.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.1704038679599762, 'eval_precision': 0.7800031051079025, 'eval_recall': 0.8105840593739916, 'eval_f1': 0.7949996043990822, 'eval_accuracy': 0.9493685386179443, 'eval_runtime': 1.9848, 'eval_samples_per_second': 49.376, 'eval_steps_per_second': 6.55, 'epoch': 5.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15812565386295319, 'eval_precision': 0.7809567758806337, 'eval_recall': 0.819135204904808, 'eval_f1': 0.7995905189384991, 'eval_accuracy': 0.9522323090581059, 'eval_runtime': 1.9965, 'eval_samples_per_second': 49.085, 'eval_steps_per_second': 6.511, 'epoch': 6.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15244361758232117, 'eval_precision': 0.7838669950738916, 'eval_recall': 0.8215553404323975, 'eval_f1': 0.8022687884039703, 'eval_accuracy': 0.9530914401901543, 'eval_runtime': 1.9836, 'eval_samples_per_second': 49.404, 'eval_steps_per_second': 6.554, 'epoch': 7.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 1.9865, 'eval_samples_per_second': 49.332, 'eval_steps_per_second': 6.544, 'epoch': 8.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 2.0013, 'eval_samples_per_second': 48.967, 'eval_steps_per_second': 6.496, 'epoch': 9.0}
{'loss': 0.1093, 'learning_rate': 0.0, 'epoch': 10.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 1.996, 'eval_samples_per_second': 49.099, 'eval_steps_per_second': 6.513, 'epoch': 10.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 1.9799, 'eval_samples_per_second': 49.497, 'eval_steps_per_second': 6.566, 'epoch': 11.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 1.9974, 'eval_samples_per_second': 49.064, 'eval_steps_per_second': 6.509, 'epoch': 12.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 1.9999, 'eval_samples_per_second': 49.002, 'eval_steps_per_second': 6.5, 'epoch': 13.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 2.0056, 'eval_samples_per_second': 48.863, 'eval_steps_per_second': 6.482, 'epoch': 14.0}
{'loss': 0.104, 'learning_rate': 0.0, 'epoch': 15.0}


  0%|          | 0/13 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 0.15118572115898132, 'eval_precision': 0.7842956120092379, 'eval_recall': 0.8218780251694094, 'eval_f1': 0.8026471283384543, 'eval_accuracy': 0.9537787450957931, 'eval_runtime': 1.9897, 'eval_samples_per_second': 49.254, 'eval_steps_per_second': 6.534, 'epoch': 15.0}
{'train_runtime': 535.233, 'train_samples_per_second': 11.154, 'train_steps_per_second': 2.803, 'train_loss': 0.2046724319458008, 'epoch': 15.0}


TrainOutput(global_step=1500, training_loss=0.2046724319458008, metrics={'train_runtime': 535.233, 'train_samples_per_second': 11.154, 'train_steps_per_second': 2.803, 'train_loss': 0.2046724319458008, 'epoch': 15.0})

In [27]:
import wandb
wandb.finish()
trainer.save_model("NER_model/NERin-lstm_15epochs")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/accuracy,▁▄▆▇███████████
eval/f1,▁▄▇████████████
eval/loss,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁
eval/precision,▁▄▇████████████
eval/recall,▁▄▆████████████
eval/runtime,▁▂▅▃▅▇▄▅▇▆▄▇▇█▅
eval/samples_per_second,█▇▄▆▄▂▅▄▂▃▅▂▂▁▄
eval/steps_per_second,█▇▄▆▄▂▅▄▂▃▅▂▂▁▄
train/epoch,▁▁▂▃▃▃▃▄▅▅▅▅▆▇▇▇███
train/global_step,▁▁▂▃▃▃▃▄▅▅▅▅▆▇▇▇███

0,1
eval/accuracy,0.95378
eval/f1,0.80265
eval/loss,0.15119
eval/precision,0.7843
eval/recall,0.82188
eval/runtime,1.9897
eval/samples_per_second,49.254
eval/steps_per_second,6.534
train/epoch,15.0
train/global_step,1500.0


# Inference

In [34]:
from transformers import AutoTokenizer, BertModel, TokenClassifierOutput
import torch

# tokenizer = AutoTokenizer.from_pretrained("NER_model_tokenizer")
# model.resize_token_embeddings(len(tokenizer))
# bert_model = BertModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
NED_model = trainer.model
# NED_model.to("cpu")
NED_model.eval()

In [54]:
# load test dataset

test_file_path = 'data/BioRED/processed/test.tsv'
test_data = make_dataset(test_file_path, lower=True, ignore_relations=['None', 'Association'], NER=True, NER_in=True)
from datasets import DatasetDict, Dataset
test_dataset_raw = Dataset.from_dict(test_data)

In [56]:
# test_dataset = test_dataset_raw.map(NER_preprocess_function, batched=False)
# with bert only:
test_dataset = test_dataset_raw.map(lambda example: NER_preprocess_function(example, bert=True, NER_in=True, tokenizer=tokenizer), batched=False, remove_columns=["inputs", "outputs", "pmids"])
test_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [71]:
tokenized_datasets['train']['labels'][0]

tensor([2, 3, 6, 6, 6, 6, 9, 9, 9, 9, 9, 9, 4, 7, 7, 9, 9, 9, 9, 9, 9, 9, 3, 9,
        9, 9, 9, 9, 3, 6, 6, 6, 6, 6, 6, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        4, 7, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 3, 6, 6, 9, 9, 9, 9,
        9, 9, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 9, 9, 9, 3, 9, 9, 5, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 3, 6, 6, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 3, 6, 9, 9, 9, 9, 9, 3, 6, 9, 3, 6, 6, 9, 3, 6, 6, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 4, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 3, 6, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 3, 6, 6, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 4,
        7, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 9, 9, 9, 9, 9, 5, 9, 9, 9, 4,
        7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9,

In [70]:
test_dataset['labels'][0]

tensor([2, 9, 9, 3, 6, 9, 9, 9, 9, 9, 9, 9, 4, 7, 7, 9, 9, 9, 9, 4, 7, 7, 9, 9,
        9, 9, 4, 7, 7, 9, 4, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 4, 9, 9, 9, 9, 4, 7, 9, 4, 7, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 4, 7, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 5, 9, 9, 9, 9, 9, 9, 5, 8, 8, 9, 9, 9, 9, 9, 4, 7, 9, 9, 9,
        9, 9, 9, 9, 4, 7, 7, 7, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 3, 6, 6,
        6, 6, 6, 6, 9, 5, 9, 9, 9, 9, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 9, 9, 9, 9,
        3, 6, 6, 6, 6, 6, 6, 6, 6, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 5, 8, 8,
        8, 8, 9, 5, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 3, 6, 6, 6, 9, 9,
        9, 9, 9, 9, 9, 3, 6, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 3, 6, 6, 6,

In [45]:
n = 3
test_dataset['input_ids'][n].unsqueeze(0).shape

torch.Size([1, 512])

In [57]:
from tqdm import tqdm
# BERTonly
NED_model.eval()
NED_model.to("cuda")
output = []

with torch.no_grad():
    for n in tqdm(range(len(test_dataset))):
    # for n in range(1):
        torch.cuda.empty_cache()
        out = NED_model(input_ids=test_dataset['input_ids'][n].unsqueeze(0).to("cuda"))
        # print(f"{n+1} / {len(test_dataset)}")
        output.append(out[0].to("cpu"))
        # output.append(torch.argmax(out[0], dim=-1).squeeze(0))
        # output[-1].to("cpu")
    # print([tag_to_NER_id[i.item()] for i in output[-1]])

100%|██████████| 100/100 [00:02<00:00, 44.60it/s]


In [58]:
len(output)

100

In [64]:
a = np.argmax(output[0], axis=-1)

In [67]:
outputNER = []
labels = []
# for i in range(100):
#     prediction = np.argmax(output[i], axis=-1)
#     outputNER.append([id2label[p.item()] for p in prediction[-1]])
#     labels.append([id2label[p.item()] for p in test_dataset['labels'][i]])

for i in range(100):
    prediction = np.argmax(output[i], axis=-1)
    outputNER.append([id2label[p.item()] for p in prediction])
    labels.append([id2label[p.item()] for p in test_dataset['labels'][i]])

In [68]:
# metric.compute((outputNER, labels))
metric.compute(predictions=outputNER, references=labels)

  _warn_prf(average, modifier, msg_start, len(result))


{'CLS]': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 100},
 'Chemical]': {'precision': 0.8162083936324168,
  'recall': 0.7888111888111888,
  'f1': 0.8022759601706969,
  'number': 715},
 'Disease]': {'precision': 0.7110187110187111,
  'recall': 0.760845383759733,
  'f1': 0.7350886620096722,
  'number': 899},
 'Gene]': {'precision': 0.8409919766593728,
  'recall': 0.870188679245283,
  'f1': 0.8553412462908012,
  'number': 1325},
 'OUT]': {'precision': 0.7050524934383202,
  'recall': 0.7321976149914821,
  'f1': 0.7183687113488216,
  'number': 2935},
 'PAD]': {'precision': 0.21138211382113822,
  'recall': 0.2708333333333333,
  'f1': 0.23744292237442924,
  'number': 96},
 'STOP]': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 100},
 'overall_precision': 0.7386602098466505,
 'overall_recall': 0.7416531604538088,
 'overall_f1': 0.7401536595228467,
 'overall_accuracy': 0.70359375}

In [111]:
NED_model.eval()
NED_model.to("cuda")
output = []

with torch.no_grad():
    # for n in range(len(test_dataset)):
    # for n in range(1):
    torch.cuda.empty_cache()
    decoder_input_ids = torch.tensor([[2, 5, 6, 6]], dtype=torch.long).to("cuda")
    while True:
        out = NED_model(input_ids=test_dataset['input_ids'][n].unsqueeze(0).to("cuda"),
                    decoder_input_ids=decoder_input_ids)
        # out = NED_model.encoder(input_ids=test_dataset['input_ids'][n].unsqueeze(0).to("cuda"), return_dict=True)
        # out = NED_model.decoder(input_ids=decoder_input_ids, m=out[0])
        next_token_id= torch.argmax(out[0][0][-1]).to("cuda")
         # print(next_token_id.item())
        decoder_input_ids = torch.cat((decoder_input_ids, next_token_id.unsqueeze(0).unsqueeze(0)), dim=-1).to("cuda")
        if next_token_id == 1 or next_token_id == 0 or decoder_input_ids.shape[-1]==513:
            break 
        # break
    print(f"{n+1} / {len(test_dataset)}")
    print(decoder_input_ids[0])
    output.append(decoder_input_ids.reshape(-1)[1:])
    output[-1].to("cpu")
    print([tag_to_NER_id[i.item()] for i in output[-1]])

4 / 100
tensor([2, 5, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 

In [101]:
# length of test_dataset['input_ids'][n] when item != 0
len(test_dataset['input_ids'][n][test_dataset['input_ids'][n] != 0])

423

In [102]:
out[0][-1]

tensor([[-4.7586, -4.0376, -5.7821,  0.8483,  2.2826,  5.0164, -3.1285,  5.1057],
        [-4.2715, -5.3063, -5.2710, -2.3930,  1.1866,  1.1315,  7.1215,  5.7219],
        [-6.3509, -5.4186, -5.6808, -2.2958,  2.9140,  0.8766,  7.7758,  6.0267],
        [-6.6171, -5.0198, -5.5981, -1.7730,  3.3793,  1.2843,  6.3436,  5.7047],
        [-6.6452, -4.7902, -5.5513, -1.4868,  3.5267,  1.5100,  5.5078,  5.5153]],
       device='cuda:0')

In [58]:
test_dataset['labels'][3]

tensor([5, 6, 6, 7, 7, 7, 7, 7, 3, 6, 6, 6, 7, 7, 7, 5, 6, 6, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 5, 6, 6, 6, 6, 6,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 3, 6, 6, 6, 7, 7, 7, 7, 7, 5, 6, 6, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 3, 6, 6, 6, 7, 7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 5, 6, 6, 7, 5, 6, 6, 6, 6, 4, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 6,
        6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 7,
        5, 6, 6, 7, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 6, 6, 6,
        6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7,