In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
import transformers

# set to only report critical errors to avoid excessing logging
transformers.utils.logging.set_verbosity(50)

In [4]:
from nlpsig_networks.scripts.fine_tune_bert_classification import (
    fine_tune_transformer_average_seed,
)

In [5]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [6]:
%run ../load_anno_mi.py

In [7]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-11-09 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-11-09 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-11-09 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-11-09 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-11-09 00:00:34,-1


In [8]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

## Baseline: Fine-tune BERT for classification

In [9]:
num_epochs = 5
learning_rates = [5e-5, 1e-5, 1e-6]
seeds = [1, 12, 123]
validation_metric = "f1"

In [10]:
label_to_id_client

{'neutral': 0, 'change': 1, 'sustain': 2}

In [11]:
id_to_label_client

{0: 'neutral', 1: 'change', 2: 'sustain'}

In [12]:
kwargs = {
    "num_epochs": num_epochs,
    "pretrained_model_name": "bert-base-uncased",
    "df": anno_mi,
    "feature_name": "utterance_text",
    "label_column": "client_talk_type",
    "label_to_id": label_to_id_client,
    "id_to_label": id_to_label_client,
    "output_dim": output_dim_client,
    "learning_rates": learning_rates,
    "seeds": seeds,
    "device": device,
    "batch_size": 8,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "validation_metric": validation_metric,
    "verbose": False,
}

## Focal Loss

In [13]:
loss = "focal"
gamma = 2

In [14]:
bert_classifier = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_focal.csv",
    **kwargs,
)

  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.675592,0.659985,0.416835
2,0.669000,0.635145,0.616808,0.522742
3,0.669000,1.108824,0.639938,0.487981
4,0.478000,1.217864,0.66384,0.530595
5,0.186100,1.40129,0.667695,0.528114


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.687119,0.69783,0.366307
2,0.665500,0.641263,0.658598,0.539876
3,0.665500,0.84217,0.658598,0.528115
4,0.315800,1.203762,0.656093,0.505126
5,0.119400,1.294447,0.663606,0.525201


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.63342,0.619497,0.528791
2,0.661900,0.656693,0.661164,0.524105
3,0.661900,1.04701,0.656447,0.535441
4,0.367500,1.246493,0.647799,0.54377
5,0.130900,1.408912,0.657233,0.54252


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.650523,0.658023,0.481067
2,0.642000,0.678633,0.689627,0.533854
3,0.642000,0.963991,0.704214,0.570266
4,0.347000,1.175009,0.700972,0.562988
5,0.124300,1.257137,0.699352,0.588171


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.69719,0.659332,0.421528
2,0.684700,0.654941,0.682967,0.558969
3,0.684700,0.848778,0.663407,0.533307
4,0.461100,1.166178,0.672372,0.5365
5,0.203800,1.380184,0.678077,0.529205


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.699687,0.529684,0.387907
2,0.744500,0.66862,0.598304,0.484167
3,0.744500,0.705105,0.649961,0.514369
4,0.673200,0.681279,0.614495,0.534497
5,0.493600,0.824582,0.647648,0.524658


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.64142,0.603506,0.511888
2,0.637400,0.655796,0.660267,0.549939
3,0.637400,0.888682,0.686144,0.554837
4,0.324600,1.083677,0.677796,0.534113
5,0.127000,1.193522,0.682805,0.546013


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.704001,0.628145,0.384065
2,0.724000,0.706054,0.633648,0.513394
3,0.724000,0.805295,0.654874,0.498996
4,0.548200,1.042794,0.65173,0.538704
5,0.238200,1.162586,0.64544,0.550053


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.721823,0.622366,0.327659
2,0.740900,0.683429,0.557536,0.44834
3,0.740900,0.690608,0.619935,0.476753
4,0.664100,0.686055,0.65154,0.507645
5,0.512300,0.822636,0.640194,0.51851


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.651309,0.651997,0.549939
2,0.673600,0.768014,0.638142,0.518442
3,0.673600,0.855569,0.641402,0.536258
4,0.431000,1.099267,0.649552,0.532274
5,0.189400,1.260795,0.664222,0.528924


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.738177,0.615266,0.281539
2,0.743200,0.704778,0.521203,0.364408
3,0.743200,0.727351,0.593678,0.398024
4,0.721600,0.725957,0.553585,0.446664
5,0.712400,0.733838,0.629915,0.408881


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.748195,0.230384,0.12483
2,0.779300,0.721273,0.66611,0.266533
3,0.779300,0.725901,0.66611,0.266533
4,0.781100,0.726037,0.66611,0.266533
5,0.756000,0.717999,0.66611,0.266533


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.687347,0.616352,0.468915
2,0.681100,0.632143,0.67217,0.538031
3,0.681100,0.81402,0.628931,0.539804
4,0.449800,1.022653,0.646226,0.544153
5,0.161100,1.214064,0.659591,0.543577


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.62024,0.670989,0.520408
2,0.623300,0.804502,0.667747,0.531717
3,0.623300,0.992513,0.699352,0.571968
4,0.295300,1.242173,0.700162,0.586186
5,0.122300,1.32668,0.690438,0.562587


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.655853,0.680522,0.391569
2,0.635700,0.679611,0.623472,0.515699
3,0.635700,1.057101,0.661777,0.540569
4,0.376600,1.346776,0.661777,0.525447
5,0.164700,1.454775,0.665037,0.530367


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.643432,0.673863,0.467379
2,0.652800,0.620755,0.660756,0.549695
3,0.652800,0.714669,0.665382,0.526774
4,0.493900,0.732494,0.661527,0.531267
5,0.334600,0.75435,0.658443,0.536988


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.659105,0.687813,0.363731
2,0.655500,0.591152,0.671119,0.530108
3,0.655500,0.609691,0.676127,0.555673
4,0.451500,0.653149,0.685309,0.557292
5,0.310800,0.660845,0.676127,0.549866


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.647446,0.626572,0.524971
2,0.662800,0.620483,0.644654,0.519159
3,0.662800,0.660129,0.64544,0.518461
4,0.502900,0.707653,0.633648,0.533786
5,0.322700,0.741499,0.64544,0.533873


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.63461,0.657212,0.492259
2,0.625900,0.651938,0.666126,0.510318
3,0.625900,0.689912,0.672609,0.53549
4,0.465800,0.743043,0.675041,0.553496
5,0.322300,0.795154,0.675851,0.546083


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.655421,0.669112,0.445577
2,0.649800,0.623848,0.677262,0.536496
3,0.649800,0.676974,0.665037,0.539352
4,0.499400,0.762011,0.654442,0.532479
5,0.354200,0.782371,0.669112,0.551987


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.643784,0.657672,0.511586
2,0.646500,0.675383,0.674634,0.53604
3,0.646500,0.80318,0.67926,0.523598
4,0.502700,0.836136,0.659214,0.536644
5,0.338700,0.864016,0.654588,0.539559


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.616539,0.605175,0.500426
2,0.640000,0.642018,0.627713,0.534091
3,0.640000,0.650813,0.65025,0.552643
4,0.456100,0.693179,0.673623,0.558558
5,0.320500,0.705697,0.68197,0.556228


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.689053,0.576258,0.436929
2,0.687400,0.653216,0.647799,0.498788
3,0.687400,0.63962,0.649371,0.529946
4,0.561300,0.657915,0.647799,0.545294
5,0.392600,0.699265,0.637579,0.524918


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.652871,0.642626,0.468498
2,0.631900,0.630971,0.67342,0.548846
3,0.631900,0.716493,0.688817,0.550399
4,0.453300,0.781856,0.688006,0.560101
5,0.300800,0.806577,0.69611,0.565232


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.631908,0.658517,0.523473
2,0.643000,0.662,0.630807,0.510099
3,0.643000,0.685236,0.643032,0.542252
4,0.460200,0.786137,0.629992,0.523836
5,0.326100,0.84118,0.665037,0.542331


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.650914,0.674634,0.469209
2,0.656600,0.591478,0.667695,0.567481
3,0.656600,0.683754,0.685428,0.562171
4,0.496000,0.717567,0.674634,0.567245
5,0.338000,0.773104,0.681573,0.562296


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.607415,0.617696,0.497304
2,0.637000,0.576842,0.679466,0.55819
3,0.637000,0.605349,0.693656,0.576807
4,0.449500,0.656939,0.69783,0.5755
5,0.306700,0.673631,0.699499,0.58182


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.657434,0.649371,0.504376
2,0.662300,0.631021,0.647799,0.536011
3,0.662300,0.678323,0.628145,0.514639
4,0.521100,0.713428,0.654088,0.539118
5,0.370700,0.737338,0.654874,0.532372


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.616408,0.657212,0.528354
2,0.632000,0.682229,0.675851,0.493448
3,0.632000,0.717564,0.679092,0.524122
4,0.440800,0.759655,0.672609,0.544571
5,0.288800,0.815853,0.67423,0.53959


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.63585,0.682967,0.503953
2,0.640800,0.625781,0.658517,0.534538
3,0.640800,0.692839,0.668297,0.53148
4,0.465600,0.718743,0.680522,0.54344
5,0.340000,0.759207,0.678077,0.54236


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.726655,0.621434,0.298736
2,0.746700,0.696244,0.64919,0.382975
3,0.746700,0.682474,0.653816,0.391406
4,0.688300,0.674973,0.659985,0.401324
5,0.668200,0.672802,0.660756,0.40242


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.685194,0.657763,0.345246
2,0.737300,0.658102,0.677796,0.358415
3,0.737300,0.647093,0.669449,0.3978
4,0.673700,0.640757,0.671953,0.392056
5,0.664200,0.640063,0.671119,0.399394


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.725858,0.61478,0.318891
2,0.750100,0.703031,0.616352,0.365541
3,0.750100,0.69323,0.623428,0.383569
4,0.705200,0.687169,0.623428,0.383087
5,0.675300,0.686341,0.624214,0.386613


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.709823,0.639384,0.289278
2,0.718100,0.681973,0.653971,0.359644
3,0.718100,0.669709,0.664506,0.381927
4,0.666000,0.664982,0.662885,0.381238
5,0.645700,0.66271,0.661264,0.384506


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.720432,0.649552,0.296535
2,0.723000,0.693829,0.665037,0.384916
3,0.723000,0.678937,0.669112,0.389133
4,0.658400,0.674414,0.663407,0.399343
5,0.638600,0.672643,0.663407,0.400067


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.709354,0.629915,0.369475
2,0.721700,0.695174,0.640709,0.42312
3,0.721700,0.688225,0.636854,0.442494
4,0.680200,0.68431,0.636854,0.456737
5,0.675600,0.682354,0.635312,0.462054


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.677963,0.663606,0.329376
2,0.728100,0.662478,0.66778,0.377205
3,0.728100,0.656089,0.640234,0.439704
4,0.674900,0.652712,0.634391,0.439699
5,0.659100,0.650675,0.646077,0.462419


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.718105,0.620283,0.37528
2,0.735200,0.703361,0.617925,0.419245
3,0.735200,0.693287,0.61478,0.429864
4,0.698000,0.68928,0.607704,0.45275
5,0.676400,0.688646,0.605346,0.448322


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.703088,0.633712,0.29546
2,0.706900,0.684531,0.654781,0.410107
3,0.706900,0.678421,0.658023,0.427489
4,0.659100,0.675387,0.65154,0.443078
5,0.638800,0.674833,0.649919,0.442329


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.69995,0.646292,0.348395
2,0.710700,0.684902,0.653627,0.395463
3,0.710700,0.674583,0.625102,0.4227
4,0.660700,0.669274,0.638957,0.428341
5,0.646900,0.668959,0.633252,0.423884


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.709726,0.622976,0.349387
2,0.758800,0.689079,0.639167,0.406812
3,0.758800,0.679638,0.645335,0.420797
4,0.689800,0.674511,0.659214,0.43902
5,0.677000,0.673369,0.661527,0.440379


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.678883,0.616027,0.333748
2,0.756100,0.656625,0.648581,0.401063
3,0.756100,0.646402,0.657763,0.412548
4,0.687400,0.640911,0.66778,0.421626
5,0.659100,0.638809,0.671119,0.42359


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.719644,0.570755,0.37117
2,0.773600,0.696764,0.596698,0.420218
3,0.773600,0.68961,0.602987,0.445534
4,0.699900,0.686853,0.617138,0.460352
5,0.674700,0.685727,0.613994,0.458871


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.702859,0.589951,0.324771
2,0.757900,0.684171,0.634522,0.343662
3,0.757900,0.676916,0.638574,0.361585
4,0.665800,0.667686,0.65154,0.40983
5,0.639700,0.668219,0.641005,0.402398


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.705185,0.647922,0.37201
2,0.745200,0.685424,0.656072,0.393957
3,0.745200,0.67623,0.665037,0.411984
4,0.667500,0.673082,0.656887,0.412003
5,0.655800,0.67211,0.658517,0.412642


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.643432,0.673863,0.467379
2,0.652800,0.620755,0.660756,0.549695
3,0.652800,0.714669,0.665382,0.526774
4,0.493900,0.732494,0.661527,0.531267
5,0.334600,0.75435,0.658443,0.536988


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.659105,0.687813,0.363731
2,0.655500,0.591152,0.671119,0.530108
3,0.655500,0.609691,0.676127,0.555673
4,0.451500,0.653149,0.685309,0.557292
5,0.310800,0.660845,0.676127,0.549866


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.647446,0.626572,0.524971
2,0.662800,0.620483,0.644654,0.519159
3,0.662800,0.660129,0.64544,0.518461
4,0.502900,0.707653,0.633648,0.533786
5,0.322700,0.741499,0.64544,0.533873


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.63461,0.657212,0.492259
2,0.625900,0.651938,0.666126,0.510318
3,0.625900,0.689912,0.672609,0.53549
4,0.465800,0.743043,0.675041,0.553496
5,0.322300,0.795154,0.675851,0.546083


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.655421,0.669112,0.445577
2,0.649800,0.623848,0.677262,0.536496
3,0.649800,0.676974,0.665037,0.539352
4,0.499400,0.762011,0.654442,0.532479
5,0.354200,0.782371,0.669112,0.551987


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.643784,0.657672,0.511586
2,0.646500,0.675383,0.674634,0.53604
3,0.646500,0.80318,0.67926,0.523598
4,0.502700,0.836136,0.659214,0.536644
5,0.338700,0.864016,0.654588,0.539559


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.616539,0.605175,0.500426
2,0.640000,0.642018,0.627713,0.534091
3,0.640000,0.650813,0.65025,0.552643
4,0.456100,0.693179,0.673623,0.558558
5,0.320500,0.705697,0.68197,0.556228


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.689053,0.576258,0.436929
2,0.687400,0.653216,0.647799,0.498788
3,0.687400,0.63962,0.649371,0.529946
4,0.561300,0.657915,0.647799,0.545294
5,0.392600,0.699265,0.637579,0.524918


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.652871,0.642626,0.468498
2,0.631900,0.630971,0.67342,0.548846
3,0.631900,0.716493,0.688817,0.550399
4,0.453300,0.781856,0.688006,0.560101
5,0.300800,0.806577,0.69611,0.565232


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.631908,0.658517,0.523473
2,0.643000,0.662,0.630807,0.510099
3,0.643000,0.685236,0.643032,0.542252
4,0.460200,0.786137,0.629992,0.523836
5,0.326100,0.84118,0.665037,0.542331


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.650914,0.674634,0.469209
2,0.656600,0.591478,0.667695,0.567481
3,0.656600,0.683754,0.685428,0.562171
4,0.496000,0.717567,0.674634,0.567245
5,0.338000,0.773104,0.681573,0.562296


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.607415,0.617696,0.497304
2,0.637000,0.576842,0.679466,0.55819
3,0.637000,0.605349,0.693656,0.576807
4,0.449500,0.656939,0.69783,0.5755
5,0.306700,0.673631,0.699499,0.58182


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.657434,0.649371,0.504376
2,0.662300,0.631021,0.647799,0.536011
3,0.662300,0.678323,0.628145,0.514639
4,0.521100,0.713428,0.654088,0.539118
5,0.370700,0.737338,0.654874,0.532372


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.616408,0.657212,0.528354
2,0.632000,0.682229,0.675851,0.493448
3,0.632000,0.717564,0.679092,0.524122
4,0.440800,0.759655,0.672609,0.544571
5,0.288800,0.815853,0.67423,0.53959


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.63585,0.682967,0.503953
2,0.640800,0.625781,0.658517,0.534538
3,0.640800,0.692839,0.668297,0.53148
4,0.465600,0.718743,0.680522,0.54344
5,0.340000,0.759207,0.678077,0.54236


saving the results dataframe to CSV in client_talk_type_output/bert_classifier_focal.csv


In [15]:
bert_classifier

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_accuracy,valid_f1,valid_f1_scores,...,valid_precision_scores,valid_recall,valid_recall_scores,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,0.644061,0.516876,"[0.7660937955141275, 0.4453816954871362, 0.339...",0.513228,"[0.7717136150234741, 0.455565142364107, 0.3124...",0.522369,"[0.7605552342394447, 0.43564356435643564, 0.37...",0.666827,0.548846,"[0.7863608183508989, 0.4844290657439447, 0.375...",...,"[0.774169921875, 0.5147058823529411, 0.3652849...",0.547763,"[0.798941798941799, 0.45751633986928103, 0.386...",1e-06,1,focal,2,True,5,8
0,0.64272,0.515995,"[0.764490095377843, 0.4560394412489729, 0.3274...",0.511459,"[0.7759904676794758, 0.4541734860883797, 0.304...",0.521931,"[0.7533256217466744, 0.45792079207920794, 0.35...",0.667148,0.549964,"[0.78615326727706, 0.5013315579227696, 0.36240...",...,"[0.7827172827172827, 0.510854816824966, 0.3573...",0.549801,"[0.7896195515243134, 0.492156862745098, 0.3676...",1e-06,12,focal,2,True,5,8
0,0.650192,0.523013,"[0.7708092485549133, 0.4469863616366036, 0.351...",0.52112,"[0.770363951473137, 0.47879359095193214, 0.314...",0.529526,"[0.771255060728745, 0.41914191419141916, 0.398...",0.674534,0.555698,"[0.7914004914004915, 0.4916399857701885, 0.384...",...,"[0.7722368736514026, 0.5394223263075723, 0.372...",0.553202,"[0.8115394305870496, 0.4516339869281046, 0.396...",1e-06,123,focal,2,True,5,8


In [16]:
bert_classifier["f1"].mean()

0.5186279622305993

In [17]:
bert_classifier["precision"].mean()

0.5152689378046481

In [18]:
bert_classifier["recall"].mean()

0.5246087278864766

In [19]:
np.stack(bert_classifier["f1_scores"]).mean(axis=0)

array([0.76713105, 0.44946917, 0.33928367])

In [20]:
np.stack(bert_classifier["precision_scores"]).mean(axis=0)

array([0.77268934, 0.46284407, 0.3102734 ])

In [21]:
np.stack(bert_classifier["recall_scores"]).mean(axis=0)

array([0.76171197, 0.43756876, 0.37454545])

## Using Cross-Entropy loss

In [22]:
loss = "cross_entropy"
gamma = None

In [23]:
bert_classifier_ce = fine_tune_transformer_average_seed(
    loss=loss,
    gamma=gamma,
    results_output=f"{output_dir}/bert_classifier_ce.csv",
    **kwargs,
)

  0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.785765,0.672321,0.432937
2,0.775400,0.776554,0.683115,0.534713
3,0.775400,1.093161,0.667695,0.505132
4,0.489000,1.48509,0.673863,0.551776
5,0.232900,1.669982,0.670008,0.553317


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.789381,0.69616,0.401369
2,0.800200,0.744275,0.682805,0.513599
3,0.800200,0.879966,0.665275,0.530236
4,0.481900,1.401957,0.65192,0.505673
5,0.233900,1.584041,0.656928,0.526282


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.802496,0.638365,0.355265
2,0.800700,0.797021,0.662736,0.509817
3,0.800700,1.102244,0.647013,0.499553
4,0.546500,1.726583,0.630503,0.53042
5,0.247800,1.854083,0.636792,0.538918


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.770768,0.646677,0.305516
2,0.756600,0.782754,0.707455,0.525693
3,0.756600,0.947078,0.691248,0.535889
4,0.493100,1.492437,0.684765,0.539541
5,0.230200,1.606085,0.683144,0.546389


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.891895,0.648737,0.262317
2,0.853000,0.839432,0.653627,0.279867
3,0.853000,0.766023,0.686227,0.524709
4,0.748600,0.876836,0.687857,0.520109
5,0.451500,1.161471,0.669112,0.523948


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.783213,0.674634,0.512431
2,0.775100,0.792804,0.667695,0.529006
3,0.775100,1.123906,0.674634,0.485411
4,0.493700,1.429447,0.672321,0.548251
5,0.229700,1.621011,0.666153,0.548185


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.762562,0.695326,0.510326
2,0.757200,0.851971,0.631052,0.506403
3,0.757200,1.025627,0.66611,0.544655
4,0.422600,1.449614,0.661937,0.543015
5,0.213500,1.59424,0.665275,0.543448


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.851551,0.639937,0.393908
2,0.869800,0.841708,0.636006,0.418776
3,0.869800,0.805911,0.65173,0.421439
4,0.787300,0.870235,0.650157,0.548777
5,0.542500,1.092163,0.650157,0.537419


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.796239,0.632901,0.262506
2,0.778700,0.770839,0.675041,0.52665
3,0.778700,0.879101,0.705835,0.570613
4,0.513900,1.319404,0.696921,0.571077
5,0.249900,1.472384,0.693679,0.562052


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.824308,0.648737,0.262317
2,0.850200,0.845203,0.608802,0.414645
3,0.850200,0.884911,0.648737,0.262317
4,0.855500,0.913317,0.654442,0.360031
5,0.813100,0.802362,0.629177,0.406364


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.888055,0.625289,0.256483
2,0.820200,0.7831,0.664611,0.527706
3,0.820200,0.842394,0.666153,0.528574
4,0.625500,1.385373,0.674634,0.534995
5,0.302800,1.465366,0.665382,0.541456


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.842361,0.597663,0.396045
2,0.787900,0.790428,0.66611,0.501457
3,0.787900,0.844369,0.698664,0.546563
4,0.530700,1.10835,0.707012,0.574858
5,0.272200,1.308376,0.685309,0.570146


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.778711,0.659591,0.421995
2,0.780200,0.786736,0.662736,0.502525
3,0.780200,1.259872,0.654874,0.540866
4,0.496100,1.570108,0.660377,0.547654
5,0.227200,1.75441,0.650157,0.548857


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.784893,0.656402,0.502944
2,0.743100,0.878205,0.688817,0.487492
3,0.743100,1.099589,0.691248,0.547808
4,0.435300,1.447949,0.692869,0.567048
5,0.215400,1.546254,0.694489,0.572177


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.768551,0.683782,0.377543
2,0.741100,0.767284,0.682152,0.545899
3,0.741100,1.167989,0.656887,0.547341
4,0.437500,1.648849,0.661777,0.541317
5,0.213800,1.690377,0.678892,0.536598


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.795504,0.648419,0.330705
2,0.794500,0.758615,0.677718,0.471922
3,0.794500,0.778292,0.673863,0.489276
4,0.642900,0.820034,0.688512,0.533556
5,0.495400,0.811645,0.677718,0.538841


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.766466,0.690317,0.354845
2,0.790900,0.731049,0.69616,0.433927
3,0.790900,0.728944,0.69616,0.469222
4,0.614800,0.769742,0.696995,0.52118
5,0.483400,0.780928,0.695326,0.505102


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.814374,0.642296,0.368148
2,0.799300,0.798408,0.652516,0.416997
3,0.799300,0.826625,0.652516,0.448679
4,0.673800,0.863407,0.649371,0.492335
5,0.511000,0.886232,0.653302,0.486532


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.768361,0.679092,0.408775
2,0.762200,0.757097,0.676661,0.41753
3,0.762200,0.79293,0.679092,0.442678
4,0.630400,0.803323,0.682334,0.501506
5,0.485900,0.83303,0.685575,0.498269


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.769003,0.682967,0.429554
2,0.768400,0.751955,0.682152,0.416198
3,0.768400,0.773745,0.687042,0.490934
4,0.620900,0.804679,0.687857,0.5225
5,0.493300,0.81629,0.680522,0.52927


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.8081,0.651503,0.348467
2,0.793700,0.769169,0.668466,0.446413
3,0.793700,0.783387,0.680031,0.489383
4,0.660300,0.801327,0.683115,0.539667
5,0.508600,0.811124,0.675405,0.540383


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.736608,0.699499,0.398352
2,0.787800,0.790235,0.630217,0.454145
3,0.787800,0.763769,0.673623,0.552511
4,0.598000,0.773034,0.694491,0.557888
5,0.456100,0.789072,0.692821,0.555055


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.822914,0.640723,0.394127
2,0.803600,0.802939,0.650943,0.427155
3,0.803600,0.820721,0.650943,0.483208
4,0.668900,0.852612,0.647799,0.519328
5,0.528800,0.868722,0.65173,0.515893


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.784383,0.658023,0.385748
2,0.760800,0.756946,0.67342,0.499995
3,0.760800,0.776097,0.689627,0.500182
4,0.591200,0.813526,0.692869,0.545182
5,0.428100,0.842254,0.692869,0.539422


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.765541,0.696007,0.421237
2,0.771600,0.789904,0.653627,0.447281
3,0.771600,0.774036,0.678077,0.530452
4,0.605700,0.820459,0.678892,0.523419
5,0.472800,0.837541,0.682152,0.530295


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.813689,0.632228,0.280857
2,0.805000,0.751396,0.67926,0.431569
3,0.805000,0.766461,0.684657,0.512219
4,0.654500,0.811363,0.687741,0.543627
5,0.478000,0.82643,0.68697,0.551325


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.752819,0.691152,0.430856
2,0.789800,0.730065,0.689482,0.465001
3,0.789800,0.729519,0.698664,0.501136
4,0.621500,0.764087,0.698664,0.549835
5,0.470500,0.779557,0.702003,0.5585


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.814382,0.65566,0.435788
2,0.801000,0.815573,0.649371,0.473389
3,0.801000,0.83917,0.644654,0.498834
4,0.665200,0.84739,0.65566,0.504325
5,0.516500,0.868303,0.656447,0.514895


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.763591,0.678282,0.433022
2,0.770600,0.770582,0.677472,0.443032
3,0.770600,0.804936,0.681524,0.49431
4,0.596300,0.822873,0.679903,0.512736
5,0.440100,0.844474,0.684765,0.51946


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.773512,0.686227,0.395474
2,0.766300,0.751092,0.679707,0.432649
3,0.766300,0.764598,0.686227,0.495828
4,0.595600,0.784316,0.692747,0.524185
5,0.481500,0.810852,0.687857,0.518335


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.868328,0.625289,0.297352
2,0.903300,0.826822,0.649961,0.385685
3,0.903300,0.816663,0.656901,0.38579
4,0.817700,0.811488,0.661527,0.395454
5,0.797700,0.811335,0.660756,0.390375


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.837551,0.664441,0.301395
2,0.903200,0.794195,0.681135,0.342971
3,0.903200,0.777538,0.684474,0.392401
4,0.815700,0.771672,0.688648,0.399896
5,0.804000,0.770015,0.686144,0.401633


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.875402,0.618711,0.293021
2,0.907600,0.842556,0.632075,0.380146
3,0.907600,0.836068,0.636792,0.37751
4,0.836200,0.833427,0.640723,0.381025
5,0.803800,0.832731,0.636792,0.384695


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.857475,0.635332,0.267403
2,0.878800,0.818924,0.645867,0.314431
3,0.878800,0.80558,0.660454,0.350625
4,0.801300,0.800674,0.662075,0.365347
5,0.778400,0.799995,0.659643,0.362012


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.846682,0.652812,0.303235
2,0.880500,0.808379,0.666667,0.358032
3,0.880500,0.796812,0.674817,0.394381
4,0.782800,0.794275,0.681337,0.401094
5,0.764500,0.793481,0.682967,0.401584


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.853532,0.62606,0.258747
2,0.894200,0.830634,0.62606,0.258747
3,0.894200,0.82339,0.62606,0.258747
4,0.817300,0.820251,0.629144,0.27129
5,0.809300,0.819491,0.630686,0.275395


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.827982,0.666945,0.269198
2,0.901600,0.794285,0.666945,0.269198
3,0.901600,0.784008,0.66778,0.27454
4,0.819100,0.780638,0.666945,0.269198
5,0.807000,0.779318,0.666945,0.269198


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.86863,0.616352,0.254215
2,0.901200,0.84649,0.616352,0.254215
3,0.901200,0.84115,0.615566,0.254015
4,0.830200,0.837964,0.618711,0.265844
5,0.814500,0.837833,0.621855,0.275454


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.84613,0.632901,0.262919
2,0.878600,0.82039,0.632901,0.260586
3,0.878600,0.812095,0.632901,0.260586
4,0.797000,0.808753,0.632901,0.260586
5,0.776800,0.807973,0.632901,0.260586


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.839523,0.648737,0.262317
2,0.881200,0.811414,0.648737,0.262317
3,0.881200,0.801326,0.648737,0.262317
4,0.790600,0.796473,0.649552,0.264871
5,0.779700,0.795498,0.649552,0.264992


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.865953,0.625289,0.256483
2,0.947500,0.82322,0.629144,0.26931
3,0.947500,0.810863,0.639167,0.310273
4,0.825200,0.805972,0.642251,0.319823
5,0.805300,0.805404,0.640709,0.316467


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.84104,0.665275,0.273912
2,0.947100,0.792642,0.66611,0.276527
3,0.947100,0.77551,0.682805,0.330074
4,0.825700,0.768282,0.686978,0.341878
5,0.795800,0.766259,0.687813,0.347002


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.881719,0.615566,0.259621
2,0.958000,0.847582,0.622642,0.298774
3,0.958000,0.836828,0.641509,0.347883
4,0.836500,0.835165,0.641509,0.353802
5,0.805000,0.83363,0.643082,0.357859


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.870973,0.632901,0.263016
2,0.949500,0.822626,0.632901,0.262919
3,0.949500,0.806914,0.635332,0.273671
4,0.799000,0.798632,0.641005,0.294259
5,0.764900,0.798789,0.641815,0.294554


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.845606,0.647922,0.262117
2,0.931800,0.809834,0.647922,0.262117
3,0.931800,0.79602,0.660962,0.310244
4,0.791300,0.791311,0.662592,0.321257
5,0.779700,0.789521,0.669112,0.34403


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.795504,0.648419,0.330705
2,0.794500,0.758615,0.677718,0.471922
3,0.794500,0.778292,0.673863,0.489276
4,0.642900,0.820034,0.688512,0.533556
5,0.495400,0.811645,0.677718,0.538841


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.766466,0.690317,0.354845
2,0.790900,0.731049,0.69616,0.433927
3,0.790900,0.728944,0.69616,0.469222
4,0.614800,0.769742,0.696995,0.52118
5,0.483400,0.780928,0.695326,0.505102


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.814374,0.642296,0.368148
2,0.799300,0.798408,0.652516,0.416997
3,0.799300,0.826625,0.652516,0.448679
4,0.673800,0.863407,0.649371,0.492335
5,0.511000,0.886232,0.653302,0.486532


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.768361,0.679092,0.408775
2,0.762200,0.757097,0.676661,0.41753
3,0.762200,0.79293,0.679092,0.442678
4,0.630400,0.803323,0.682334,0.501506
5,0.485900,0.83303,0.685575,0.498269


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.769003,0.682967,0.429554
2,0.768400,0.751955,0.682152,0.416198
3,0.768400,0.773745,0.687042,0.490934
4,0.620900,0.804679,0.687857,0.5225
5,0.493300,0.81629,0.680522,0.52927


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.8081,0.651503,0.348467
2,0.793700,0.769169,0.668466,0.446413
3,0.793700,0.783387,0.680031,0.489383
4,0.660300,0.801327,0.683115,0.539667
5,0.508600,0.811124,0.675405,0.540383


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.736608,0.699499,0.398352
2,0.787800,0.790235,0.630217,0.454145
3,0.787800,0.763769,0.673623,0.552511
4,0.598000,0.773034,0.694491,0.557888
5,0.456100,0.789072,0.692821,0.555055


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.822914,0.640723,0.394127
2,0.803600,0.802939,0.650943,0.427155
3,0.803600,0.820721,0.650943,0.483208
4,0.668900,0.852612,0.647799,0.519328
5,0.528800,0.868722,0.65173,0.515893


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.784383,0.658023,0.385748
2,0.760800,0.756946,0.67342,0.499995
3,0.760800,0.776097,0.689627,0.500182
4,0.591200,0.813526,0.692869,0.545182
5,0.428100,0.842254,0.692869,0.539422


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.765541,0.696007,0.421237
2,0.771600,0.789904,0.653627,0.447281
3,0.771600,0.774036,0.678077,0.530452
4,0.605700,0.820459,0.678892,0.523419
5,0.472800,0.837541,0.682152,0.530295


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.813689,0.632228,0.280857
2,0.805000,0.751396,0.67926,0.431569
3,0.805000,0.766461,0.684657,0.512219
4,0.654500,0.811363,0.687741,0.543627
5,0.478000,0.82643,0.68697,0.551325


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.752819,0.691152,0.430856
2,0.789800,0.730065,0.689482,0.465001
3,0.789800,0.729519,0.698664,0.501136
4,0.621500,0.764087,0.698664,0.549835
5,0.470500,0.779557,0.702003,0.5585


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.814382,0.65566,0.435788
2,0.801000,0.815573,0.649371,0.473389
3,0.801000,0.83917,0.644654,0.498834
4,0.665200,0.84739,0.65566,0.504325
5,0.516500,0.868303,0.656447,0.514895


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.763591,0.678282,0.433022
2,0.770600,0.770582,0.677472,0.443032
3,0.770600,0.804936,0.681524,0.49431
4,0.596300,0.822873,0.679903,0.512736
5,0.440100,0.844474,0.684765,0.51946


Map:   0%|          | 0/4817 [00:00<?, ? examples/s]

Map:   0%|          | 0/4817 [00:00<?, ? examples/s]



Epoch,Training Loss,Validation Loss,Validation accuracy,Validation f1
1,No log,0.773512,0.686227,0.395474
2,0.766300,0.751092,0.679707,0.432649
3,0.766300,0.764598,0.686227,0.495828
4,0.595600,0.784316,0.692747,0.524185
5,0.481500,0.810852,0.687857,0.518335


saving the results dataframe to CSV in client_talk_type_output/bert_classifier_ce.csv


In [24]:
bert_classifier_ce

Unnamed: 0,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_accuracy,valid_f1,valid_f1_scores,...,valid_precision_scores,valid_recall,valid_recall_scores,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,0.660153,0.492106,"[0.783697315342885, 0.4376330353341848, 0.2549...",0.511266,"[0.7550254623425355, 0.4520668425681618, 0.326...",0.482605,"[0.8146327356853673, 0.4240924092409241, 0.209...",0.677103,0.518548,"[0.7980167630740174, 0.4768632991875662, 0.280...",...,"[0.7507774322523323, 0.5188316679477325, 0.381...",0.505,"[0.8515998992189469, 0.4411764705882353, 0.222...",1e-06,1,cross_entropy,,True,5,8
0,0.656513,0.503736,"[0.7794784580498867, 0.45109612141652616, 0.28...",0.51096,"[0.7643135075041689, 0.46120689655172414, 0.30...",0.498286,"[0.7952573742047426, 0.4414191419141914, 0.258...",0.677425,0.538457,"[0.7971452764001452, 0.49140893470790376, 0.32...",...,"[0.7666356444858073, 0.5181159420289855, 0.38]",0.528066,"[0.8301839254220207, 0.4673202614379085, 0.286...",1e-06,12,cross_entropy,,True,5,8
0,0.668582,0.507366,"[0.7892995035852179, 0.4374152733845459, 0.295...",0.525563,"[0.7543489720611491, 0.4835164835164835, 0.338...",0.496268,"[0.8276460381723539, 0.39933993399339934, 0.26...",0.684329,0.533591,"[0.8031939877876938, 0.4745762711864407, 0.323...",...,"[0.7521442709478777, 0.543918918918919, 0.3983...",0.518066,"[0.8616780045351474, 0.42091503267973857, 0.27...",1e-06,123,cross_entropy,,True,5,8


In [25]:
bert_classifier_ce["f1"].mean()

0.5010695163433949

In [26]:
bert_classifier_ce["precision"].mean()

0.5159295051966489

In [27]:
bert_classifier_ce["recall"].mean()

0.4923865047002098

In [28]:
np.stack(bert_classifier_ce["f1_scores"]).mean(axis=0)

array([0.78415843, 0.44204814, 0.27700198])

In [29]:
np.stack(bert_classifier_ce["precision_scores"]).mean(axis=0)

array([0.75789598, 0.46559674, 0.32429579])

In [30]:
np.stack(bert_classifier_ce["recall_scores"]).mean(axis=0)

array([0.81251205, 0.42161716, 0.2430303 ])