In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [6]:
%run ../load_anno_mi.py

In [7]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-11-06 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-11-06 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-11-06 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-11-06 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-11-06 00:00:34,-1


In [8]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

## Baseline: Fine-tune BERT for classification

In [9]:
num_epochs = 5
seeds = [0, 1, 12, 123, 1234]
validation_metric = "f1"

In [10]:
label_to_id_client

{'neutral': 0, 'change': 1, 'sustain': 2}

In [11]:
id_to_label_client

{0: 'neutral', 1: 'change', 2: 'sustain'}

In [12]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": anno_mi,
    "feature_name": "utterance_text",
    "label_column": "client_talk_type",
    "label_to_id": label_to_id_client,
    "id_to_label": id_to_label_client,
    "output_dim": output_dim_client,
    "seeds": seeds,
    "device": device,
    "batch_size": 8,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "validation_metric": validation_metric,
    "verbose": False,
}

## Focal Loss

In [13]:
loss = "focal"
gamma = 2

In [14]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.675592,0.659985,0.416835
2,0.669000,0.635145,0.616808,0.522742
3,0.669000,1.108824,0.639938,0.487981
4,0.478000,1.217864,0.66384,0.530595
5,0.186100,1.40129,0.667695,0.528114


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.687119,0.69783,0.366307
2,0.665500,0.641263,0.658598,0.539876
3,0.665500,0.84217,0.658598,0.528115
4,0.315800,1.203762,0.656093,0.505126
5,0.119400,1.294447,0.663606,0.525201


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.63342,0.619497,0.528791
2,0.661900,0.656693,0.661164,0.524105
3,0.661900,1.04701,0.656447,0.535441
4,0.367500,1.246493,0.647799,0.54377
5,0.130900,1.408912,0.657233,0.54252


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.650523,0.658023,0.481067
2,0.642000,0.678633,0.689627,0.533854
3,0.642000,0.963991,0.704214,0.570266
4,0.347000,1.175009,0.700972,0.562988
5,0.124300,1.257137,0.699352,0.588171


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.69719,0.659332,0.421528
2,0.684700,0.654941,0.682967,0.558969
3,0.684700,0.848778,0.663407,0.533307
4,0.461100,1.166178,0.672372,0.5365
5,0.203800,1.380184,0.678077,0.529205


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.699687,0.529684,0.387907
2,0.744500,0.66862,0.598304,0.484167
3,0.744500,0.705105,0.649961,0.514369
4,0.673200,0.681279,0.614495,0.534497
5,0.493600,0.824582,0.647648,0.524658


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.64142,0.603506,0.511888
2,0.637400,0.655796,0.660267,0.549939
3,0.637400,0.888682,0.686144,0.554837
4,0.324600,1.083677,0.677796,0.534113
5,0.127000,1.193522,0.682805,0.546013


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.704001,0.628145,0.384065
2,0.724000,0.706054,0.633648,0.513394
3,0.724000,0.805295,0.654874,0.498996
4,0.548200,1.042794,0.65173,0.538704
5,0.238200,1.162586,0.64544,0.550053


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.721823,0.622366,0.327659
2,0.740900,0.683429,0.557536,0.44834
3,0.740900,0.690608,0.619935,0.476753
4,0.664100,0.686055,0.65154,0.507645
5,0.512300,0.822636,0.640194,0.51851


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.651309,0.651997,0.549939
2,0.673600,0.768014,0.638142,0.518442
3,0.673600,0.855569,0.641402,0.536258
4,0.431000,1.099267,0.649552,0.532274
5,0.189400,1.260795,0.664222,0.528924


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.738177,0.615266,0.281539
2,0.743200,0.704778,0.521203,0.364408
3,0.743200,0.727351,0.593678,0.398024
4,0.721600,0.725957,0.553585,0.446664
5,0.712400,0.733838,0.629915,0.408881


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.748195,0.230384,0.12483
2,0.779300,0.721273,0.66611,0.266533
3,0.779300,0.725901,0.66611,0.266533
4,0.781100,0.726037,0.66611,0.266533
5,0.756000,0.717999,0.66611,0.266533


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.687347,0.616352,0.468915
2,0.681100,0.632143,0.67217,0.538031
3,0.681100,0.81402,0.628931,0.539804
4,0.449800,1.022653,0.646226,0.544153
5,0.161100,1.214064,0.659591,0.543577


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.62024,0.670989,0.520408
2,0.623300,0.804502,0.667747,0.531717
3,0.623300,0.992513,0.699352,0.571968
4,0.295300,1.242173,0.700162,0.586186
5,0.122300,1.32668,0.690438,0.562587


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.655853,0.680522,0.391569
2,0.635700,0.679611,0.623472,0.515699
3,0.635700,1.057101,0.661777,0.540569
4,0.376600,1.346776,0.661777,0.525447
5,0.164700,1.454775,0.665037,0.530367


saving the results dataframe to CSV in client_talk_type_output/bert_classifier_focal.csv


In [15]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.620306,0.51027,"[0.7431776834445118, 0.44533120510774143, 0.34...",0.500967,"[0.7810707456978967, 0.4312210200927357, 0.290...",0.528517,"[0.7087912087912088, 0.4603960396039604, 0.416...",1,focal,2,True
0,0.583525,0.468548,"[0.7247805328815647, 0.37351703023344807, 0.30...",0.461697,"[0.7752883031301483, 0.3483226266952177, 0.261...",0.485273,"[0.6804511278195489, 0.40264026402640263, 0.37...",12,focal,2,True
0,0.62069,0.403734,"[0.7567351293678315, 0.31513537505548156, 0.13...",0.462857,"[0.7022277227722772, 0.3410182516810759, 0.345...",0.400198,"[0.8204164256795836, 0.2929042904290429, 0.087...",123,focal,2,True


In [16]:
bert_classifier["f1"].mean()

0.4608508435414917

In [17]:
bert_classifier["precision"].mean()

0.4751734599965687

In [18]:
bert_classifier["recall"].mean()

0.4713292214125982

In [19]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.74156445, 0.37799454, 0.26299355])

In [20]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.75286226, 0.37352063, 0.29913749])

In [21]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.73655292, 0.38531353, 0.29212121])

## Using Cross-Entropy loss

In [14]:
loss = "cross_entropy"
gamma = None

In [15]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.785765,0.672321,0.432937
2,0.775400,0.776554,0.683115,0.534713
3,0.775400,1.093161,0.667695,0.505132
4,0.489000,1.48509,0.673863,0.551776
5,0.232900,1.669982,0.670008,0.553317


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.789381,0.69616,0.401369
2,0.800200,0.744275,0.682805,0.513599
3,0.800200,0.879966,0.665275,0.530236
4,0.481900,1.401957,0.65192,0.505673
5,0.233900,1.584041,0.656928,0.526282


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.802496,0.638365,0.355265
2,0.800700,0.797021,0.662736,0.509817
3,0.800700,1.102244,0.647013,0.499553
4,0.546500,1.726583,0.630503,0.53042
5,0.247800,1.854083,0.636792,0.538918


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.770768,0.646677,0.305516
2,0.756600,0.782754,0.707455,0.525693
3,0.756600,0.947078,0.691248,0.535889
4,0.493100,1.492437,0.684765,0.539541
5,0.230200,1.606085,0.683144,0.546389


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.891895,0.648737,0.262317
2,0.853000,0.839432,0.653627,0.279867
3,0.853000,0.766023,0.686227,0.524709
4,0.748600,0.876836,0.687857,0.520109
5,0.451500,1.161471,0.669112,0.523948


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.783213,0.674634,0.512431
2,0.775100,0.792804,0.667695,0.529006
3,0.775100,1.123906,0.674634,0.485411
4,0.493700,1.429447,0.672321,0.548251
5,0.229700,1.621011,0.666153,0.548185


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.762562,0.695326,0.510326
2,0.757200,0.851971,0.631052,0.506403
3,0.757200,1.025627,0.66611,0.544655
4,0.422600,1.449614,0.661937,0.543015
5,0.213500,1.59424,0.665275,0.543448


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.851551,0.639937,0.393908
2,0.869800,0.841708,0.636006,0.418776
3,0.869800,0.805911,0.65173,0.421439
4,0.787300,0.870235,0.650157,0.548777
5,0.542500,1.092163,0.650157,0.537419


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.796239,0.632901,0.262506
2,0.778700,0.770839,0.675041,0.52665
3,0.778700,0.879101,0.705835,0.570613
4,0.513900,1.319404,0.696921,0.571077
5,0.249900,1.472384,0.693679,0.562052


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.824308,0.648737,0.262317
2,0.850200,0.845203,0.608802,0.414645
3,0.850200,0.884911,0.648737,0.262317
4,0.855500,0.913317,0.654442,0.360031
5,0.813100,0.802362,0.629177,0.406364


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.888055,0.625289,0.256483
2,0.820200,0.7831,0.664611,0.527706
3,0.820200,0.842394,0.666153,0.528574
4,0.625500,1.385373,0.674634,0.534995
5,0.302800,1.465366,0.665382,0.541456


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.842361,0.597663,0.396045
2,0.787900,0.790428,0.66611,0.501457
3,0.787900,0.844369,0.698664,0.546563
4,0.530700,1.10835,0.707012,0.574858
5,0.272200,1.308376,0.685309,0.570146


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.778711,0.659591,0.421995
2,0.780200,0.786736,0.662736,0.502525
3,0.780200,1.259872,0.654874,0.540866
4,0.496100,1.570108,0.660377,0.547654
5,0.227200,1.75441,0.650157,0.548857


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.784893,0.656402,0.502944
2,0.743100,0.878205,0.688817,0.487492
3,0.743100,1.099589,0.691248,0.547808
4,0.435300,1.447949,0.692869,0.567048
5,0.215400,1.546254,0.694489,0.572177


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.768551,0.683782,0.377543
2,0.741100,0.767284,0.682152,0.545899
3,0.741100,1.167989,0.656887,0.547341
4,0.437500,1.648849,0.661777,0.541317
5,0.213800,1.690377,0.678892,0.536598


saving the results dataframe to CSV in client_talk_type_output/bert_classifier_ce.csv


In [16]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.667816,0.482975,"[0.7897202516396734, 0.39833641404805914, 0.26...",0.533202,"[0.7351108896087715, 0.4527310924369748, 0.411...",0.466538,"[0.8530942741469058, 0.3556105610561056, 0.190...",1,cross_entropy,,True
0,0.653065,0.471073,"[0.7802614930409111, 0.4383346425765907, 0.194...",0.502176,"[0.759233926128591, 0.41829085457271364, 0.329...",0.467022,"[0.8024869866975131, 0.4603960396039604, 0.138...",12,cross_entropy,,True
0,0.658812,0.493624,"[0.7851562500000001, 0.4317984361424848, 0.263...",0.506405,"[0.7584905660377359, 0.4559633027522936, 0.304...",0.485519,"[0.8137651821862348, 0.41006600660066006, 0.23...",123,cross_entropy,,True


In [17]:
bert_classifier_ce["f1"].mean()

0.48255742861862533

In [18]:
bert_classifier_ce["precision"].mean()

0.513927952353963

In [19]:
bert_classifier_ce["recall"].mean()

0.4730263591232846

In [20]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.785046  , 0.42282316, 0.23980312])

In [21]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.75094513, 0.44232842, 0.34851031])

In [22]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.82311548, 0.40869087, 0.18727273])