The model: 

    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

The input be like:

    [cls] [relation] [sep] "text w ner1 and ner2 tags"

(the ner1 and ner2 tags depends on the first occurence orders)

The output be like:

    [src] [none] [tgt] [none]

or

    [src] [ner2] [tgt] [ner1]

or

    [src] [ner1] [tgt] [ner2]

In [1]:
import pandas as pd
import torch
import re
from tqdm.notebook import trange, tqdm
from torch import nn
from labels import get_labels
from relations import relations
from datasets import DatasetDict, Dataset

In [2]:
# load labels for bert_w_ner
additional_tokens, labels, id2label, label2id = get_labels(mode='bert_w_ner')
print(additional_tokens, "\n", labels)

{'additional_special_tokens': ['[ner1]', '[/ner1]', '[ner2]', '[/ner2]', '[Association]', '[Bind]', '[Comparison]', '[Conversion]', '[Cotreatment]', '[Drug_Interaction]', '[Negative_Correlation]', '[Positive_Correlation]']} 
 ['[pad]', '[src]', '[ner1]', '[ner2]', '[tgt]', '[none]']


In [3]:
checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

# Tokenizer

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

for id in [1, 3, 0, 2, 4]:
    print(f"{id}: {tokenizer.decode(id)}")

1: [UNK]
3: [SEP]
0: [PAD]
2: [CLS]
4: [MASK]


In [5]:
# adding new tokens to the tokenizer
# since I haven't load the model so I will resize the embedding of the model later]
num_added_toks = tokenizer.add_special_tokens(additional_tokens)
print('We have added', num_added_toks, 'tokens')

# save the tokenizer
tokenizer.save_pretrained("bert_w_ner/bert_w_ner_tokenizer")

We have added 12 tokens


('bert_w_ner/bert_w_ner_tokenizer/tokenizer_config.json',
 'bert_w_ner/bert_w_ner_tokenizer/special_tokens_map.json',
 'bert_w_ner/bert_w_ner_tokenizer/vocab.txt',
 'bert_w_ner/bert_w_ner_tokenizer/added_tokens.json',
 'bert_w_ner/bert_w_ner_tokenizer/tokenizer.json')

# Data pre-process

In [6]:
from data_preprocessing import make_bert_re_data
from data_preprocessing import bert_w_ner_preprocess_function

In [7]:
# train and valid file paths
train_file_path = 'data/BioRED/processed/train.tsv'
valid_file_path = 'data/BioRED/processed/dev.tsv'

In [8]:
# make bert_re data
train_data_raw = make_bert_re_data(file_path=train_file_path, lower=True, output_none=True)
valid_data_raw = make_bert_re_data(file_path=valid_file_path, lower=True, output_none=True)

In [9]:
# make into Dataset type
train_data_raw = Dataset.from_dict(train_data_raw)
valid_data_raw = Dataset.from_dict(valid_data_raw)

dataset = DatasetDict({
    "train": train_data_raw,
    "valid": valid_data_raw
})

In [10]:
# tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["inputs", "outputs", "pmids"])
tokenized_datasets = dataset.map(lambda example: bert_w_ner_preprocess_function(example, tokenizer), batched=True, remove_columns=["input_texts", "input_relations", "outputs", "pmids"])

Map:   0%|          | 0/183160 [00:00<?, ? examples/s]

Map:   0%|          | 0/53264 [00:00<?, ? examples/s]

In [11]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 183160
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 53264
    })
})

In [12]:
# to tensor
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Model

In [13]:
# Prints the number of trainable parameters in the model.
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


In [14]:
labels

['[pad]', '[src]', '[ner1]', '[ner2]', '[tgt]', '[none]']

In [15]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(checkpoint,
                                                        num_labels=len(labels),
                                                        id2label=id2label,
                                                        label2id=label2id,
                                                        )



Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForToken

In [16]:
model

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

adding special tokens to the embedding layer

In [17]:
model.resize_token_embeddings(len(tokenizer))

Embedding(30534, 768)

In [18]:
print_trainable_parameters(model)

trainable params: 108905478 || all params: 108905478 || trainable%: 100.0


In [19]:
import inspect

print(inspect.getsource(model.forward))

    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
        expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""


# Evaluate

In [20]:
import evaluate

metric = evaluate.load("seqeval")

In [21]:
import numpy as np

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # only consider non-padding tokens [:4]
    true_labels = [[id2label[l.item()] for l in label[:4]] for label in labels]
    true_predictions = [
    [id2label[p.item()] for (p, l) in zip(prediction[:4], label[:4])]
    for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

# Trainer

customized loss

In [22]:
# change the feature name "labels" to "labels_" of tokenized_datasets
# to not trigger the default loss function the enssembled in Auto model

# tokenized_datasets = tokenized_datasets.rename_column("labels", "labels_")

In [23]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 183160
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 53264
    })
})

In [24]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenized_datasets["train"][i] for i in range(2)])
batch["labels"]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


tensor([[1, 5, 4,  ..., 0, 0, 0],
        [1, 5, 4,  ..., 0, 0, 0]])

In [25]:
from transformers import Trainer, TrainingArguments

tokenized_datasets

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss
        loss_fct = nn.CrossEntropyLoss()
        # only consider first 4 columns of the batched loss and the batched labels
        logits = logits[:, :4, :]
        labels = labels[:, :4]
        loss = loss_fct(logits.reshape(-1, self.model.config.num_labels), labels.reshape(-1))
        return (loss, outputs) if return_outputs else loss

collator

wandb

In [26]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="BERT",
    # notes="PubmedBERT-FT-NER_w_NERin_10epochs",
    name="bert_w_ner_epoch_5_loss2",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


args

In [27]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="outputs",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    report_to="wandb",
    # per_device_train_batch_size=4,
    # per_device_eval_batch_size=4,
    auto_find_batch_size=True,
    load_best_model_at_end=True,
    # push_to_hub=True,
)

In [28]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

In [28]:
trainer = CustomTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

training

In [29]:
model.to("cuda")
trainer.train()



  0%|          | 0/114475 [00:00<?, ?it/s]

{'loss': 0.0271, 'learning_rate': 1.9912644682245032e-05, 'epoch': 0.02}
{'loss': 0.0008, 'learning_rate': 1.9825289364490066e-05, 'epoch': 0.04}
{'loss': 0.0006, 'learning_rate': 1.9737934046735096e-05, 'epoch': 0.07}
{'loss': 0.0005, 'learning_rate': 1.9650578728980127e-05, 'epoch': 0.09}
{'loss': 0.0005, 'learning_rate': 1.956322341122516e-05, 'epoch': 0.11}
{'loss': 0.0004, 'learning_rate': 1.947586809347019e-05, 'epoch': 0.13}
{'loss': 0.0004, 'learning_rate': 1.938851277571522e-05, 'epoch': 0.15}
{'loss': 0.0003, 'learning_rate': 1.9301157457960255e-05, 'epoch': 0.17}
{'loss': 0.0004, 'learning_rate': 1.9213802140205286e-05, 'epoch': 0.2}
{'loss': 0.0003, 'learning_rate': 1.9126446822450316e-05, 'epoch': 0.22}
{'loss': 0.0003, 'learning_rate': 1.903909150469535e-05, 'epoch': 0.24}
{'loss': 0.0004, 'learning_rate': 1.895173618694038e-05, 'epoch': 0.26}
{'loss': 0.0004, 'learning_rate': 1.8864380869185414e-05, 'epoch': 0.28}
{'loss': 0.0004, 'learning_rate': 1.8777025551430444e-05,

  0%|          | 0/6658 [00:00<?, ?it/s]



{'eval_loss': 0.0002926877059508115, 'eval_precision': 0.9894863322319015, 'eval_recall': 0.9894863322319015, 'eval_f1': 0.9894863322319015, 'eval_accuracy': 0.9894863322319015, 'eval_runtime': 681.9754, 'eval_samples_per_second': 78.103, 'eval_steps_per_second': 9.763, 'epoch': 1.0}
{'loss': 0.0003, 'learning_rate': 1.5981655383271458e-05, 'epoch': 1.0}
{'loss': 0.0003, 'learning_rate': 1.589430006551649e-05, 'epoch': 1.03}
{'loss': 0.0003, 'learning_rate': 1.5806944747761522e-05, 'epoch': 1.05}
{'loss': 0.0002, 'learning_rate': 1.5719589430006553e-05, 'epoch': 1.07}
{'loss': 0.0002, 'learning_rate': 1.5632234112251587e-05, 'epoch': 1.09}
{'loss': 0.0002, 'learning_rate': 1.5544878794496617e-05, 'epoch': 1.11}
{'loss': 0.0003, 'learning_rate': 1.5457523476741647e-05, 'epoch': 1.14}
{'loss': 0.0003, 'learning_rate': 1.537016815898668e-05, 'epoch': 1.16}
{'loss': 0.0002, 'learning_rate': 1.5282812841231712e-05, 'epoch': 1.18}
{'loss': 0.0003, 'learning_rate': 1.5195457523476742e-05, 'ep

  0%|          | 0/6658 [00:00<?, ?it/s]



{'eval_loss': 0.0002620388404466212, 'eval_precision': 0.989373685791529, 'eval_recall': 0.989373685791529, 'eval_f1': 0.989373685791529, 'eval_accuracy': 0.989373685791529, 'eval_runtime': 679.4867, 'eval_samples_per_second': 78.389, 'eval_steps_per_second': 9.799, 'epoch': 2.0}
{'loss': 0.0002, 'learning_rate': 1.1963310766542915e-05, 'epoch': 2.01}
{'loss': 0.0002, 'learning_rate': 1.1875955448787945e-05, 'epoch': 2.03}
{'loss': 0.0002, 'learning_rate': 1.1788600131032977e-05, 'epoch': 2.05}
{'loss': 0.0003, 'learning_rate': 1.170124481327801e-05, 'epoch': 2.07}
{'loss': 0.0003, 'learning_rate': 1.1613889495523042e-05, 'epoch': 2.1}
{'loss': 0.0002, 'learning_rate': 1.1526534177768072e-05, 'epoch': 2.12}
{'loss': 0.0002, 'learning_rate': 1.1439178860013104e-05, 'epoch': 2.14}
{'loss': 0.0002, 'learning_rate': 1.1351823542258136e-05, 'epoch': 2.16}
{'loss': 0.0002, 'learning_rate': 1.1264468224503167e-05, 'epoch': 2.18}
{'loss': 0.0002, 'learning_rate': 1.11771129067482e-05, 'epoch':

  0%|          | 0/6658 [00:00<?, ?it/s]



{'eval_loss': 0.0002686600782908499, 'eval_precision': 0.9890686016821868, 'eval_recall': 0.9890686016821868, 'eval_f1': 0.9890686016821868, 'eval_accuracy': 0.9890686016821868, 'eval_runtime': 679.4823, 'eval_samples_per_second': 78.389, 'eval_steps_per_second': 9.799, 'epoch': 3.0}
{'loss': 0.0002, 'learning_rate': 7.94496614981437e-06, 'epoch': 3.01}
{'loss': 0.0002, 'learning_rate': 7.857610832059402e-06, 'epoch': 3.04}
{'loss': 0.0002, 'learning_rate': 7.770255514304434e-06, 'epoch': 3.06}
{'loss': 0.0002, 'learning_rate': 7.682900196549466e-06, 'epoch': 3.08}
{'loss': 0.0002, 'learning_rate': 7.595544878794497e-06, 'epoch': 3.1}
{'loss': 0.0002, 'learning_rate': 7.508189561039529e-06, 'epoch': 3.12}
{'loss': 0.0002, 'learning_rate': 7.420834243284561e-06, 'epoch': 3.14}
{'loss': 0.0002, 'learning_rate': 7.333478925529593e-06, 'epoch': 3.17}
{'loss': 0.0002, 'learning_rate': 7.246123607774624e-06, 'epoch': 3.19}
{'loss': 0.0002, 'learning_rate': 7.158768290019655e-06, 'epoch': 3.2

  0%|          | 0/6658 [00:00<?, ?it/s]



{'eval_loss': 0.0002841673558577895, 'eval_precision': 0.9874070666866926, 'eval_recall': 0.9874070666866926, 'eval_f1': 0.9874070666866926, 'eval_accuracy': 0.9874070666866926, 'eval_runtime': 681.1798, 'eval_samples_per_second': 78.194, 'eval_steps_per_second': 9.774, 'epoch': 4.0}
{'loss': 0.0002, 'learning_rate': 3.926621533085827e-06, 'epoch': 4.02}
{'loss': 0.0001, 'learning_rate': 3.8392662153308585e-06, 'epoch': 4.04}
{'loss': 0.0001, 'learning_rate': 3.75191089757589e-06, 'epoch': 4.06}
{'loss': 0.0001, 'learning_rate': 3.664555579820922e-06, 'epoch': 4.08}
{'loss': 0.0001, 'learning_rate': 3.5772002620659535e-06, 'epoch': 4.11}
{'loss': 0.0001, 'learning_rate': 3.4898449443109848e-06, 'epoch': 4.13}
{'loss': 0.0002, 'learning_rate': 3.402489626556017e-06, 'epoch': 4.15}
{'loss': 0.0001, 'learning_rate': 3.3151343088010486e-06, 'epoch': 4.17}
{'loss': 0.0001, 'learning_rate': 3.22777899104608e-06, 'epoch': 4.19}
{'loss': 0.0001, 'learning_rate': 3.140423673291112e-06, 'epoch':

  0%|          | 0/6658 [00:00<?, ?it/s]



{'eval_loss': 0.00030054408125579357, 'eval_precision': 0.9884865950735957, 'eval_recall': 0.9884865950735957, 'eval_f1': 0.9884865950735957, 'eval_accuracy': 0.9884865950735957, 'eval_runtime': 678.2311, 'eval_samples_per_second': 78.534, 'eval_steps_per_second': 9.817, 'epoch': 5.0}
{'train_runtime': 39488.3426, 'train_samples_per_second': 23.192, 'train_steps_per_second': 2.899, 'train_loss': 0.0003305545512389137, 'epoch': 5.0}


TrainOutput(global_step=114475, training_loss=0.0003305545512389137, metrics={'train_runtime': 39488.3426, 'train_samples_per_second': 23.192, 'train_steps_per_second': 2.899, 'train_loss': 0.0003305545512389137, 'epoch': 5.0})

In [30]:
import wandb
wandb.finish()
trainer.save_model("bert_w_ner/models/bert_w_ner_epoch_5_loss2")

0,1
eval/accuracy,██▇▁▅
eval/f1,██▇▁▅
eval/loss,▇▁▂▅█
eval/precision,██▇▁▅
eval/recall,██▇▁▅
eval/runtime,█▃▃▇▁
eval/samples_per_second,▁▆▆▂█
eval/steps_per_second,▁▆▆▂█
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
eval/accuracy,0.98849
eval/f1,0.98849
eval/loss,0.0003
eval/precision,0.98849
eval/recall,0.98849
eval/runtime,678.2311
eval/samples_per_second,78.534
eval/steps_per_second,9.817
train/epoch,5.0
train/global_step,114475.0


# Inference

In [1]:
import pandas as pd
import torch
import re
from tqdm.notebook import trange, tqdm
from torch import nn
from labels import get_labels
from relations import relations
from datasets import DatasetDict, Dataset

from data_preprocessing import make_bert_re_data
from data_preprocessing import bert_w_ner_preprocess_function
additional_tokens, labels, id2label, label2id = get_labels(mode='bert_w_ner')

In [2]:
# model and tokenizer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("bert_w_ner/models/bert_w_ner_epoch_5_loss2")

tokenizer = AutoTokenizer.from_pretrained("bert_w_ner/bert_w_ner_tokenizer")

In [41]:
import json
#load test data and preprocess
# test_file_path = 'data/BioRED/processed/test.tsv'
# test_data = make_bert_re_data(file_path=test_file_path , lower=True, output_none=True)
# # turning all of the items in test_data['pmids'] into string
# test_data['pmids'] = [str(pmid) for pmid in test_data['pmids']]
# # save test_data
# with open('bert_w_ner/data/test_data_dict.json', 'w') as f:
#     json.dump(test_data, f)


with open('bert_w_ner/data/test_data_dict.json', 'r') as f:
    test_data= json.load(f)


test_dataset_raw = Dataset.from_dict(test_data)
# test_dataset = test_dataset_raw.map(NER_preprocess_function, batched=False)
# with bert only:
# test_dataset = test_dataset_raw.map(lambda example: bert_w_ner_preprocess_function(example, tokenizer), batched=True, remove_columns=["input_texts", "input_relations", "outputs", "pmids"])

# test_dataset.save_to_disk('bert_w_ner/data/test_tokenized_dataset_w_ner')

# from datasets import load_from_disk

# test_dataset = load_from_disk('GPT_w_ner/data/test_tokenized_dataset_no_ner')
# test_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

In [12]:
test_dataset_raw[0]

{'pmids': '15485686',
 'input_texts': 'a novel scn5a mutation manifests as a malignant form of long qt syndrome with perinatal onset of tachycardia/bradycardia . objective : congenital long qt syndrome ( lqts ) with in utero onset of the rhythm disturbances is associated with a poor prognosis . in this study we investigated a newborn patient with fetal bradycardia , 2:1 [ner1] atrioventricular block [/ner1] and ventricular tachycardia soon after birth . methods : mutational analysis and dna sequencing were conducted in a newborn . the 2:1 [ner1] atrioventricular block [/ner1] improved to 1:1 conduction only after intravenous lidocaine infusion or a high dose of mexiletine , which also controlled the ventricular tachycardia . results : a novel , spontaneous lqts-3 mutation was identified in the transmembrane segment 6 of domain iv of the na(v)1.5 cardiac sodium channel , with a g-->a substitution at codon 1763 , which changed a valine ( gtg ) to a methionine ( atg ) . the proband was he

In [124]:
from tqdm import tqdm
model.eval()
model.to("cuda")
output = []

with torch.no_grad():
    for input_line in tqdm(test_dataset['input_ids']):
    # for n in range(1):
        torch.cuda.empty_cache()
        out = model(input_ids=input_line.unsqueeze(0).to("cuda"))
        # print(f"{n+1} / {len(test_dataset)}")
        output.append(out[0].to("cpu"))
        # output.append(torch.argmax(out[0], dim=-1).squeeze(0))
        # output[-1].to("cpu")
    # print([tag_to_NER_id[i.item()] for i in output[-1]])

100%|██████████| 60720/60720 [15:02<00:00, 67.29it/s]


In [49]:
# save the output logits to a file
# torch.save(output, 'bert_w_ner/results/bert_w_ner_epoch_5_loss2_raw.pt')

output = torch.load('bert_w_ner/results/bert_w_ner_epoch_5_loss2_raw.pt')

In [132]:
# # save the results
# import pickle

# # save the dictionary to a file
# with open('bert_w_ner/results/bert_w_ner_epoch_5_loss2.pickle', 'wb') as handle:
#     pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

analysis

In [74]:
import pickle
# load the dictionary from a file
with open('bert_w_ner/results/bert_w_ner_epoch_5_loss2.pickle', 'rb') as handle:
    results = pickle.load(handle)

In [75]:
import numpy as np

predictions = [np.argmax(logit, axis=-1)[0] for logit in output]

In [76]:
preds = []

for item in predictions:
    preds.append([id2label[i.item()] for i in item[:4]])

In [77]:
src_tp = 0
src_fp = 0
src_fn = 0

tgt_tp = 0
tgt_fp = 0
tgt_fn = 0

tuple_tp = 0
tuple_fp = 0
tuple_fn = 0

for pred, label in zip(preds, test_dataset_raw['outputs']):
    src = False
    relation = False
    if pred[1] == label[1]:
        src_tp += 1
        src = True
    else:
        src_fn += 1
        src_fp += 1
    if pred[2] == label[2]:
        tgt_tp += 1
        tgt = True
        if src:
            tuple_tp += 1
        else:
            tuple_fn += 1
            tuple_fp += 1
    else:
        tgt_fn += 1
        tgt_fn += 1

In [28]:
id2label

{0: '[pad]', 1: '[src]', 2: '[ner1]', 3: '[ner2]', 4: '[tgt]', 5: '[none]'}

In [79]:
# calculate the precision, recall and f1 score

# for src
src_precision = src_tp / (src_tp + src_fp)
src_recall = src_tp / (src_tp + src_fn)
src_f1 = 2 * src_precision * src_recall / (src_precision + src_recall)

# for tgt
tgt_precision = tgt_tp / (tgt_tp + tgt_fp)
tgt_recall = tgt_tp / (tgt_tp + tgt_fn)
tgt_f1 = 2 * tgt_precision * tgt_recall / (tgt_precision + tgt_recall)

# for tuple
tuple_precision = tuple_tp / (tuple_tp + tuple_fp)
tuple_recall = tuple_tp / (tuple_tp + tuple_fn)
tuple_f1 = 2 * tuple_precision * tuple_recall / (tuple_precision + tuple_recall)

print(f"src_precision: {src_precision}, src_recall: {src_recall}, src_f1: {src_f1}")
print(f"tgt_precision: {tgt_precision}, tgt_recall: {tgt_recall}, tgt_f1: {tgt_f1}")
print(f"tuple_precision: {tuple_precision}, tuple_recall: {tuple_recall}, tuple_f1: {tuple_f1}")

src_precision: 0.9798913043478261, src_recall: 0.9798913043478261, src_f1: 0.9798913043478261
tgt_precision: 1.0, tgt_recall: 0.9499036608863198, tgt_f1: 0.974308300395257
tuple_precision: 0.9816599053414469, tuple_recall: 0.9816599053414469, tuple_f1: 0.9816599053414469


# inference with balanced test data

In [43]:
# """
# {'None': 59557,
#  'Association': 635,
#  'Bind': 9,
#  'Comparison': 6,
#  'Conversion': 1,
#  'Cotreatment': 14,
#  'Drug_Interaction': 2,
#  'Negative_Correlation': 171,
#  'Positive_Correlation': 325}
#  """

# get the index of the None label

none_index = [i for i, example in enumerate(test_dataset_raw) if example['outputs'][1] == '[none]']

# randomly select 200 examples from the none_index
none_index = random.sample(none_index, 59557 - 200)
keep_indices = [i for i in range(len(test_dataset_raw)) if i not in none_index]
print(len(keep_indices))

1363


In [50]:
import numpy as np

predictions = [np.argmax(logit, axis=-1)[0] for logit in output]

preds = []

for item in predictions:
    preds.append([id2label[i.item()] for i in item[:4]])



In [65]:
src_tp = 0
src_fp = 0
src_fn = 0

tgt_tp = 0
tgt_fp = 0
tgt_fn = 0

tuple_tp = 0
tuple_fp = 0
tuple_fn = 0

step = 0
for i, _ in enumerate(preds):
    if i not in keep_indices:
        continue

    pred = preds[i]
    label = test_dataset_raw['outputs'][i]
    if step % 100 == 0:
        print(f"{step} / {len(keep_indices)}")
        print(f"pred: {pred}, \nlabel: {label}")
    step += 1
    src = False
    relation = False
    if pred[1] == label[1]:
        src_tp += 1
        src = True
    else:
        src_fn += 1
        src_fn += 1
    if pred[2] == label[2]:
        tgt_tp += 1
        tgt = True
        if src:
            tuple_tp += 1
        else:
            tuple_fn += 1
            tuple_fp += 1
    else:
        tgt_fn += 1
        tgt_fn += 1

0 / 1363
pred: ['[src]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
100 / 1363
pred: ['[src]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
200 / 1363
pred: ['[src]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
300 / 1363
pred: ['[src]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
400 / 1363
pred: ['[pad]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
500 / 1363
pred: ['[pad]', '[ner2]', '[tgt]', '[ner1]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
600 / 1363
pred: ['[pad]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
700 / 1363
pred: ['[src]', '[none]', '[tgt]', '[none]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
800 / 1363
pred: ['[src]', '[ner2]', '[tgt]', '[ner1]'], 
label: ['[src]', '[ner2]', '[tgt]', '[ner1]']
900 / 1363
pred: ['[none]', '[none]', '[tgt]', '[none]'], 
label: 

In [73]:
# calculate the precision, recall and f1 score

# for src
src_precision = src_tp / (src_tp + src_fp + 1e-8)
src_recall = src_tp / (src_tp + src_fn + 1e-8)
src_f1 = 2 * src_precision * src_recall / (src_precision + src_recall + 1e-8)

# for tgt
tgt_precision = tgt_tp / (tgt_tp + tgt_fp + 1e-8)
tgt_recall = tgt_tp / (tgt_tp + tgt_fn + 1e-8)
tgt_f1 = 2 * tgt_precision * tgt_recall / (tgt_precision + tgt_recall + 1e-8)

# for tuple
tuple_precision = tuple_tp / (tuple_tp + tuple_fp + 1e-8)
tuple_recall = tuple_tp / (tuple_tp + tuple_fn + 1e-8)
tuple_f1 = 2 * tuple_precision * tuple_recall / (tuple_precision + tuple_recall + 1e-8)

print(f"src_precision: {src_precision}, \nsrc_recall: {src_recall}, \nsrc_f1: {src_f1}")
print(f"tgt_precision: {tgt_precision}, \ntgt_recall: {tgt_recall}, \ntgt_f1: {tgt_f1}")
print(f"tuple_precision: {tuple_precision}, \ntuple_recall: {tuple_recall}, \ntuple_f1: {tuple_f1}")

src_precision: 0.9999999999726027, 
src_recall: 0.15459551037630415, 
src_f1: 0.267791633775526
tgt_precision: 0.9999999999921384, 
tgt_recall: 0.8748280605166794, 
tgt_f1: 0.9332355049200628
tuple_precision: 0.28380503144430974, 
tuple_recall: 0.28380503144430974, 
tuple_f1: 0.2838050264443098
