In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [6]:
%run ../load_anno_mi.py

In [7]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-18 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-18 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-18 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-18 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-18 00:00:34


In [8]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# Baseline: Fine-tune BERT for classification

In [9]:
num_epochs = 10
seeds = [1, 12, 123]
validation_metric = "f1"

In [None]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": anno_mi,
    "feature_name": "utterance_text",
    "label_column": "client_talk_type",
    "seeds": seeds,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "validation_metric": validation_metric,
    "device": device,
    "verbose": False,
}

## Focal Loss

In [None]:
loss = "focal"
gamma = 2

In [10]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.606866,0.669025,0.525367
2,No log,0.600274,0.642296,0.52572
3,No log,0.691045,0.675314,0.540597
4,No log,0.932919,0.67217,0.540132
5,No log,1.030759,0.632075,0.528385
6,No log,1.115621,0.66195,0.539733
7,No log,1.166045,0.669025,0.548507
8,No log,1.222678,0.670597,0.544779
9,No log,1.28519,0.676101,0.5498
10,No log,1.244305,0.663522,0.550576


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.680067,0.658805,0.457608
2,No log,0.63706,0.647799,0.54001
3,No log,0.690647,0.628931,0.552262
4,No log,0.919274,0.637579,0.526459
5,No log,1.098896,0.652516,0.522528
6,No log,1.152871,0.665094,0.542922
7,No log,1.263884,0.649371,0.503578
8,No log,1.335778,0.646226,0.523546
9,No log,1.432483,0.660377,0.518369
10,No log,1.458723,0.65566,0.517527


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.645491,0.654088,0.511456
2,No log,0.608551,0.639937,0.518774
3,No log,0.651519,0.665881,0.567288
4,No log,0.820383,0.680818,0.563683
5,No log,0.970285,0.691038,0.547191
6,No log,1.086527,0.695755,0.559652
7,No log,1.212684,0.680818,0.531837
8,No log,1.15502,0.687107,0.570321
9,No log,1.192023,0.687107,0.560522
10,No log,1.21004,0.688679,0.561281


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.606264,0.652516,0.501832
2,No log,0.586937,0.670597,0.544494
3,No log,0.717314,0.647799,0.515063
4,No log,0.884184,0.653302,0.54201
5,No log,1.053661,0.680818,0.536112
6,No log,1.210249,0.678459,0.534536
7,No log,1.189986,0.676101,0.545092
8,No log,1.279304,0.676887,0.545799
9,No log,1.292987,0.683962,0.553912
10,No log,1.295929,0.676887,0.554442


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.65383,0.676887,0.471418
2,No log,0.620035,0.622642,0.530974
3,No log,0.661666,0.680031,0.5292
4,No log,0.803726,0.636006,0.534496
5,No log,0.961537,0.64544,0.520602
6,No log,1.130855,0.662736,0.52625
7,No log,1.127969,0.662736,0.529792
8,No log,1.229601,0.67217,0.523147
9,No log,1.271869,0.673742,0.533436
10,No log,1.261186,0.671384,0.535878


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.601597,0.639937,0.524405
2,No log,0.636404,0.631289,0.49248
3,No log,0.755748,0.668239,0.528575
4,No log,0.80716,0.658019,0.550823
5,No log,0.991903,0.621069,0.533957
6,No log,1.121427,0.664308,0.515898
7,No log,1.086296,0.665881,0.544322
8,No log,1.184168,0.658019,0.539548
9,No log,1.233599,0.661164,0.541001
10,No log,1.250637,0.665094,0.542078


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.650595,0.66195,0.502053
2,No log,0.629223,0.654874,0.526018
3,No log,0.690637,0.654088,0.552857
4,No log,0.837607,0.666667,0.551127
5,No log,1.007583,0.660377,0.509937
6,No log,1.055075,0.656447,0.515954
7,No log,1.18307,0.660377,0.520858
8,No log,1.180562,0.665881,0.550803
9,No log,1.25794,0.663522,0.523719
10,No log,1.228732,0.669811,0.544107


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.69934,0.613208,0.459978
2,No log,0.616409,0.610063,0.523355
3,No log,0.621064,0.632075,0.538376
4,No log,0.676209,0.628931,0.550141
5,No log,0.939947,0.693396,0.53125
6,No log,0.891134,0.696541,0.569403
7,No log,0.91485,0.664308,0.563032
8,No log,1.017268,0.676887,0.555046
9,No log,1.013041,0.678459,0.565279
10,No log,1.027291,0.676101,0.563491


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.616306,0.643082,0.503158
2,No log,0.639753,0.631289,0.507622
3,No log,0.684261,0.663522,0.540417
4,No log,0.910455,0.665094,0.5222
5,No log,1.027754,0.666667,0.533702
6,No log,1.096419,0.676101,0.546719
7,No log,1.162876,0.67217,0.530513
8,No log,1.208605,0.681604,0.545099
9,No log,1.250113,0.684748,0.547823
10,No log,1.240073,0.673742,0.538688


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.656904,0.591195,0.500229
2,No log,0.62449,0.633648,0.522044
3,No log,0.695699,0.639151,0.531137
4,No log,0.909873,0.647799,0.524393
5,No log,1.116414,0.67217,0.533537
6,No log,1.174379,0.672956,0.530237
7,No log,1.249399,0.660377,0.512607
8,No log,1.259374,0.657233,0.518364
9,No log,1.322095,0.678459,0.528632
10,No log,1.28224,0.661164,0.527918


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.627458,0.607704,0.515967
2,No log,0.613182,0.632862,0.512982
3,No log,0.691948,0.639937,0.530601
4,No log,0.921213,0.539308,0.464302
5,No log,1.097009,0.676101,0.517794
6,No log,1.181695,0.626572,0.52037
7,No log,1.215317,0.663522,0.539124
8,No log,1.272125,0.675314,0.547269
9,No log,1.276884,0.662736,0.538704
10,No log,1.285431,0.668239,0.542234


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.636363,0.650943,0.524258
2,No log,0.693953,0.674528,0.530489
3,No log,0.737253,0.633648,0.50715
4,No log,0.92963,0.649371,0.502006
5,No log,0.996656,0.631289,0.546219
6,No log,1.135448,0.679245,0.548142
7,No log,1.20262,0.659591,0.540908
8,No log,1.231,0.667453,0.539293
9,No log,1.301063,0.661164,0.546057
10,No log,1.341479,0.671384,0.539933


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.679471,0.68239,0.429383
2,No log,0.587518,0.658019,0.541654
3,No log,0.597252,0.685535,0.566844
4,No log,0.762522,0.691038,0.562226
5,No log,0.834511,0.681604,0.568964
6,No log,0.957951,0.697327,0.584316
7,No log,1.08281,0.703616,0.57704
8,No log,1.079042,0.700472,0.589582
9,No log,1.098026,0.707547,0.592512
10,No log,1.115709,0.706761,0.593334


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.631255,0.631289,0.477173
2,No log,0.607072,0.686321,0.51572
3,No log,0.633498,0.631289,0.538601
4,No log,0.700224,0.666667,0.532478
5,No log,0.990108,0.666667,0.535859
6,No log,1.110501,0.665881,0.530034
7,No log,1.15956,0.666667,0.536044
8,No log,1.239554,0.666667,0.524859
9,No log,1.350425,0.680818,0.529432
10,No log,1.33919,0.667453,0.528805


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.644488,0.646226,0.488801
2,No log,0.612852,0.643082,0.544573
3,No log,0.81941,0.65173,0.493424
4,No log,0.846447,0.644654,0.527276
5,No log,1.052485,0.658019,0.514323
6,No log,1.035218,0.64544,0.521089
7,No log,1.152281,0.669025,0.531147
8,No log,1.160998,0.662736,0.537167
9,No log,1.206446,0.672956,0.533714
10,No log,1.216705,0.667453,0.53217


  0%|          | 0/963 [00:00<?, ?it/s]

saving the results dataframe to CSV in client_talk_type_output/bert_classifier.csv


In [11]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.648329,0.539967,"[0.7717855968431436, 0.47443946188340813, 0.37...",0.53512,"[0.7870556673373574, 0.502851711026616, 0.3154...",0.554806,"[0.7570967741935484, 0.44906621392190155, 0.45...",1,focal,2,True
0,0.642101,0.528872,"[0.771127918447879, 0.4605026929982047, 0.3549...",0.524773,"[0.7863849765258216, 0.48857142857142855, 0.29...",0.542643,"[0.7564516129032258, 0.43548387096774194, 0.43...",12,focal,2,True
0,0.665559,0.548762,"[0.7854753941710464, 0.473508353221957, 0.3873...",0.551676,"[0.7757156338471217, 0.5408942202835333, 0.338...",0.556409,"[0.795483870967742, 0.42105263157894735, 0.452...",123,focal,2,True


In [12]:
bert_classifier["f1"].mean()

0.5392002385817745

In [13]:
bert_classifier["precision"].mean()

0.5371898824473699

In [14]:
bert_classifier["recall"].mean()

0.5512859722270346

In [15]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.77612964, 0.4694835 , 0.37198758])

In [16]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.78305209, 0.51077245, 0.3177451 ])

In [17]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.76967742, 0.43520091, 0.44897959])

## Using Cross-Entropy loss

In [18]:
loss = "cross_entropy"
gamma = None

In [19]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.779004,0.671384,0.377502
2,No log,0.749475,0.683176,0.421864
3,No log,0.807604,0.676101,0.546421
4,No log,0.940113,0.699686,0.544097
5,No log,1.091524,0.654874,0.551821
6,No log,1.22637,0.687107,0.549546
7,No log,1.264411,0.691824,0.553405
8,No log,1.340211,0.678459,0.549437
9,No log,1.411818,0.679245,0.540458
10,No log,1.416991,0.679245,0.548778


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.802116,0.660377,0.408749
2,No log,0.797723,0.665881,0.422781
3,No log,0.896729,0.65566,0.48157
4,No log,1.01897,0.656447,0.551653
5,No log,1.187551,0.67217,0.540182
6,No log,1.304959,0.662736,0.532352
7,No log,1.363464,0.671384,0.524665
8,No log,1.451214,0.674528,0.54291
9,No log,1.48826,0.674528,0.544425
10,No log,1.512795,0.65566,0.523127


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.763045,0.679245,0.414249
2,No log,0.746144,0.687107,0.446927
3,No log,0.753428,0.690252,0.514225
4,No log,0.84587,0.701258,0.535093
5,No log,1.030446,0.684748,0.540583
6,No log,1.051117,0.708333,0.568523
7,No log,1.190214,0.694182,0.561335
8,No log,1.24792,0.691038,0.555907
9,No log,1.29894,0.70283,0.587403
10,No log,1.321749,0.694182,0.560114


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.752148,0.684748,0.371504
2,No log,0.727708,0.694969,0.452008
3,No log,0.863995,0.661164,0.526238
4,No log,0.969219,0.672956,0.520023
5,No log,1.188775,0.660377,0.531924
6,No log,1.296789,0.680031,0.525025
7,No log,1.436814,0.685535,0.53528
8,No log,1.456275,0.675314,0.531488
9,No log,1.527386,0.68239,0.540874
10,No log,1.535101,0.683176,0.539337


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.772607,0.680031,0.401387
2,No log,0.767437,0.665881,0.41831
3,No log,0.837378,0.662736,0.489547
4,No log,0.901249,0.669025,0.535116
5,No log,1.035932,0.67217,0.514277
6,No log,1.193257,0.662736,0.509258
7,No log,1.245109,0.661164,0.521818
8,No log,1.295072,0.659591,0.529544
9,No log,1.354743,0.671384,0.527072
10,No log,1.359551,0.665881,0.526551


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.75878,0.669025,0.428823
2,No log,0.769441,0.662736,0.495501
3,No log,0.830565,0.676887,0.516118
4,No log,0.952494,0.669025,0.529308
5,No log,1.178045,0.648585,0.540099
6,No log,1.250462,0.668239,0.541766
7,No log,1.3991,0.663522,0.555011
8,No log,1.427162,0.670597,0.548411
9,No log,1.496284,0.664308,0.5384
10,No log,1.524389,0.669025,0.544467


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.802861,0.666667,0.398571
2,No log,0.767324,0.680031,0.517611
3,No log,0.837007,0.688679,0.577445
4,No log,1.061572,0.643082,0.563269
5,No log,1.16844,0.661164,0.511061
6,No log,1.279465,0.679245,0.534033
7,No log,1.430802,0.669811,0.518136
8,No log,1.500656,0.661164,0.552074
9,No log,1.555207,0.666667,0.540892
10,No log,1.558535,0.669025,0.545056


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.791687,0.673742,0.341165
2,No log,0.734132,0.696541,0.450991
3,No log,0.738295,0.693396,0.52138
4,No log,0.861476,0.701258,0.585338
5,No log,0.999573,0.707547,0.551143
6,No log,1.138026,0.680818,0.580209
7,No log,1.245262,0.699686,0.555284
8,No log,1.360876,0.679245,0.552219
9,No log,1.372124,0.70283,0.572974
10,No log,1.393855,0.702044,0.575529


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.757509,0.676887,0.387625
2,No log,0.798934,0.647013,0.514152
3,No log,0.795258,0.688679,0.511609
4,No log,0.928171,0.688679,0.514618
5,No log,1.064539,0.678459,0.522102
6,No log,1.207122,0.67217,0.55025
7,No log,1.358213,0.678459,0.54106
8,No log,1.433254,0.663522,0.531324
9,No log,1.479519,0.687893,0.564212
10,No log,1.500348,0.68239,0.556739


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.810828,0.631289,0.416659
2,No log,0.783756,0.669025,0.521958
3,No log,0.854953,0.660377,0.522437
4,No log,1.015548,0.680818,0.510494
5,No log,1.239255,0.685535,0.492021
6,No log,1.289434,0.665881,0.50182
7,No log,1.425067,0.674528,0.511433
8,No log,1.4875,0.66195,0.517555
9,No log,1.531644,0.666667,0.506338
10,No log,1.560167,0.664308,0.520147


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.780113,0.666667,0.400788
2,No log,0.733287,0.678459,0.499723
3,No log,0.831472,0.662736,0.553301
4,No log,0.916299,0.671384,0.524253
5,No log,1.184144,0.658019,0.513303
6,No log,1.336765,0.662736,0.527821
7,No log,1.409262,0.669025,0.521658
8,No log,1.491128,0.658019,0.534802
9,No log,1.561862,0.650157,0.531061
10,No log,1.548291,0.668239,0.536178


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.799151,0.656447,0.420407
2,No log,0.828129,0.675314,0.476711
3,No log,0.888067,0.664308,0.475393
4,No log,1.035264,0.657233,0.481719
5,No log,1.070191,0.654874,0.528785
6,No log,1.265966,0.658019,0.546121
7,No log,1.359896,0.65566,0.52928
8,No log,1.44804,0.653302,0.535684
9,No log,1.513988,0.65173,0.526098
10,No log,1.550179,0.659591,0.523073


  0%|          | 0/964 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.784537,0.676887,0.405467
2,No log,0.755962,0.688679,0.447619
3,No log,0.772917,0.691038,0.509058
4,No log,0.836107,0.704403,0.56764
5,No log,0.972668,0.694969,0.581684
6,No log,1.068029,0.710692,0.557442
7,No log,1.113394,0.709906,0.576966
8,No log,1.231365,0.70283,0.579308
9,No log,1.274859,0.700472,0.580479
10,No log,1.288114,0.704403,0.579455


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.778412,0.673742,0.412936
2,No log,0.746663,0.694969,0.445481
3,No log,0.744533,0.698899,0.521437
4,No log,0.867011,0.678459,0.518368
5,No log,1.216551,0.650157,0.541425
6,No log,1.258691,0.675314,0.545662
7,No log,1.410093,0.663522,0.535548
8,No log,1.403985,0.681604,0.531914
9,No log,1.489538,0.676101,0.524357
10,No log,1.497319,0.673742,0.533103


  0%|          | 0/963 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.781303,0.66195,0.323006
2,No log,0.763113,0.67217,0.526785
3,No log,0.819909,0.676101,0.515895
4,No log,0.913625,0.661164,0.522468
5,No log,1.110329,0.676101,0.480162
6,No log,1.192202,0.669811,0.509133
7,No log,1.323109,0.666667,0.512463
8,No log,1.438628,0.666667,0.50375
9,No log,1.466388,0.660377,0.50785
10,No log,1.481672,0.660377,0.513286


  0%|          | 0/963 [00:00<?, ?it/s]

saving the results dataframe to CSV in client_talk_type_output/bert_classifier_ce.csv


In [20]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,seed,loss,gamma,k_fold
0,0.68279,0.447718,"[0.8005641748942174, 0.4305627258647393, 0.112...",0.586895,"[0.7112781954887218, 0.549407114624506, 0.5]",0.444184,"[0.915483870967742, 0.3539898132427844, 0.0630...",1,cross_entropy,,True
0,0.681129,0.488027,"[0.8026736413833189, 0.4374683544303798, 0.223...",0.54596,"[0.7303014278159704, 0.5420326223337516, 0.365...",0.473034,"[0.8909677419354839, 0.366723259762309, 0.1614...",12,cross_entropy,,True
0,0.681752,0.493653,"[0.8026865671641792, 0.4739726027397261, 0.204...",0.543507,"[0.7469444444444444, 0.5128458498023716, 0.370...",0.482999,"[0.8674193548387097, 0.4405772495755518, 0.141...",123,cross_entropy,,True


In [21]:
bert_classifier_ce["f1"].mean()

0.4764659694252144

In [22]:
bert_classifier_ce["precision"].mean()

0.5587875089238038

In [23]:
bert_classifier_ce["recall"].mean()

0.4667392157253909

In [24]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.80197479, 0.44733456, 0.18008855])

In [25]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.72950802, 0.53476186, 0.41209264])

In [26]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.89129032, 0.38709677, 0.12183055])