The model: 

    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

The input be like:

    [cls] [relation] [sep] "text"

(the ner1 and ner2 tags depends on the first occurence orders)

The output be like: (length==512)

    [out] [src-b] [src-in] [tgt-b] [tgt-in] [out]

or

    [out] [out] [out] [out] [out] [out] [out] [out] [out]

In [1]:
import pandas as pd
import torch
import re
from tqdm.notebook import trange, tqdm
from torch import nn
from labels import get_labels
from relations import relations
from datasets import DatasetDict, Dataset

In [2]:
# load labels for bert_w_ner
additional_tokens, labels, id2label, label2id = get_labels(mode='bert_without_ner')
print(additional_tokens, "\n", labels)

{'additional_special_tokens': ['[Association]', '[Bind]', '[Comparison]', '[Conversion]', '[Cotreatment]', '[Drug_Interaction]', '[Negative_Correlation]', '[Positive_Correlation]']} 
 ['[pad]', '[src-b]', '[src-in]', '[tgt-b]', '[tgt-in]', '[out]']


In [3]:
checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

# Tokenizer

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

for id in [1, 3, 0, 2, 4]:
    print(f"{id}: {tokenizer.decode(id)}")

1: [UNK]
3: [SEP]
0: [PAD]
2: [CLS]
4: [MASK]


In [5]:
# adding new tokens to the tokenizer
# since I haven't load the model so I will resize the embedding of the model later]
num_added_toks = tokenizer.add_special_tokens(additional_tokens)
print('We have added', num_added_toks, 'tokens')

# save the tokenizer
tokenizer.save_pretrained("bert_without_ner/bert_without_ner_tokenizer")

We have added 8 tokens


('bert_without_ner/bert_without_ner_tokenizer/tokenizer_config.json',
 'bert_without_ner/bert_without_ner_tokenizer/special_tokens_map.json',
 'bert_without_ner/bert_without_ner_tokenizer/vocab.txt',
 'bert_without_ner/bert_without_ner_tokenizer/added_tokens.json',
 'bert_without_ner/bert_without_ner_tokenizer/tokenizer.json')

# Data pre-process

In [6]:
from data_preprocessing import make_bert_re_data
from data_preprocessing import bert_w_ner_preprocess_function

In [7]:
# train and valid file paths
train_file_path = 'data/BioRED/processed/train.tsv'
valid_file_path = 'data/BioRED/processed/dev.tsv'

In [8]:
# make bert_re data
train_data_raw = make_bert_re_data(file_path=train_file_path, lower=True, no_ner_input=True)
valid_data_raw = make_bert_re_data(file_path=valid_file_path, lower=True, no_ner_input=True)

In [9]:
# make into Dataset type
train_data_raw = Dataset.from_dict(train_data_raw)
valid_data_raw = Dataset.from_dict(valid_data_raw)

dataset = DatasetDict({
    "train": train_data_raw,
    "valid": valid_data_raw
})

In [10]:
dataset

DatasetDict({
    train: Dataset({
        features: ['pmids', 'input_texts', 'input_relations', 'outputs'],
        num_rows: 28723
    })
    valid: Dataset({
        features: ['pmids', 'input_texts', 'input_relations', 'outputs'],
        num_rows: 7922
    })
})

In [11]:
# tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["inputs", "outputs", "pmids"])
tokenized_datasets = dataset.map(lambda example: bert_w_ner_preprocess_function(example, tokenizer, mode="bert_without_ner"), batched=True, remove_columns=["input_texts", "input_relations", "outputs", "pmids"])

Map:   0%|          | 0/28723 [00:00<?, ? examples/s]

Map:   0%|          | 0/7922 [00:00<?, ? examples/s]

In [12]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 28723
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 7922
    })
})

In [13]:
# to tensor
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Model

In [14]:
# Prints the number of trainable parameters in the model.
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


In [15]:
from transformers import AutoModelForTokenClassification

print(labels)
model = AutoModelForTokenClassification.from_pretrained(checkpoint,
                                                        num_labels=len(labels),
                                                        id2label=id2label,
                                                        label2id=label2id,
                                                        )



['[pad]', '[src-b]', '[src-in]', '[tgt-b]', '[tgt-in]', '[out]']


Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForToken

In [16]:
model

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

adding special tokens to the embedding layer

In [17]:
model.resize_token_embeddings(len(tokenizer))

Embedding(30530, 768)

In [18]:
print_trainable_parameters(model)

trainable params: 108902406 || all params: 108902406 || trainable%: 100.0


# Evaluate

In [19]:
import evaluate

metric = evaluate.load("seqeval")

In [20]:
import numpy as np

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    # only consider non-padding tokens [:4]
    true_labels = [[id2label[l.item()] for l in label[:509]] for label in labels]
    true_predictions = [
    [id2label[p.item()] for (p, l) in zip(prediction[:509], label[:509])]
    for prediction, label in zip(predictions, labels)
    ]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }

# Trainer

customized loss

In [21]:
from transformers import Trainer, TrainingArguments

tokenized_datasets

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss
        loss_fct = nn.CrossEntropyLoss()
        # only consider first 4 columns of the batched loss and the batched labels
        logits = logits[:, :509, :]
        labels = labels[:, :509]
        loss = loss_fct(logits.reshape(-1, self.model.config.num_labels), labels.reshape(-1))
        return (loss, outputs) if return_outputs else loss

collator

In [22]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenized_datasets["train"][i] for i in range(2)])
batch["labels"]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


tensor([[3, 4, 4,  ..., 0, 0, 0],
        [5, 5, 5,  ..., 0, 0, 0]])

wandb

In [23]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="BERT",
    # notes="PubmedBERT-FT-NER_w_NERin_10epochs",
    name="bert_without_ner_epoch_5_loss2",
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33m309439737[0m ([33mtian1995[0m). Use [1m`wandb login --relogin`[0m to force relogin


args

In [24]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="outputs",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    report_to="wandb",
    # per_device_train_batch_size=4,
    # per_device_eval_batch_size=4,
    auto_find_batch_size=True,
    load_best_model_at_end=True,
    # push_to_hub=True,
)

In [25]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

In [26]:
# trainer = CustomTrainer(
#     model=model,
#     args=args,
#     train_dataset=tokenized_datasets["train"],
#     eval_dataset=tokenized_datasets["valid"],
#     data_collator=data_collator,
#     compute_metrics=compute_metrics,
#     tokenizer=tokenizer,
# )

training

In [27]:
model.to("cuda")
trainer.train()



  0%|          | 0/17955 [00:00<?, ?it/s]

{'loss': 0.0843, 'learning_rate': 1.9443052074631025e-05, 'epoch': 0.14}
{'loss': 0.0352, 'learning_rate': 1.8886104149262045e-05, 'epoch': 0.28}
{'loss': 0.0304, 'learning_rate': 1.8329156223893068e-05, 'epoch': 0.42}
{'loss': 0.0286, 'learning_rate': 1.7772208298524088e-05, 'epoch': 0.56}
{'loss': 0.028, 'learning_rate': 1.721526037315511e-05, 'epoch': 0.7}
{'loss': 0.0257, 'learning_rate': 1.6658312447786135e-05, 'epoch': 0.84}
{'loss': 0.0242, 'learning_rate': 1.6101364522417155e-05, 'epoch': 0.97}


  0%|          | 0/991 [00:00<?, ?it/s]



{'eval_loss': 0.026951752603054047, 'eval_precision': 0.6794690507685992, 'eval_recall': 0.3748270048564455, 'eval_f1': 0.4831344058121433, 'eval_accuracy': 0.991852040697389, 'eval_runtime': 172.0931, 'eval_samples_per_second': 46.033, 'eval_steps_per_second': 5.759, 'epoch': 1.0}
{'loss': 0.0237, 'learning_rate': 1.5544416597048178e-05, 'epoch': 1.11}
{'loss': 0.0232, 'learning_rate': 1.49874686716792e-05, 'epoch': 1.25}
{'loss': 0.0228, 'learning_rate': 1.4430520746310221e-05, 'epoch': 1.39}
{'loss': 0.0227, 'learning_rate': 1.3873572820941243e-05, 'epoch': 1.53}
{'loss': 0.0203, 'learning_rate': 1.3316624895572265e-05, 'epoch': 1.67}
{'loss': 0.0207, 'learning_rate': 1.2759676970203288e-05, 'epoch': 1.81}
{'loss': 0.0219, 'learning_rate': 1.220272904483431e-05, 'epoch': 1.95}


  0%|          | 0/991 [00:00<?, ?it/s]



{'eval_loss': 0.02589261345565319, 'eval_precision': 0.6360639007732465, 'eval_recall': 0.40776528018922525, 'eval_f1': 0.496948695145512, 'eval_accuracy': 0.9918036811763417, 'eval_runtime': 171.9126, 'eval_samples_per_second': 46.082, 'eval_steps_per_second': 5.765, 'epoch': 2.0}
{'loss': 0.019, 'learning_rate': 1.1645781119465331e-05, 'epoch': 2.09}
{'loss': 0.0209, 'learning_rate': 1.1088833194096353e-05, 'epoch': 2.23}
{'loss': 0.0198, 'learning_rate': 1.0531885268727375e-05, 'epoch': 2.37}
{'loss': 0.0194, 'learning_rate': 9.974937343358396e-06, 'epoch': 2.51}
{'loss': 0.0198, 'learning_rate': 9.417989417989418e-06, 'epoch': 2.65}
{'loss': 0.0191, 'learning_rate': 8.86104149262044e-06, 'epoch': 2.78}
{'loss': 0.0202, 'learning_rate': 8.304093567251463e-06, 'epoch': 2.92}


  0%|          | 0/991 [00:00<?, ?it/s]



{'eval_loss': 0.026785697788000107, 'eval_precision': 0.6227006544953496, 'eval_recall': 0.4548702850959966, 'eval_f1': 0.5257157813613297, 'eval_accuracy': 0.9917235779696838, 'eval_runtime': 172.0715, 'eval_samples_per_second': 46.039, 'eval_steps_per_second': 5.759, 'epoch': 3.0}
{'loss': 0.0203, 'learning_rate': 7.747145641882484e-06, 'epoch': 3.06}
{'loss': 0.0185, 'learning_rate': 7.190197716513506e-06, 'epoch': 3.2}
{'loss': 0.018, 'learning_rate': 6.6332497911445286e-06, 'epoch': 3.34}
{'loss': 0.0202, 'learning_rate': 6.07630186577555e-06, 'epoch': 3.48}
{'loss': 0.0191, 'learning_rate': 5.519353940406572e-06, 'epoch': 3.62}
{'loss': 0.0191, 'learning_rate': 4.962406015037594e-06, 'epoch': 3.76}
{'loss': 0.0177, 'learning_rate': 4.405458089668616e-06, 'epoch': 3.9}


  0%|          | 0/991 [00:00<?, ?it/s]



{'eval_loss': 0.025697294622659683, 'eval_precision': 0.694310759711241, 'eval_recall': 0.40658262248056165, 'eval_f1': 0.5128465554726802, 'eval_accuracy': 0.9922165970868224, 'eval_runtime': 172.0331, 'eval_samples_per_second': 46.049, 'eval_steps_per_second': 5.761, 'epoch': 4.0}
{'loss': 0.0164, 'learning_rate': 3.8485101642996384e-06, 'epoch': 4.04}
{'loss': 0.0172, 'learning_rate': 3.29156223893066e-06, 'epoch': 4.18}
{'loss': 0.0172, 'learning_rate': 2.734614313561682e-06, 'epoch': 4.32}
{'loss': 0.0183, 'learning_rate': 2.177666388192704e-06, 'epoch': 4.46}
{'loss': 0.0169, 'learning_rate': 1.620718462823726e-06, 'epoch': 4.59}
{'loss': 0.0183, 'learning_rate': 1.063770537454748e-06, 'epoch': 4.73}
{'loss': 0.0175, 'learning_rate': 5.0682261208577e-07, 'epoch': 4.87}


  0%|          | 0/991 [00:00<?, ?it/s]



{'eval_loss': 0.025351131334900856, 'eval_precision': 0.6678882452614572, 'eval_recall': 0.42648650008807026, 'eval_f1': 0.5205626708436991, 'eval_accuracy': 0.9921191340520964, 'eval_runtime': 171.9237, 'eval_samples_per_second': 46.079, 'eval_steps_per_second': 5.764, 'epoch': 5.0}
{'train_runtime': 6568.8904, 'train_samples_per_second': 21.863, 'train_steps_per_second': 2.733, 'train_loss': 0.02284046025436856, 'epoch': 5.0}


TrainOutput(global_step=17955, training_loss=0.02284046025436856, metrics={'train_runtime': 6568.8904, 'train_samples_per_second': 21.863, 'train_steps_per_second': 2.733, 'train_loss': 0.02284046025436856, 'epoch': 5.0})

In [28]:
import wandb
wandb.finish()
trainer.save_model("bert_without_ner/models/bert_without_ner_epoch_5_loss2")

0,1
eval/accuracy,▃▂▁█▇
eval/f1,▁▃█▆▇
eval/loss,█▃▇▃▁
eval/precision,▇▂▁█▅
eval/recall,▁▄█▄▆
eval/runtime,█▁▇▆▁
eval/samples_per_second,▁█▂▃█
eval/steps_per_second,▁█▁▃▇
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
eval/accuracy,0.99212
eval/f1,0.52056
eval/loss,0.02535
eval/precision,0.66789
eval/recall,0.42649
eval/runtime,171.9237
eval/samples_per_second,46.079
eval/steps_per_second,5.764
train/epoch,5.0
train/global_step,17955.0


# Inference

In [5]:
import pandas as pd
import torch
import re
from tqdm.notebook import trange, tqdm
from torch import nn
from labels import get_labels
from relations import relations
from datasets import DatasetDict, Dataset

from data_preprocessing import make_bert_re_data
from data_preprocessing import bert_w_ner_preprocess_function
additional_tokens, labels, id2label, label2id = get_labels(mode='bert_without_ner')

In [6]:
# model and tokenizer
from transformers import AutoModelForTokenClassification, AutoTokenizer

model = AutoModelForTokenClassification.from_pretrained("bert_without_ner/models/bert_without_ner_epoch_5_loss2",
                                                        num_labels=len(labels),
                                                        id2label=id2label,
                                                        label2id=label2id,
                                                        )

tokenizer = AutoTokenizer.from_pretrained("bert_without_ner/bert_without_ner_tokenizer")

In [7]:
# load test data and preprocess

# tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["inputs", "outputs", "pmids"])
# tokenized_datasets = dataset.map(lambda example: bert_w_ner_preprocess_function(example, tokenizer, mode="bert_without_ner"), batched=True, remove_columns=["input_texts", "input_relations", "outputs", "pmids"])
test_file_path = 'data/BioRED/processed/test.tsv'
test_data = make_bert_re_data(file_path=test_file_path, lower=True, no_ner_input=True)

test_dataset_raw = Dataset.from_dict(test_data)
# test_dataset = test_dataset_raw.map(NER_preprocess_function, batched=False)
# with bert only:
test_dataset = test_dataset_raw.map(lambda example: bert_w_ner_preprocess_function(example, tokenizer, mode="bert_without_ner"), batched=True, remove_columns=["input_texts", "input_relations", "outputs", "pmids"])
test_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

Map:   0%|          | 0/8014 [00:00<?, ? examples/s]

In [13]:
import numpy as np
import evaluate

metric = evaluate.load("seqeval")

def compute_metrics_infer(logits, labels):
    predictions = [np.argmax(logit, axis=-1)[0] for logit in logits]

    # only consider non-padding tokens [:4]
    true_labels = [[id2label[l.item()] for l in label if l != label2id['[pad]']] for label in labels]
    true_predictions = [[id2label[p.item()] for (p, l) in zip(prediction, label)]for prediction, label in zip(predictions, true_labels)]
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels, suffix=True)
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }, all_metrics

In [42]:
from tqdm import tqdm
model.eval()
model.to("cuda")
output = []

with torch.no_grad():
    for input_line in tqdm(test_dataset['input_ids']):
    # for n in range(1):
        torch.cuda.empty_cache()
        out = model(input_ids=input_line.unsqueeze(0).to("cuda"))
        # print(f"{n+1} / {len(test_dataset)}")
        output.append(out[0].to("cpu"))
        # output.append(torch.argmax(out[0], dim=-1).squeeze(0))
        # output[-1].to("cpu")
    # print([tag_to_NER_id[i.item()] for i in output[-1]])

100%|██████████| 60720/60720 [14:50<00:00, 68.15it/s]


In [8]:
output = torch.load('bert_without_ner/results/bert_without_ner_epoch_5_loss2_raw.pt')

In [23]:
compute_metrics_infer(output, test_dataset['labels'])

  _warn_prf(average, modifier, msg_start, len(result))


({'precision': 0.5722459270752521,
  'recall': 0.20179878257301143,
  'f1': 0.2983769024624564,
  'accuracy': 0.9890279785690237},
 {'[out': {'precision': 0.6857640906449739,
   'recall': 0.3190247067091961,
   'f1': 0.4354660172680983,
   'number': 18497},
  '[pad': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 0},
  '[src': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 6818},
  '[tgt': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 3927},
  'overall_precision': 0.5722459270752521,
  'overall_recall': 0.20179878257301143,
  'overall_f1': 0.2983769024624564,
  'overall_accuracy': 0.9890279785690237})

In [11]:
import numpy as np
import evaluate

# metric = evaluate.load("seqeval")

def compute_metrics_infer_doclevel(logits, labels):
    predictions = [np.argmax(logit, axis=-1)[0] for logit in logits]


    all_doc = 0
    correct_src = 0
    correct_tgt = 0
    all_doc_w_src_tgt = 0
    pred_src_doc = 0
    pred_tgt_doc = 0
    correct_doc = 0
    # only consider the src and tgt tags
    for prediction, label in zip(predictions, labels):
        all_doc += 1
        temp_dict = {'src': [],
                        'tgt': []}
        doc_has_relation = False
        pred_has = {'src': False,
                        'tgt': False}
        for pred, lab in zip(prediction, label):
            if lab.item() == label2id["[pad]"]:
                break
            else:
                if lab.item() != label2id["[out]"]:
                    doc_has_relation = True
                if pred.item() != label2id["[pad]"] and pred.item() != label2id["[out]"]:
                    pred_has[f'{id2label[pred.item()][1:4]}'] = True

                    if pred.item() == lab.item():
                        temp_dict[f'{id2label[pred.item()][1:4]}'].append(1)

                    else:
                        temp_dict[f'{id2label[pred.item()][1:4]}'].append(0)
        # print(temp_dict)
        if doc_has_relation:
            all_doc_w_src_tgt += 1
        
        if pred_has['src']:
            pred_src_doc += 1
        if pred_has['tgt']:
            pred_tgt_doc += 1

        if 0 in temp_dict['src'] or len(temp_dict['src']) == 0:
            if 0 in temp_dict['tgt'] or len(temp_dict['tgt']) == 0:
                pass
            else:
                correct_tgt += 1
        else:
            correct_src += 1
            if 0 in temp_dict['tgt'] or len(temp_dict['tgt']) == 0:
                pass
            else:
                correct_tgt += 1
                correct_doc == 1
                    

    return {'src precision': correct_src / pred_src_doc,
            'src recall': correct_src / all_doc_w_src_tgt,
            'src f1': 2 * correct_src / (pred_src_doc + all_doc_w_src_tgt),
            'tgt precision': correct_tgt / pred_tgt_doc,
            'tgt recall': correct_tgt / all_doc_w_src_tgt,
            'tgt f1': 2 * correct_tgt / (pred_tgt_doc + all_doc_w_src_tgt),
            'doc precision': correct_doc / all_doc,
            }

In [12]:
compute_metrics_infer_doclevel(output, test_dataset['labels'])

{'src precision': 0.3069306930693069,
 'src recall': 0.026655202063628546,
 'src f1': 0.0490506329113924,
 'tgt precision': 0.3157894736842105,
 'tgt recall': 0.015477214101461736,
 'tgt f1': 0.029508196721311476,
 'doc precision': 0.0}

In [55]:
def compute_metrics_infer_relation(logits, labels):
    predictions = [np.argmax(logit, axis=-1)[0] for logit in logits]


    all_text = 0
    correct_relation = 0
    
    precision_correct_relation = 0
    # only consider the src and tgt tags
    for prediction, label in zip(predictions, labels):
        all_text += 1
        if 1 in label:
            if 1 in prediction:
                correct_relation += 1
        else:
            if 1 not in prediction and 2 not in prediction and 3 not in prediction and 4 not in prediction:
                correct_relation += 1

        if 1 in prediction or 2 in prediction or 3 in prediction or 4 in prediction:
            if 1 in label:
                precision_correct_relation += 1
        else:
            if 1 not in label:
                precision_correct_relation += 1
                    

    return {
        'relation precision': precision_correct_relation / all_text,
        'relation recall': correct_relation / all_text,
        'relation f1': 2 * precision_correct_relation * correct_relation / (precision_correct_relation + correct_relation) / all_text}

In [56]:
compute_metrics_infer_relation(output, test_dataset['labels'])

{'relation precision': 0.8694784127776392,
 'relation recall': 0.8659845270776142,
 'relation f1': 0.8677279529302294}