In [None]:
import numpy as np
import pickle
import os

seed = 2023

In [None]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [None]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [None]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [None]:
%run ../load_anno_mi.py

In [None]:
anno_mi.head()

In [None]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

## Baseline: Fine-tune BERT for classification

In [None]:
num_epochs = 5
learning_rates = [5e-5, 1e-5, 1e-6]
seeds = [1, 12, 123]
validation_metric = "f1"

In [None]:
label_to_id_client

In [None]:
id_to_label_client

In [None]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": anno_mi,
    "feature_name": "utterance_text",
    "label_column": "client_talk_type",
    "label_to_id": label_to_id_client,
    "id_to_label": id_to_label_client,
    "output_dim": output_dim_client,
    "learning_rates": learning_rates,
    "seeds": seeds,
    "device": device,
    "batch_size": 8,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "validation_metric": validation_metric,
    "verbose": False,
}

## Focal Loss

In [None]:
loss = "focal"
gamma = 2

In [None]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

In [None]:
bert_classifier

In [None]:
bert_classifier["f1"].mean()

In [None]:
bert_classifier["precision"].mean()

In [None]:
bert_classifier["recall"].mean()

In [None]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

## Using Cross-Entropy loss

In [None]:
loss = "cross_entropy"
gamma = None

In [None]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

In [None]:
bert_classifier_ce

In [None]:
bert_classifier_ce["f1"].mean()

In [None]:
bert_classifier_ce["precision"].mean()

In [None]:
bert_classifier_ce["recall"].mean()

In [None]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)