In [None]:
import os
import sys

ROOT_DIR = os.path.abspath(os.path.join('..'))
sys.path.append(ROOT_DIR)

os.environ["WANDB_SILENT"] = "true"

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from datasets import Dataset, DatasetDict
import wandb

from data.dataloader import NoReCDataLoader
from utils.utils import init_run

In [None]:
MODEL_NAME =  "ltg/norbert3-xs"
# MODEL_NAME =  "ltg/norbert3-small"
# MODEL_NAME =  "ltg/norbert3-base"
# MODEL_NAME =  "ltg/norbert3-large"

# MODEL_NAME = "NbAiLab/nb-bert-base"
# MODEL_NAME = "NbAiLab/nb-bert-large"

# MODEL_NAME = "bert-base-multilingual-cased"

config = init_run(config_name="bert", run_name="Binary" + MODEL_NAME)

# Loading and processing data

In [None]:
train_df, val_df, test_df = NoReCDataLoader(**config.dataloader).load_binary_dataset()

train_df = train_df[["text", "label"]]
val_df = val_df[["text", "label"]]
test_df = test_df[["text", "label"]]

train_dataset = Dataset.from_dict(train_df)
val_dataset = Dataset.from_dict(val_df)
test_dataset = Dataset.from_dict(test_df)

norec_dataset = DatasetDict({"train": train_dataset, "val": val_dataset, "test": test_dataset})
norec_dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=config.model.max_seq_length, padding="max_length")

norec_dataset = norec_dataset.map(preprocess_function, batched=True)

# Modeling and Training

In [None]:
id2label = {0: "negative", 1: "positive"}
label2id = {"negative": 0, "positive": 1}

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    num_labels=2,
    id2label=id2label,
    label2id=label2id
)

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    accuracy = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    f1 = f1_score(labels, preds)

    return {
        'accuracy': accuracy,
        'auc': auc,
        'f1_score': f1
    }

In [None]:
training_args = TrainingArguments(
    num_train_epochs=config.general.max_epochs,
    per_device_train_batch_size=config.general.batch_size,
    per_device_eval_batch_size=config.general.batch_size,
    weight_decay=0.01,
    logging_steps=250,
    output_dir=config.general.log_dir,
    fp16=False,
    seed=config.general.seed,
    data_seed=config.general.data_seed,
    report_to="wandb",
    evaluation_strategy="steps",
    load_best_model_at_end = True,
    eval_steps = 250,
    metric_for_best_model = "auc",
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=norec_dataset["train"],
    eval_dataset=norec_dataset["val"],
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.train()
trainer.evaluate()

# Testing

In [None]:
predictions = trainer.predict(norec_dataset["test"])

y_test = norec_dataset["test"]['label']
y_preds = predictions.predictions.argmax(-1)

auc = roc_auc_score(y_test, y_preds)
accuracy = accuracy_score(y_test, y_preds)
f1 = f1_score(y_test, y_preds)

In [None]:
wandb.run.summary['test_auc'] = auc
wandb.run.summary['test_accuracy'] = accuracy
wandb.run.summary['test_f1'] = f1
wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
    preds=y_preds.numpy(),
    y_true=y_test.numpy(),
    class_names=["negative", "positive"]
)})
wandb.finish()