In [4]:
import numpy as np
import pandas as pd
import torch
from torch.nn import Linear, Tanh, Dropout, MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
from transformers import (
    AutoTokenizer,
    BertPreTrainedModel,
    BertModel,
    RobertaPreTrainedModel,
    RobertaModel,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    precision_recall_fscore_support,
)


In [5]:
CLASSES = [
    "Directed Link",
    "Negative Cause",
    "Negative Decrease",
    "Negative Increase",
    "Positive Cause",
    "Positive Decrease",
    "Positive Increase",
    "Undirected Link",
]


class BertSrcClassifier(BertPreTrainedModel):
    def __init__(self, config, mask_token_id, num_token_layers=2):
        super().__init__(config)
        self.mask_token_id = mask_token_id
        self.num_token_layers = num_token_layers
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout
            if config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dense = Linear(config.hidden_size * num_token_layers, config.hidden_size)
        self.activation = Tanh()
        self.dropout = Dropout(classifier_dropout)
        self.classifier = Linear(config.hidden_size, config.num_labels)

        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        token_type_ids: torch.Tensor = None,
        position_ids: torch.Tensor = None,
        head_mask: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        labels: torch.Tensor = None,
        output_attentions: bool = None,
        output_hidden_states: bool = None,
        return_dict: bool = None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        assert 1 <= self.num_token_layers <= 3
        if self.num_token_layers == 1:
            output = outputs[1]
        else:
            check = input_ids == self.mask_token_id
            if self.num_token_layers == 3:
                check[:, 0] = True
            output = torch.reshape(
                outputs[0][check], (-1, self.num_token_layers * self.config.hidden_size)
            )
            output = self.dense(output)
            output = self.activation(output)

        output = self.dropout(output)
        logits = self.classifier(output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (
                    labels.dtype == torch.long or labels.dtype == torch.int
                ):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="macro"
    )
    acc = accuracy_score(labels, preds)
    print(classification_report(labels, preds, digits=3, target_names=CLASSES))
    return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}


class SrcDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)


In [7]:
train = pd.read_csv("../data/train.csv")
validation = pd.read_csv("../data/validation.csv")
test = pd.read_csv("../data/test.csv")

data = pd.concat([train, validation, test]).reset_index(drop=True)

In [5]:
x1 = data.apply(lambda x : x["14"].replace(x["3"], '[MASK]'), axis=1)
x2 = data.apply(lambda x : x["14"].replace(x["8"], '[MASK]'), axis=1)

label_encoder = LabelEncoder()
y = torch.tensor(label_encoder.fit_transform(data["16"]), dtype=torch.long)

In [7]:
training_args = TrainingArguments(
    output_dir='./chcekpoints',
    logging_dir='./logs',
    num_train_epochs=10,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='epoch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_ratio=0.1,
    weight_decay=0.01,
    report_to='wandb',
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    seed=42,
)

In [8]:
for num_layer in [1, 2, 3]:

    kf = KFold(n_splits=5, shuffle=True, random_state=0)

    counter = 0
    results_lst = []

    for train_idx, eval_idx in kf.split(y):
        print("Starting fold", counter)
        counter += 1

        train_x1 = x1[train_idx].to_list()
        train_x2 = x2[train_idx].to_list()
        train_y = y[train_idx]
        eval_x1 = x1[eval_idx].to_list()
        eval_x2 = x2[eval_idx].to_list()
        eval_y = y[eval_idx]

        model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = BertSrcClassifier.from_pretrained(model_name, num_labels=len(CLASSES), mask_token_id=tokenizer.mask_token_id, num_token_layers=num_layer)

        tokenized_train = tokenizer(train_x1, train_x2,  return_tensors='pt', padding=True, truncation=True)
        tokenized_val = tokenizer(eval_x1, eval_x2,  return_tensors='pt', padding=True, truncation=True)
        max_sequence_length = max(len(tokenized_train['input_ids'][0]), len(tokenized_val['input_ids'][0]))
        tokenized_train = tokenizer(train_x1, train_x2,  return_tensors='pt', padding="max_length", max_length=max_sequence_length)
        tokenized_val = tokenizer(eval_x1, eval_x2,  return_tensors='pt', padding="max_length", max_length=max_sequence_length)
        train_dataset = SrcDataset(tokenized_train, train_y)
        eval_dataset = SrcDataset(tokenized_val, eval_y)

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            compute_metrics=compute_metrics
        )
        trainer.train()
        preds = trainer.predict(eval_dataset)
        results_lst.append([preds.metrics, classification_report(eval_y, preds.predictions.argmax(axis=1), target_names=CLASSES, digits=4)])

    with open("logs/" + str(num_layer) + "-layer", "w") as f:
        for i in range(5):
            f.write(f"fold {i}\n")
            f.write(f"accuracy: {results_lst[i][0]['test_accuracy']}\n")
            f.write(f"precision: {results_lst[i][0]['test_precision']}\n")
            f.write(f"recall: {results_lst[i][0]['test_recall']}\n")
            f.write(f"f1: {results_lst[i][0]['test_f1']}\n")
            f.write(results_lst[i][1] + "\n\n")
    

Starting fold 0
