In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [6]:
%run ../load_anno_mi.py

In [7]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-11-01 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-11-01 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-11-01 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-11-01 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-11-01 00:00:34,-1


In [8]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# Baseline: Fine-tune BERT for classification

In [9]:
num_epochs = 10
seeds = [1, 12, 123]
validation_metric = "f1"

In [10]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": anno_mi,
    "feature_name": "utterance_text",
    "label_column": "client_talk_type",
    "seeds": seeds,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "validation_metric": validation_metric,
    "device": device,
    "verbose": False,
}

## Focal Loss

In [11]:
loss = "focal"
gamma = 2

In [12]:
from __future__ import annotations

import os
import shutil
from typing import Iterable

import evaluate
import numpy as np
import pandas as pd
import torch
from datasets.arrow_dataset import Dataset
from nlpsig import TextEncoder
from nlpsig.classification_utils import DataSplits, Folds
from sklearn import metrics
from tqdm.auto import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from nlpsig_networks.focal_loss import FocalLoss
from nlpsig_networks.pytorch_utils import set_seed

df=kwargs["df"]
feature_name=kwargs["feature_name"]
label_column=kwargs["label_column"]
pretrained_model_name="bert-base-uncased"

df = df.iloc[client_index].reset_index(drop=True)
split_ids = client_transcript_id

y_data = df[label_column]
label_to_id = {str(y_data.unique()[i]): i for i in range(len(y_data.unique()))}
id_to_label = {v: k for k, v in label_to_id.items()}
output_dim = len(label_to_id.values())

if loss == "focal":
    criterion = FocalLoss(gamma=gamma)
    y_train = torch.tensor(y_data.apply(lambda x: label_to_id[str(x)]).values)
    criterion.set_alpha_from_y(y=y_train)
elif loss == "cross_entropy":
    criterion = torch.nn.CrossEntropyLoss()
else:
    raise ValueError("loss must be either 'focal' or 'cross_entropy'")

# create column named "label_as_id" which are the corresponding IDs
df["label_as_id"] = df[label_column].apply(lambda x: label_to_id[str(x)])

# initialise model, tokenizer and data_collator
model = AutoModelForSequenceClassification.from_pretrained(
    pretrained_model_name,
    num_labels=output_dim,
    id2label=id_to_label,
    label2id=label_to_id,
)

datasize = (
        len(df.index)
    )
dummy_data = torch.ones(datasize)

folds = Folds(
            x_data=dummy_data,
            y_data=dummy_data,
            groups=split_ids,
            n_splits=5,
            shuffle=True,
            random_state=0,
        )


# set tokenizer and data collator
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

text_encoder = TextEncoder(
        df=df,
        feature_name=feature_name,
        model=model,
        tokenizer=tokenizer,
        data_collator=data_collator,
        verbose=False,
    )

# tokenize the text in df[feature_name]
text_encoder.tokenize_text()

# split the dataset using the indices which are passed in
text_encoder.split_dataset(indices=folds.fold_indices[0])

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'speaker', 'label_as_id', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
        num_rows: 2631
    })
    test: Dataset({
        features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'speaker', 'label_as_id', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
        num_rows: 889
    })
  

In [13]:
# set up training arguments
text_encoder.set_up_training_args(
    output_dir=output_dir,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=8,
    disable_tqdm=False,
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=seed,
)

# set up trainer
def _compute_metrics(eval_pred):
    accuracy = evaluate.load("accuracy")
    f1 = evaluate.load("f1")
    predictions = np.argmax(eval_pred.predictions, axis=1)
    accuracy = accuracy.compute(
        predictions=predictions, references=eval_pred.label_ids
    )["accuracy"]
    f1 = f1.compute(
        predictions=predictions, references=eval_pred.label_ids, average="macro"
    )["f1"]
    return {"accuracy": accuracy, "f1": f1}

text_encoder.set_up_trainer(
    data_collator=data_collator,
    compute_metrics=_compute_metrics,
    custom_loss=criterion.forward,
)

<nlpsig.encode_text.TextEncoder.set_up_trainer.<locals>.MyTrainer at 0x1456ea91fee0>

In [14]:
text_encoder.trainer.train_dataset

Dataset({
    features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'speaker', 'label_as_id', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
    num_rows: 2631
})

In [15]:
text_encoder.dataset_split["train"]

Dataset({
    features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'speaker', 'label_as_id', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
    num_rows: 2631
})

In [16]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



KeyError: 'labels'

In [None]:
bert_classifier

In [None]:
bert_classifier["f1"].mean()

In [None]:
bert_classifier["precision"].mean()

In [None]:
bert_classifier["recall"].mean()

In [None]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

## Using Cross-Entropy loss

In [None]:
loss = "cross_entropy"
gamma = None

In [None]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

In [None]:
bert_classifier_ce

In [None]:
bert_classifier_ce["f1"].mean()

In [None]:
bert_classifier_ce["precision"].mean()

In [None]:
bert_classifier_ce["recall"].mean()

In [None]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

In [None]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)