In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [6]:
%run ../load_anno_mi.py

In [7]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-11-03 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-11-03 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-11-03 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-11-03 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-11-03 00:00:34,-1


In [8]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

## Baseline: Fine-tune BERT for classification

In [9]:
num_epochs = 5
seeds = [1, 12, 123]
validation_metric = "f1"

In [10]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": anno_mi,
    "feature_name": "utterance_text",
    "label_column": "client_talk_type",
    "seeds": seeds,
    "device": device,
    "batch_size": 8,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "validation_metric": validation_metric,
    "verbose": False,
}

## Focal Loss

In [11]:
loss = "focal"
gamma = 2

In [12]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.671098,0.639938,0.41401
2,No log,0.623863,0.663069,0.539196
3,No log,0.719978,0.668466,0.520446
4,No log,0.794235,0.661527,0.546534
5,No log,0.868729,0.66384,0.555466


  0%|          | 0/889 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.644931,0.69783,0.411904
2,No log,0.590828,0.661937,0.521542
3,No log,0.601215,0.636895,0.540292
4,No log,0.622216,0.66778,0.553186
5,No log,0.642634,0.653589,0.55228


  0%|          | 0/1188 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.672268,0.612421,0.453484
2,No log,0.602983,0.66195,0.54407
3,No log,0.635857,0.646226,0.525428
4,No log,0.742127,0.649371,0.56648
5,No log,0.764182,0.649371,0.555552


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.62402,0.672609,0.517554
2,No log,0.62928,0.681524,0.550359
3,No log,0.695616,0.683144,0.57378
4,No log,0.848521,0.677472,0.545707
5,No log,0.86499,0.672609,0.542991


  0%|          | 0/1079 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.648665,0.674817,0.487115
2,No log,0.629796,0.668297,0.507137
3,No log,0.714836,0.632437,0.523588
4,No log,0.839829,0.645477,0.521312
5,No log,0.925295,0.642217,0.515604


  0%|          | 0/1100 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.673213,0.643022,0.429664
2,No log,0.61801,0.659985,0.551187
3,No log,0.633223,0.670008,0.54723
4,No log,0.747485,0.66384,0.547728
5,No log,0.813989,0.652274,0.531969


  0%|          | 0/889 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.640919,0.631052,0.476538
2,No log,0.582157,0.649416,0.527225
3,No log,0.604475,0.632721,0.537794
4,No log,0.651109,0.649416,0.544209
5,No log,0.676097,0.66611,0.55706


  0%|          | 0/1188 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.674662,0.623428,0.491205
2,No log,0.642652,0.625,0.468333
3,No log,0.644784,0.638365,0.557543
4,No log,0.688729,0.638365,0.556883
5,No log,0.710343,0.636006,0.5421


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.678111,0.597245,0.43339
2,No log,0.624044,0.643436,0.526227
3,No log,0.644048,0.683955,0.564936
4,No log,0.816652,0.675851,0.562017
5,No log,0.844945,0.681524,0.555193


  0%|          | 0/1079 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.617619,0.657702,0.535223
2,No log,0.632466,0.624287,0.531971
3,No log,0.704734,0.660147,0.551028
4,No log,0.896823,0.664222,0.556527
5,No log,0.912896,0.667482,0.541819


  0%|          | 0/1100 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.631426,0.670779,0.536688
2,No log,0.612584,0.658443,0.548238
3,No log,0.682249,0.64148,0.543288
4,No log,0.745676,0.669237,0.555974
5,No log,0.860926,0.643793,0.531495


  0%|          | 0/889 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.615835,0.617696,0.52115
2,No log,0.574926,0.627713,0.537733
3,No log,0.602418,0.671953,0.55991
4,No log,0.681203,0.671119,0.563488
5,No log,0.715486,0.676962,0.562017


  0%|          | 0/1188 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.675219,0.593553,0.488603
2,No log,0.647784,0.56761,0.493452
3,No log,0.641609,0.601415,0.488276
4,No log,0.695488,0.620283,0.528428
5,No log,0.733338,0.63522,0.53605


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.644892,0.598055,0.503889
2,No log,0.615977,0.665316,0.529856
3,No log,0.656012,0.688006,0.554251
4,No log,0.728878,0.683144,0.548179
5,No log,0.815001,0.688006,0.538691


  0%|          | 0/1079 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.646488,0.607172,0.480245
2,No log,0.639824,0.650367,0.518248
3,No log,0.684399,0.639772,0.543672
4,No log,0.870881,0.667482,0.534701
5,No log,0.89817,0.663407,0.530834


  0%|          | 0/1100 [00:00<?, ?it/s]

saving the results dataframe to CSV in client_talk_type_output/bert_classifier_focal.csv


In [13]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.637165,0.490209,"[0.7681200453001132, 0.3871829105473965, 0.315...",0.493358,"[0.752357182473655, 0.42028985507246375, 0.307...",0.489035,"[0.7845575477154425, 0.3589108910891089, 0.323...",1,focal,2,True
0,0.616667,0.49738,"[0.7516858983965232, 0.4008583690987124, 0.339...",0.490963,"[0.7800933125972006, 0.41771019677996424, 0.27...",0.518075,"[0.7252747252747253, 0.38531353135313534, 0.44...",12,focal,2,True
0,0.627395,0.512491,"[0.7561890472618155, 0.4339869281045752, 0.347...",0.50732,"[0.7857811038353602, 0.4598337950138504, 0.276...",0.535636,"[0.728744939271255, 0.41089108910891087, 0.467...",123,focal,2,True


In [14]:
bert_classifier["f1"].mean()

0.5000266858119725

In [15]:
bert_classifier["precision"].mean()

0.49721340933941444

In [16]:
bert_classifier["recall"].mean()

0.5142486864842258

In [17]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.758665  , 0.40734274, 0.33407232])

In [18]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.77274387, 0.43261128, 0.28628508])

In [19]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.7461924 , 0.3850385 , 0.41151515])

## Using Cross-Entropy loss

In [20]:
loss = "cross_entropy"
gamma = None

In [21]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.819342,0.643793,0.31944
2,No log,0.776637,0.660756,0.427186
3,No log,0.78289,0.676947,0.489899
4,No log,0.850129,0.666924,0.510591
5,No log,0.882697,0.661527,0.509059


  0%|          | 0/889 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.74195,0.709516,0.43371
2,No log,0.730335,0.691987,0.406367
3,No log,0.780205,0.679466,0.528769
4,No log,0.810637,0.679466,0.528717
5,No log,0.82568,0.679466,0.528102


  0%|          | 0/1188 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.804242,0.654088,0.412138
2,No log,0.777305,0.663522,0.435987
3,No log,0.827279,0.661164,0.507753
4,No log,0.91409,0.656447,0.549467
5,No log,0.922612,0.658019,0.526969


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.760374,0.67342,0.420574
2,No log,0.752918,0.686386,0.442084
3,No log,0.770269,0.692869,0.54323
4,No log,0.852209,0.688006,0.542587
5,No log,0.881872,0.683144,0.540007


  0%|          | 0/1079 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.77354,0.682967,0.38728
2,No log,0.756695,0.682967,0.447867
3,No log,0.814335,0.680522,0.474998
4,No log,0.888833,0.669112,0.493192
5,No log,0.903987,0.671557,0.517213


  0%|          | 0/1100 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.788571,0.656901,0.371158
2,No log,0.756821,0.677718,0.530846
3,No log,0.819356,0.670008,0.546983
4,No log,0.974275,0.659214,0.520863
5,No log,0.98687,0.659214,0.55628


  0%|          | 0/889 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.747666,0.702003,0.394806
2,No log,0.721954,0.710351,0.563304
3,No log,0.797611,0.671953,0.562197
4,No log,0.826615,0.668614,0.531489
5,No log,0.880339,0.671119,0.54088


  0%|          | 0/1188 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.823848,0.647799,0.426662
2,No log,0.81618,0.633648,0.419563
3,No log,0.849242,0.658805,0.521595
4,No log,0.928315,0.646226,0.548166
5,No log,0.961377,0.648585,0.534476


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.784152,0.665316,0.501418
2,No log,0.744745,0.677472,0.550367
3,No log,0.779446,0.685575,0.569343
4,No log,0.858423,0.698541,0.580699
5,No log,0.878396,0.701783,0.573136


  0%|          | 0/1079 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.744402,0.700081,0.429542
2,No log,0.78641,0.649552,0.503039
3,No log,0.802914,0.682967,0.522398
4,No log,0.913845,0.656887,0.531736
5,No log,0.940051,0.657702,0.517722


  0%|          | 0/1100 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.783874,0.653816,0.340749
2,No log,0.766803,0.639938,0.471916
3,No log,0.776369,0.662298,0.520774
4,No log,0.850662,0.676947,0.563799
5,No log,0.904903,0.67926,0.5576


  0%|          | 0/889 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.757677,0.656093,0.422809
2,No log,0.732912,0.676127,0.554099
3,No log,0.758291,0.692821,0.522299
4,No log,0.817564,0.698664,0.56071
5,No log,0.848058,0.698664,0.560541


  0%|          | 0/1188 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.802587,0.652516,0.424081
2,No log,0.795439,0.643082,0.519995
3,No log,0.859947,0.649371,0.527444
4,No log,0.962361,0.633648,0.531442
5,No log,0.999896,0.643082,0.5272


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.758118,0.679903,0.418208
2,No log,0.746508,0.675851,0.521656
3,No log,0.782351,0.692869,0.535841
4,No log,0.872589,0.686386,0.544413
5,No log,0.909315,0.682334,0.543516


  0%|          | 0/1079 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.780773,0.671557,0.406399
2,No log,0.781099,0.655257,0.432979
3,No log,0.786806,0.690302,0.496486
4,No log,0.812554,0.692747,0.536896
5,No log,0.850325,0.673187,0.536569


  0%|          | 0/1100 [00:00<?, ?it/s]

saving the results dataframe to CSV in client_talk_type_output/bert_classifier_ce.csv


In [22]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.666284,0.411718,"[0.7948684036503108, 0.40261437908496733, 0.03...",0.494171,"[0.7323909334633195, 0.4265927977839335, 0.323...",0.423396,"[0.8689994216310005, 0.3811881188118812, 0.02]",1,cross_entropy,,True
0,0.660728,0.483504,"[0.7848518111964874, 0.4240035041611914, 0.241...",0.509268,"[0.7467362924281984, 0.4519140989729225, 0.329...",0.472439,"[0.8270676691729323, 0.39933993399339934, 0.19...",12,cross_entropy,,True
0,0.646169,0.490161,"[0.7708947885939037, 0.41975308641975306, 0.27...",0.501798,"[0.7495219885277247, 0.4335971855760774, 0.322...",0.48252,"[0.7935222672064778, 0.4067656765676568, 0.247...",123,cross_entropy,,True


In [23]:
bert_classifier_ce["f1"].mean()

0.46179440822554924

In [24]:
bert_classifier_ce["precision"].mean()

0.5017456883387937

In [25]:
bert_classifier_ce["recall"].mean()

0.4594516561739073

In [26]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.78353833, 0.41545699, 0.1863879 ])

In [27]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.74288307, 0.43736803, 0.32498597])

In [28]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.82986312, 0.39576458, 0.15272727])