In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [5]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [6]:
output_dir = "therapist_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [7]:
%run ../load_anno_mi.py

In [8]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-19 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-19 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-19 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-19 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-19 00:00:34


In [9]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# Baseline: Fine-tune BERT for classification

In [10]:
num_epochs = 10
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"

In [11]:
bert_classifier = fine_tune_transformer_average_seed(
    num_epochs=num_epochs,
    pretrained_model_name="bert-base-uncased",
    df=anno_mi,
    feature_name="utterance_text",
    label_column="main_therapist_behaviour",
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=therapist_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/bert_classifier.csv",
    device=device,
    verbose=False,
)

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.603574,0.72149,0.684949
2,No log,0.502027,0.802172,0.771195
3,No log,0.537197,0.803724,0.775029
4,No log,0.630703,0.806827,0.774411
5,No log,0.722502,0.806827,0.770274
6,No log,0.792987,0.805275,0.773378
7,No log,0.790286,0.813033,0.782326
8,No log,0.816728,0.826998,0.796084
9,No log,0.827507,0.81924,0.788582
10,No log,0.822687,0.820016,0.789183


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.580531,0.782002,0.751679
2,No log,0.555349,0.785881,0.760013
3,No log,0.584613,0.794414,0.763507
4,No log,0.63504,0.794414,0.761319
5,No log,0.756904,0.782002,0.750005
6,No log,0.840247,0.791311,0.755552
7,No log,0.863801,0.788208,0.749913
8,No log,0.893537,0.788208,0.750235
9,No log,0.90326,0.79519,0.757662
10,No log,0.903733,0.78976,0.752176


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.581317,0.786656,0.759049
2,No log,0.54337,0.78976,0.761699
3,No log,0.593204,0.8045,0.767395
4,No log,0.693527,0.80993,0.779706
5,No log,0.778945,0.805275,0.768955
6,No log,0.844111,0.806051,0.773817
7,No log,0.88493,0.806051,0.776262
8,No log,0.931037,0.8045,0.772709
9,No log,0.927969,0.8045,0.773905
10,No log,0.940943,0.808379,0.777237


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.612092,0.757176,0.728843
2,No log,0.549517,0.796742,0.771804
3,No log,0.5705,0.792863,0.772221
4,No log,0.732861,0.788984,0.752683
5,No log,0.802737,0.785105,0.755653
6,No log,0.868674,0.788208,0.755633
7,No log,0.905278,0.795966,0.763079
8,No log,0.925299,0.794414,0.764438
9,No log,0.935797,0.793638,0.761481
10,No log,0.940179,0.794414,0.764125


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.58992,0.766486,0.734531
2,No log,0.546177,0.800621,0.771928
3,No log,0.601459,0.796742,0.765942
4,No log,0.677607,0.799845,0.770753
5,No log,0.754981,0.809154,0.780914
6,No log,0.881078,0.802948,0.772853
7,No log,0.915067,0.802948,0.776861
8,No log,0.932339,0.80993,0.782838
9,No log,0.957666,0.80993,0.781979
10,No log,0.953199,0.809154,0.7821


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.566331,0.803724,0.772149
2,No log,0.501513,0.807603,0.78154
3,No log,0.591527,0.816137,0.779194
4,No log,0.689025,0.810706,0.782536
5,No log,0.727251,0.830877,0.80196
6,No log,0.832131,0.82467,0.789889
7,No log,0.871346,0.814585,0.784184
8,No log,0.893299,0.820016,0.786954
9,No log,0.913774,0.81924,0.784797
10,No log,0.89423,0.823894,0.79246


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.578209,0.785105,0.737576
2,No log,0.572943,0.806051,0.77085
3,No log,0.618668,0.785105,0.756293
4,No log,0.715869,0.795966,0.762984
5,No log,0.866909,0.795966,0.755843
6,No log,0.876288,0.784329,0.753243
7,No log,0.956404,0.79519,0.758134
8,No log,0.965522,0.798293,0.764647
9,No log,0.966472,0.798293,0.765863
10,No log,0.976172,0.801396,0.768627


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.557674,0.810706,0.770015
2,No log,0.506408,0.803724,0.769906
3,No log,0.56447,0.8045,0.77181
4,No log,0.731665,0.802172,0.7678
5,No log,0.879339,0.806827,0.771833
6,No log,0.903443,0.806051,0.772955
7,No log,0.978154,0.799069,0.764452
8,No log,1.039621,0.806051,0.770353
9,No log,1.015147,0.806051,0.772694
10,No log,1.020293,0.806051,0.771956


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.554118,0.78976,0.751704
2,No log,0.56995,0.788984,0.756541
3,No log,0.585041,0.799069,0.765939
4,No log,0.665396,0.799069,0.769014
5,No log,0.807161,0.772692,0.745365
6,No log,0.8969,0.774244,0.744916
7,No log,0.929568,0.779674,0.755512
8,No log,0.916783,0.787432,0.759799
9,No log,0.94927,0.786656,0.756386
10,No log,0.957854,0.785881,0.754731


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.589097,0.799845,0.766314
2,No log,0.58739,0.769589,0.736178
3,No log,0.648516,0.802172,0.772313
4,No log,0.783335,0.771916,0.746776
5,No log,0.864061,0.79519,0.769154
6,No log,0.916765,0.79519,0.769838
7,No log,0.97026,0.791311,0.764157
8,No log,1.00428,0.796742,0.768852
9,No log,1.00898,0.79519,0.767286
10,No log,1.012583,0.794414,0.766463


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.554746,0.776571,0.747406
2,No log,0.482562,0.806827,0.778135
3,No log,0.560219,0.810706,0.763763
4,No log,0.615051,0.811482,0.781399
5,No log,0.838336,0.80993,0.757081
6,No log,0.864487,0.813809,0.773886
7,No log,0.866696,0.812258,0.765651
8,No log,0.857248,0.811482,0.776286
9,No log,0.866368,0.820791,0.782245
10,No log,0.879106,0.821567,0.7839


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.612473,0.782002,0.742652
2,No log,0.579143,0.799069,0.758661
3,No log,0.59251,0.802948,0.776339
4,No log,0.763249,0.8045,0.761768
5,No log,0.818862,0.799845,0.762762
6,No log,0.84983,0.782777,0.751258
7,No log,0.907592,0.788984,0.753561
8,No log,0.943972,0.792863,0.759281
9,No log,0.957059,0.785881,0.752174
10,No log,0.964815,0.788984,0.753811


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.568041,0.797517,0.746084
2,No log,0.558345,0.785105,0.753482
3,No log,0.571158,0.816137,0.782984
4,No log,0.717265,0.809154,0.778027
5,No log,0.747605,0.799069,0.769232
6,No log,0.8439,0.808379,0.775781
7,No log,0.929724,0.806827,0.771678
8,No log,0.914328,0.811482,0.780817
9,No log,0.935449,0.808379,0.775527
10,No log,0.942199,0.810706,0.776688


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.585967,0.783553,0.745145
2,No log,0.547059,0.792863,0.763585
3,No log,0.577428,0.791311,0.756643
4,No log,0.655783,0.798293,0.772935
5,No log,0.82819,0.775795,0.737338
6,No log,0.864692,0.792087,0.751504
7,No log,0.869808,0.798293,0.765175
8,No log,0.892493,0.787432,0.753352
9,No log,0.919991,0.785881,0.750383
10,No log,0.933322,0.793638,0.758239


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.592034,0.757176,0.730966
2,No log,0.565699,0.808379,0.769705
3,No log,0.668686,0.802172,0.762128
4,No log,0.800825,0.799845,0.768161
5,No log,0.8541,0.795966,0.768295
6,No log,0.923353,0.788984,0.763255
7,No log,0.982475,0.813033,0.778853
8,No log,0.97697,0.799069,0.767478
9,No log,1.001308,0.8045,0.772923
10,No log,1.005452,0.806827,0.775062


  0%|          | 0/976 [00:00<?, ?it/s]

saving the results dataframe to CSV in therapist_talk_type_output/bert_classifier.csv


In [12]:
bert_classifier["f1"].mean()

0.7662553535255058

In [13]:
bert_classifier["precision"].mean()

0.7647903751315419

In [14]:
bert_classifier["recall"].mean()

0.770116666013689

In [15]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.84674532, 0.60189101, 0.73780547, 0.87857961])

In [16]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.81775941, 0.5757255 , 0.75650451, 0.90917208])

In [17]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.87797186, 0.63163482, 0.72055427, 0.85030571])

## Using Cross-Entropy loss

In [18]:
loss = "cross_entropy"
gamma = None

In [19]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    num_epochs=num_epochs,
    pretrained_model_name="bert-base-uncased",
    df=anno_mi,
    feature_name="utterance_text",
    label_column="main_therapist_behaviour",
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=therapist_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    device=device,
    verbose=False,
)

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.571431,0.802948,0.773919
2,No log,0.525881,0.810706,0.780152
3,No log,0.551556,0.812258,0.785215
4,No log,0.628722,0.811482,0.77983
5,No log,0.694601,0.807603,0.769082
6,No log,0.765525,0.813033,0.778111
7,No log,0.768635,0.812258,0.783637
8,No log,0.78576,0.837083,0.807748
9,No log,0.837983,0.820791,0.791406
10,No log,0.833992,0.82467,0.794468


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.672801,0.749418,0.665244
2,No log,0.571106,0.784329,0.755737
3,No log,0.585277,0.80993,0.772153
4,No log,0.683377,0.782002,0.746544
5,No log,0.75903,0.798293,0.763324
6,No log,0.812454,0.794414,0.763951
7,No log,0.838543,0.801396,0.765577
8,No log,0.914788,0.792087,0.760998
9,No log,0.898033,0.801396,0.768979
10,No log,0.914749,0.802172,0.767947


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.575646,0.788208,0.754933
2,No log,0.546322,0.808379,0.779296
3,No log,0.547506,0.816137,0.782865
4,No log,0.623707,0.813809,0.781326
5,No log,0.770313,0.797517,0.755722
6,No log,0.822703,0.816137,0.782782
7,No log,0.874362,0.808379,0.77794
8,No log,0.912547,0.813809,0.783118
9,No log,0.947362,0.806827,0.771324
10,No log,0.947264,0.812258,0.779021


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.611094,0.781226,0.755188
2,No log,0.566692,0.809154,0.78039
3,No log,0.615885,0.799069,0.779293
4,No log,0.731298,0.786656,0.760541
5,No log,0.789069,0.794414,0.767548
6,No log,0.880517,0.791311,0.760647
7,No log,0.899338,0.792863,0.761648
8,No log,0.957103,0.78976,0.759474
9,No log,0.976084,0.792087,0.763446
10,No log,0.977211,0.79519,0.76711


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.60983,0.773468,0.732524
2,No log,0.544184,0.80993,0.777227
3,No log,0.56938,0.805275,0.774877
4,No log,0.64426,0.810706,0.780823
5,No log,0.69703,0.80993,0.784039
6,No log,0.783188,0.801396,0.773332
7,No log,0.868979,0.793638,0.762193
8,No log,0.874963,0.802172,0.774339
9,No log,0.91729,0.802172,0.772533
10,No log,0.92087,0.800621,0.772214


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.549557,0.801396,0.757078
2,No log,0.523886,0.813809,0.785406
3,No log,0.537359,0.822343,0.790614
4,No log,0.675796,0.797517,0.767345
5,No log,0.731317,0.807603,0.77515
6,No log,0.783278,0.821567,0.785782
7,No log,0.828693,0.822343,0.788623
8,No log,0.875078,0.813809,0.776078
9,No log,0.912663,0.814585,0.77333
10,No log,0.910601,0.811482,0.773593


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.612235,0.776571,0.710071
2,No log,0.576154,0.80993,0.769217
3,No log,0.602178,0.803724,0.774922
4,No log,0.690159,0.792863,0.759508
5,No log,0.802309,0.796742,0.763009
6,No log,0.855965,0.798293,0.767155
7,No log,0.887094,0.792863,0.761735
8,No log,0.93455,0.798293,0.768752
9,No log,0.958049,0.806827,0.774608
10,No log,0.960206,0.8045,0.77252


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.570986,0.798293,0.735772
2,No log,0.5349,0.814585,0.781476
3,No log,0.564424,0.80993,0.778861
4,No log,0.632018,0.811482,0.778538
5,No log,0.732604,0.814585,0.781445
6,No log,0.804715,0.817688,0.781169
7,No log,0.865484,0.806051,0.775515
8,No log,0.902722,0.809154,0.778513
9,No log,0.930885,0.80993,0.776791
10,No log,0.937648,0.80993,0.77809


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.592903,0.78045,0.743533
2,No log,0.53987,0.806051,0.772917
3,No log,0.599936,0.801396,0.770617
4,No log,0.65218,0.802172,0.77008
5,No log,0.802161,0.78976,0.762603
6,No log,0.834728,0.799845,0.773002
7,No log,0.911067,0.788984,0.760599
8,No log,0.932312,0.802172,0.772985
9,No log,0.953257,0.799845,0.769191
10,No log,0.953218,0.800621,0.769996


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.601212,0.781226,0.719939
2,No log,0.60166,0.794414,0.759147
3,No log,0.623297,0.802172,0.77135
4,No log,0.762082,0.783553,0.756882
5,No log,0.840345,0.797517,0.768193
6,No log,0.94385,0.786656,0.756304
7,No log,1.008557,0.795966,0.765063
8,No log,1.054051,0.78976,0.75962
9,No log,1.055455,0.791311,0.762634
10,No log,1.059902,0.788208,0.758235


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.59911,0.776571,0.69157
2,No log,0.540649,0.800621,0.761744
3,No log,0.555174,0.808379,0.768258
4,No log,0.596796,0.80993,0.776714
5,No log,0.680358,0.812258,0.774219
6,No log,0.750684,0.813809,0.770801
7,No log,0.813711,0.808379,0.761408
8,No log,0.842134,0.809154,0.769408
9,No log,0.837644,0.811482,0.772908
10,No log,0.839934,0.813809,0.773972


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.640096,0.770365,0.689762
2,No log,0.590022,0.799845,0.757744
3,No log,0.606813,0.806827,0.772541
4,No log,0.677364,0.801396,0.763291
5,No log,0.746529,0.79519,0.762659
6,No log,0.845345,0.788208,0.747921
7,No log,0.901476,0.798293,0.761125
8,No log,0.962051,0.796742,0.759777
9,No log,0.956004,0.798293,0.764795
10,No log,0.961381,0.795966,0.759997


  0%|          | 0/977 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.559275,0.788208,0.746285
2,No log,0.547478,0.812258,0.781765
3,No log,0.614745,0.799845,0.750032
4,No log,0.756831,0.796742,0.734059
5,No log,0.753551,0.802948,0.756428
6,No log,0.855881,0.800621,0.758568
7,No log,0.89811,0.806051,0.770263
8,No log,0.945837,0.8045,0.770863
9,No log,0.985822,0.803724,0.764268
10,No log,0.981278,0.8045,0.766597


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.577875,0.785881,0.747455
2,No log,0.576929,0.782777,0.745989
3,No log,0.584726,0.807603,0.773918
4,No log,0.659858,0.79519,0.768495
5,No log,0.769042,0.785105,0.748296
6,No log,0.885678,0.777347,0.738142
7,No log,0.943399,0.775019,0.740026
8,No log,0.906213,0.788984,0.754527
9,No log,0.940739,0.785105,0.752156
10,No log,0.950711,0.78976,0.752872


  0%|          | 0/976 [00:00<?, ?it/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]

Map:   0%|          | 0/4882 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.622848,0.769589,0.737511
2,No log,0.551993,0.810706,0.783647
3,No log,0.63345,0.801396,0.755649
4,No log,0.688104,0.809154,0.778295
5,No log,0.749379,0.806827,0.773329
6,No log,0.876056,0.8045,0.775701
7,No log,0.938582,0.799069,0.766476
8,No log,0.980589,0.796742,0.762669
9,No log,0.994611,0.805275,0.772066
10,No log,0.994199,0.807603,0.774182


  0%|          | 0/976 [00:00<?, ?it/s]

saving the results dataframe to CSV in therapist_talk_type_output/bert_classifier_ce.csv


In [20]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.806432,0.777498,"[0.852482269503546, 0.6143896523848019, 0.7585...",0.779359,"[0.8312586445366529, 0.6239737274220033, 0.735...",0.777315,"[0.8748180494905385, 0.6050955414012739, 0.783...",1,cross_entropy,,True
0,0.802335,0.766754,"[0.8471337579617835, 0.5731497418244407, 0.760...",0.773974,"[0.8243801652892562, 0.6235955056179775, 0.722...",0.763614,"[0.87117903930131, 0.5302547770700637, 0.80292...",12,cross_entropy,,True
0,0.799672,0.764325,"[0.8514851485148515, 0.5827702702702703, 0.744...",0.769375,"[0.8019292604501608, 0.6205035971223022, 0.761...",0.762299,"[0.9075691411935953, 0.5493630573248408, 0.728...",123,cross_entropy,,True


In [21]:
bert_classifier_ce["f1"].mean()

0.7695258089451796

In [22]:
bert_classifier_ce["precision"].mean()

0.7742359595241934

In [23]:
bert_classifier_ce["recall"].mean()

0.7677424570460639

In [24]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.85036706, 0.59010322, 0.75454642, 0.88308653])

In [25]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.81918936, 0.62269094, 0.73966398, 0.91539956])

In [26]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.88452208, 0.56157113, 0.77161919, 0.85325743])