In [6]:
import math
import numpy as np
import pandas as pd
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Dropout, Linear, MSELoss, Tanh
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    precision_recall_fscore_support,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BertModel,
    BertPreTrainedModel,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import SequenceClassifierOutput

In [19]:
class BertSrcClassifier(BertPreTrainedModel):
    """BertSRC Classifier
    """
    def __init__(self, config, mask_token_id: int, num_token_layer: int = 2):
        super().__init__(config)
        self.mask_token_id = mask_token_id
        self.n_output_layer = num_token_layer
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout
            if config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dense = Linear(config.hidden_size * num_token_layer, config.hidden_size)
        self.activation = Tanh()
        self.dropout = Dropout(classifier_dropout)
        self.classifier = Linear(config.hidden_size, config.num_labels)

        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        token_type_ids: torch.Tensor = None,
        position_ids: torch.Tensor = None,
        head_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        labels: torch.Tensor = None,
        output_attentions: bool = None,
        output_hidden_states: bool = None,
        return_dict: bool = None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        assert 1 <= self.n_output_layer <= 3
        if self.n_output_layer == 1:
            output = outputs[0][0]
        else:
            check = input_ids == self.mask_token_id
            if self.n_output_layer == 3:
                check[:, 0] = True
            output = torch.reshape(
                outputs[0][check], (-1, self.n_output_layer * self.config.hidden_size)
            )

        output = self.dense(output)
        output = self.activation(output)
        output = self.dropout(output)
        logits = self.classifier(output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (
                    labels.dtype == torch.long or labels.dtype == torch.int
                ):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class BertsrcDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    print(classification_report(labels, preds, digits=3, target_names=label_encoder.classes_))
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [8]:
train_abstracts = pd.read_csv("ChemProt_Corpus/chemprot_training/chemprot_training_abstracts.tsv", sep="\t", names=["a_id", "title", "abstract"])
train_entities = pd.read_csv("ChemProt_Corpus/chemprot_training/chemprot_training_entities.tsv", sep="\t", names=["a_id", "e_id", "type", "start", "end", "text"])
train_gs = pd.read_csv("ChemProt_Corpus/chemprot_training/chemprot_training_gold_standard.tsv", sep="\t", names=["a_id", "relation", "e1", "e2"])

dev_abstracts = pd.read_csv("ChemProt_Corpus/chemprot_development/chemprot_development_abstracts.tsv", sep="\t", names=["a_id", "title", "abstract"])
dev_entities = pd.read_csv("ChemProt_Corpus/chemprot_development/chemprot_development_entities.tsv", sep="\t", names=["a_id", "e_id", "type", "start", "end", "text"])
dev_gs = pd.read_csv("ChemProt_Corpus/chemprot_development/chemprot_development_gold_standard.tsv", sep="\t", names=["a_id", "relation", "e1", "e2"])

test_abstracts = pd.read_csv("ChemProt_Corpus/chemprot_test_gs/chemprot_test_abstracts_gs.tsv", sep="\t", names=["a_id", "title", "abstract"])
test_entities = pd.read_csv("ChemProt_Corpus/chemprot_test_gs/chemprot_test_entities_gs.tsv", sep="\t", names=["a_id", "e_id", "type", "start", "end", "text"])
test_gs = pd.read_csv("ChemProt_Corpus/chemprot_test_gs/chemprot_test_gold_standard.tsv", sep="\t", names=["a_id", "relation", "e1", "e2"])

In [9]:
train_x = []
train_y = []

abs_dict = train_abstracts.set_index("a_id").to_dict()
ent_dict = train_entities.set_index(["a_id", "e_id"]).to_dict()

for i in range(len(train_gs)):
    a_id = train_gs.loc[i, "a_id"]
    e1_id = train_gs.loc[i, "e1"][5:]
    e2_id = train_gs.loc[i, "e2"][5:]

    text = abs_dict["title"][a_id] + " " + abs_dict["abstract"][a_id]
    e1_start = ent_dict["start"][(a_id, e1_id)]
    e1_end = ent_dict["end"][(a_id, e1_id)]
    e2_start = ent_dict["start"][(a_id, e2_id)]
    e2_end = ent_dict["end"][(a_id, e2_id)]

    two_masked_input = text[:e1_start] + "[MASK]" + text[e1_end:] + " [SEP] " + text[:e2_start] + "[MASK]" + text[e2_end:]
    train_x.append(two_masked_input)
    train_y.append(train_gs.loc[i, "relation"])

In [10]:
dev_x = []
dev_y = []

abs_dict = dev_abstracts.set_index("a_id").to_dict()
ent_dict = dev_entities.set_index(["a_id", "e_id"]).to_dict()

for i in range(len(dev_gs)):
    a_id = dev_gs.loc[i, "a_id"]
    e1_id = dev_gs.loc[i, "e1"][5:]
    e2_id = dev_gs.loc[i, "e2"][5:]

    text = abs_dict["title"][a_id] + " " + abs_dict["abstract"][a_id]
    e1_start = ent_dict["start"][(a_id, e1_id)]
    e1_end = ent_dict["end"][(a_id, e1_id)]
    e2_start = ent_dict["start"][(a_id, e2_id)]
    e2_end = ent_dict["end"][(a_id, e2_id)]

    two_masked_input = text[:e1_start] + "[MASK]" + text[e1_end:] + " [SEP] " + text[:e2_start] + "[MASK]" + text[e2_end:]
    dev_x.append(two_masked_input)
    dev_y.append(dev_gs.loc[i, "relation"])

In [11]:
test_x = []
test_y = []

abs_dict = test_abstracts.set_index("a_id").to_dict()
ent_dict = test_entities.set_index(["a_id", "e_id"]).to_dict()

for i in range(len(test_gs)):
    a_id = test_gs.loc[i, "a_id"]
    e1_id = test_gs.loc[i, "e1"][5:]
    e2_id = test_gs.loc[i, "e2"][5:]

    text = abs_dict["title"][a_id] + " " + abs_dict["abstract"][a_id]
    e1_start = ent_dict["start"][(a_id, e1_id)]
    e1_end = ent_dict["end"][(a_id, e1_id)]
    e2_start = ent_dict["start"][(a_id, e2_id)]
    e2_end = ent_dict["end"][(a_id, e2_id)]

    two_masked_input = text[:e1_start] + "[MASK]" + text[e1_end:] + " [SEP] " + text[:e2_start] + "[MASK]" + text[e2_end:]
    test_x.append(two_masked_input)
    test_y.append(test_gs.loc[i, "relation"])

In [12]:
label_encoder = LabelEncoder()
label_encoder.fit(train_y)

train_y = torch.tensor(label_encoder.transform(train_y), dtype=torch.int).to("cuda")
dev_y = torch.tensor(label_encoder.transform(dev_y), dtype=torch.int).to("cuda")
test_y = torch.tensor(label_encoder.transform(test_y), dtype=torch.int).to("cuda")

In [14]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
model = BertSrcClassifier.from_pretrained(
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    num_labels=len(label_encoder.classes_),
    mask_token_id=tokenizer.mask_token_id,
    num_token_layer=2,
)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertSrcClassifier: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertSrcClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertSrcClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertSrcClassifier were not initialized 

In [15]:
encodings_train = tokenizer(train_x, truncation=True, max_length=512)
encodings_dev = tokenizer(dev_x, truncation=True, max_length=512)
encodings_test = tokenizer(test_x, truncation=True, max_length=512)

In [18]:
train_dataset = BertsrcDataset(encodings_train, train_y)
dev_dataset = BertsrcDataset(encodings_dev, dev_y)
test_dataset = BertsrcDataset(encodings_test, test_y)

In [14]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [15]:
training_args = TrainingArguments(
    output_dir="./checkpoints",
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=30,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_ratio=0.1,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-8,
    max_grad_norm=1,
    lr_scheduler_type="linear",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    seed=42,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)


In [None]:
trainer.train()

In [32]:
pred = trainer.evaluate()

***** Running Prediction *****
  Num examples = 2416
  Batch size = 256
  item['labels'] = torch.tensor(self.labels[idx])


              precision    recall  f1-score   support

       CPR:3      0.886     0.878     0.882       550
       CPR:4      0.934     0.921     0.928      1094
       CPR:5      0.850     0.931     0.889       116
       CPR:6      0.926     0.950     0.938       199
       CPR:9      0.954     0.963     0.959       457

    accuracy                          0.922      2416
   macro avg      0.910     0.929     0.919      2416
weighted avg      0.922     0.922     0.922      2416



In [33]:
pred = trainer.predict(test_dataset)

***** Running Prediction *****
  Num examples = 3458
  Batch size = 256
  item['labels'] = torch.tensor(self.labels[idx])


              precision    recall  f1-score   support

       CPR:3      0.890     0.812     0.849       665
       CPR:4      0.929     0.945     0.937      1661
       CPR:5      0.848     0.918     0.882       195
       CPR:6      0.896     0.915     0.905       293
       CPR:9      0.948     0.958     0.953       644

    accuracy                          0.918      3458
   macro avg      0.902     0.910     0.905      3458
weighted avg      0.918     0.918     0.917      3458

