In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    histories_baseline_hyperparameter_search
)

In [3]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-12 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-12 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-12 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-12 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-12 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

# Baseline: Averaging history and use FFN

Here, we average the full history of a path and concatenate it the current embedding (the total number of features that are passed into the FFN is `2 * sbert_embeddings.shape[0]`).

Here, we will run the hyperparameter search to implement the FFN with the same parameters as the standard FFN baseline on the sentence embeddings. Going to try out some variations (1 hidden layer, 2 hidden layers and 3 hidden layers - all of size 100).

In [7]:
num_epochs = 100
hidden_dim_sizes = [[128,128],[256,256],[512,512]]
dropout_rates = [0.5, 0.2, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [0, 1, 12, 123, 1234]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [8]:
ffn_mean_history, best_ffn_mean_history, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=False,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}.csv",
    verbose=False
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_mean_history_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_mean_history_focal_2_best_model.csv


In [9]:
ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,0.572498,0.721190,0.651794,"[0.8105515587529976, 0.5833333333333333, 0.561...",0.651200,"[0.8076463560334528, 0.5962145110410094, 0.549...",0.652748,"[0.8134777376654633, 0.5709969788519638, 0.573...",0.538772,0.728625,...,0.5,0.0010,0,focal,2,False,,64,0.00,768
0,0.537211,0.732342,0.658555,"[0.8188701223063483, 0.5829307568438004, 0.573...",0.671742,"[0.7934537246049661, 0.6241379310344828, 0.597...",0.648236,"[0.8459687123947052, 0.5468277945619335, 0.551...",0.506889,0.743494,...,0.5,0.0010,1,focal,2,False,,64,0.00,768
0,0.518405,0.732342,0.662321,"[0.8180201541197393, 0.5990491283676703, 0.569...",0.665640,"[0.8060747663551402, 0.63, 0.5608465608465608]",0.660186,"[0.8303249097472925, 0.5709969788519638, 0.579...",0.486352,0.734201,...,0.5,0.0010,12,focal,2,False,,64,0.00,768
0,0.582249,0.707807,0.655892,"[0.7878017789072428, 0.6054054054054053, 0.574...",0.647239,"[0.8344549125168237, 0.5476772616136919, 0.559...",0.670997,"[0.7460890493381468, 0.676737160120846, 0.5901...",0.551687,0.718401,...,0.5,0.0010,123,focal,2,False,,64,0.00,768
0,0.533087,0.729368,0.657557,"[0.81673541543901, 0.5874587458745875, 0.56847...",0.662240,"[0.8002309468822171, 0.6472727272727272, 0.539...",0.657597,"[0.8339350180505415, 0.5377643504531722, 0.601...",0.513729,0.733271,...,0.5,0.0010,1234,focal,2,False,,64,0.00,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.625520,0.725651,0.652992,"[0.8094674556213017, 0.6045801526717558, 0.544...",0.662544,"[0.7962747380675204, 0.6111111111111112, 0.580...",0.644984,"[0.8231046931407943, 0.5981873111782477, 0.513...",0.567449,0.736059,...,0.1,0.0005,0,focal,2,False,,64,0.26,768
0,0.548754,0.729368,0.663040,"[0.81437125748503, 0.596969696969697, 0.577777...",0.665614,"[0.8104886769964244, 0.5987841945288754, 0.587...",0.660588,"[0.8182912154031288, 0.595166163141994, 0.5683...",0.525988,0.732342,...,0.1,0.0005,1,focal,2,False,,64,0.26,768
0,0.560586,0.732342,0.652622,"[0.820393974507532, 0.5888, 0.5486725663716814]",0.671022,"[0.7910614525139665, 0.6258503401360545, 0.596...",0.638691,"[0.851985559566787, 0.5558912386706949, 0.5081...",0.496237,0.747212,...,0.1,0.0005,12,focal,2,False,,64,0.26,768
0,0.539687,0.686989,0.639620,"[0.7679158448389217, 0.592039800995025, 0.5589...",0.636663,"[0.8463768115942029, 0.5031712473572939, 0.560...",0.659726,"[0.7027677496991577, 0.7190332326283988, 0.557...",0.507002,0.711896,...,0.1,0.0005,123,focal,2,False,,64,0.26,768


In [10]:
ffn_mean_history.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
"(128, 128)",0.1,0.0001,0.501766,0.708253,0.642372,0.637302,0.648999,0.495057,0.711338,0.627209,0.617011,0.64175,274.0,2.0,0.0,64.0,0.7,768.0
"(128, 128)",0.1,0.0005,0.50786,0.721784,0.656571,0.653669,0.66092,0.503272,0.726208,0.642483,0.634041,0.655011,274.0,2.0,0.0,64.0,0.8,768.0
"(128, 128)",0.1,0.001,0.610362,0.711375,0.65111,0.646191,0.661664,0.579054,0.722862,0.645543,0.633922,0.667261,274.0,2.0,0.0,64.0,0.6,768.0
"(128, 128)",0.2,0.0001,0.50528,0.710186,0.646385,0.640746,0.65325,0.491654,0.711896,0.627392,0.616037,0.644471,274.0,2.0,0.0,64.0,0.4,768.0
"(128, 128)",0.2,0.0005,0.514852,0.714349,0.649244,0.644351,0.65561,0.505491,0.724721,0.641884,0.631611,0.6563,274.0,2.0,0.0,64.0,0.5,768.0
"(128, 128)",0.2,0.001,0.616134,0.719851,0.647776,0.650246,0.64706,0.580686,0.733271,0.647068,0.641348,0.656821,274.0,2.0,0.0,64.0,0.3,768.0
"(128, 128)",0.5,0.0001,0.53514,0.708253,0.633688,0.636913,0.631331,0.508891,0.713383,0.619475,0.61481,0.626662,274.0,2.0,0.0,64.0,0.1,768.0
"(128, 128)",0.5,0.0005,0.562434,0.720744,0.65383,0.653108,0.655905,0.533762,0.733271,0.649443,0.64265,0.660373,274.0,2.0,0.0,64.0,0.2,768.0
"(128, 128)",0.5,0.001,0.54869,0.72461,0.657224,0.659612,0.657953,0.519486,0.731599,0.647689,0.643573,0.656675,274.0,2.0,0.0,64.0,0.0,768.0
"(256, 256)",0.1,0.0001,0.505269,0.712416,0.648094,0.644604,0.652135,0.498619,0.724164,0.638809,0.630917,0.6493,274.0,2.0,0.0,64.0,0.16,768.0


In [11]:
best_ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,0.547946,0.714498,0.660005,"[0.7942464040025016, 0.6062246278755076, 0.579...",0.659798,"[0.8268229166666666, 0.5490196078431373, 0.603...",0.666085,"[0.7641395908543923, 0.676737160120846, 0.5573...",0.520177,0.722119,...,"(256, 256)",0.5,0.0005,0,focal,2,False,,64,768
0,0.659909,0.724907,0.65391,"[0.8079951544518473, 0.6268221574344024, 0.526...",0.655369,"[0.8134146341463414, 0.6056338028169014, 0.547...",0.653464,"[0.802647412755716, 0.649546827794562, 0.50819...",0.608059,0.752788,...,"(256, 256)",0.5,0.0005,1,focal,2,False,,64,768
0,0.547631,0.736803,0.667662,"[0.822262118491921, 0.6101190476190476, 0.5706...",0.67423,"[0.8178571428571428, 0.6011730205278593, 0.603...",0.662345,"[0.8267148014440433, 0.6193353474320241, 0.540...",0.522817,0.748141,...,"(256, 256)",0.5,0.0005,12,focal,2,False,,64,768
0,0.546686,0.717472,0.650839,"[0.8052884615384615, 0.5812807881773399, 0.565...",0.648429,"[0.8043217286914766, 0.6366906474820144, 0.504...",0.661936,"[0.8062575210589651, 0.5347432024169184, 0.644...",0.510219,0.73513,...,"(256, 256)",0.5,0.0005,123,focal,2,False,,64,768
0,0.542385,0.730112,0.65883,"[0.8172043010752688, 0.6049382716049382, 0.554...",0.660345,"[0.8113879003558719, 0.6182965299684543, 0.551...",0.657542,"[0.8231046931407943, 0.5921450151057401, 0.557...",0.508341,0.743494,...,"(256, 256)",0.5,0.0005,1234,focal,2,False,,64,768


In [12]:
best_ffn_mean_history["f1"].mean()

0.6582492815821922

In [13]:
best_ffn_mean_history["precision"].mean()

0.6596340295302325

In [14]:
best_ffn_mean_history["recall"].mean()

0.6602743161015272

In [15]:
np.stack(best_ffn_mean_history["f1_scores"]).mean(axis=0)

array([0.80939929, 0.60587698, 0.55947158])

In [16]:
np.stack(best_ffn_mean_history["precision_scores"]).mean(axis=0)

array([0.81476086, 0.60216272, 0.5619785 ])

In [17]:
np.stack(best_ffn_mean_history["recall_scores"]).mean(axis=0)

array([0.8045728 , 0.61450151, 0.56174863])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [18]:
ffn_mean_history_kfold, best_ffn_mean_history_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}_kfold.csv",
    verbose=False
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_mean_history_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_mean_history_focal_2_kfold_best_model.csv


In [19]:
ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,,0.691152,0.623333,"[0.7843962848297213, 0.5783277995774223, 0.507...",0.610416,"[0.8208916537065837, 0.5823708206686931, 0.427...",0.649322,"[0.7510078254683424, 0.5743405275779376, 0.622...",,0.689977,...,0.5,0.0010,0,focal,2,True,5,64,0.00,768
0,,0.697844,0.621761,"[0.794698795180723, 0.5646085295989816, 0.5059...",0.614812,"[0.807739407298555, 0.601763907734057, 0.43493...",0.639536,"[0.7820725634337207, 0.5317745803357314, 0.604...",,0.698649,...,0.5,0.0010,1,focal,2,True,5,64,0.00,768
0,,0.699628,0.626092,"[0.7914503079338244, 0.5781729000613122, 0.508...",0.617496,"[0.8063484251968503, 0.5915934755332497, 0.454...",0.639940,"[0.7770927199430875, 0.565347721822542, 0.5773...",,0.699550,...,0.5,0.0010,12,focal,2,True,5,64,0.00,768
0,,0.700074,0.621861,"[0.7954680977936793, 0.5654819084213897, 0.504...",0.617221,"[0.8001439539347409, 0.6068728522336769, 0.444...",0.634519,"[0.7908465733934076, 0.5293764988009593, 0.583...",,0.699775,...,0.5,0.0010,123,focal,2,True,5,64,0.00,768
0,,0.699777,0.625108,"[0.7943159922928709, 0.5693568726355612, 0.511...",0.617558,"[0.8069488622461464, 0.6003989361702128, 0.445...",0.641543,"[0.7820725634337207, 0.5413669064748201, 0.601...",,0.696847,...,0.5,0.0010,1234,focal,2,True,5,64,0.00,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.684610,0.617443,"[0.7767401319885444, 0.5780652749191415, 0.497...",0.603716,"[0.8177766124803356, 0.5672244662435084, 0.426...",0.642191,"[0.7396253260611809, 0.5893285371702638, 0.597...",,0.685698,...,0.1,0.0005,0,focal,2,True,5,64,0.26,768
0,,0.697100,0.622179,"[0.7908837434840587, 0.5769696969696969, 0.498...",0.613036,"[0.8090277777777778, 0.5833333333333334, 0.446...",0.636188,"[0.7735356888783496, 0.5707434052757794, 0.564...",,0.700676,...,0.1,0.0005,1,focal,2,True,5,64,0.26,768
0,,0.693978,0.621953,"[0.7869090909090909, 0.5693839452395768, 0.509...",0.612749,"[0.804859905777337, 0.591849935316947, 0.44153...",0.640228,"[0.7697415224092957, 0.5485611510791367, 0.602...",,0.700788,...,0.1,0.0005,12,focal,2,True,5,64,0.26,768
0,,0.687435,0.621942,"[0.7777639905695495, 0.5725760183591508, 0.515...",0.608570,"[0.8157209786569495, 0.5489548954895489, 0.461...",0.642009,"[0.7431823571259188, 0.5983213429256595, 0.584...",,0.688514,...,0.1,0.0005,123,focal,2,True,5,64,0.26,768


In [20]:
ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(128, 128)",0.1,0.0001,0.69139,0.614987,0.60672,0.627088,0.699797,0.624897,0.61595,0.638185,274.0,2.0,1.0,5.0,64.0,0.7,768.0
"(128, 128)",0.1,0.0005,0.688892,0.61809,0.606822,0.635875,0.693446,0.622956,0.610953,0.64187,274.0,2.0,1.0,5.0,64.0,0.8,768.0
"(128, 128)",0.1,0.001,0.688,0.616879,0.606436,0.637181,0.691509,0.620321,0.609165,0.642442,274.0,2.0,1.0,5.0,64.0,0.6,768.0
"(128, 128)",0.2,0.0001,0.692104,0.615656,0.607199,0.628456,0.699189,0.624622,0.615427,0.638823,274.0,2.0,1.0,5.0,64.0,0.4,768.0
"(128, 128)",0.2,0.0005,0.689487,0.618439,0.607484,0.637501,0.69482,0.624219,0.612581,0.64461,274.0,2.0,1.0,5.0,64.0,0.5,768.0
"(128, 128)",0.2,0.001,0.688506,0.617959,0.607217,0.636164,0.693423,0.6226,0.611015,0.642424,274.0,2.0,1.0,5.0,64.0,0.3,768.0
"(128, 128)",0.5,0.0001,0.694007,0.618233,0.609635,0.631957,0.700788,0.627256,0.617649,0.642425,274.0,2.0,1.0,5.0,64.0,0.1,768.0
"(128, 128)",0.5,0.0005,0.694483,0.621495,0.612236,0.638983,0.698401,0.626295,0.615963,0.645451,274.0,2.0,1.0,5.0,64.0,0.2,768.0
"(128, 128)",0.5,0.001,0.697695,0.623631,0.615501,0.640972,0.696959,0.623206,0.614146,0.642586,274.0,2.0,1.0,5.0,64.0,0.0,768.0
"(256, 256)",0.1,0.0001,0.69487,0.619909,0.611405,0.632189,0.702027,0.628708,0.619329,0.642508,274.0,2.0,1.0,5.0,64.0,0.16,768.0


In [21]:
best_ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,,0.701115,0.624976,"[0.7930868939030244, 0.5785932721712539, 0.503...",0.618246,"[0.8029161603888214, 0.5905118601747815, 0.461...",0.634738,"[0.7834953758596158, 0.5671462829736211, 0.553...",,0.705743,...,"(512, 512)",0.5,0.0001,0,focal,2,True,5,64,768
0,,0.70171,0.625143,"[0.7948133029175171, 0.5782312925170068, 0.502...",0.618249,"[0.8049610894941635, 0.5970625798212005, 0.452...",0.636585,"[0.7849181882855111, 0.5605515587529976, 0.564...",,0.705743,...,"(512, 512)",0.5,0.0001,1,focal,2,True,5,64,768
0,,0.698885,0.622428,"[0.7917973462002413, 0.5798065296251511, 0.495...",0.614703,"[0.8057942548490057, 0.5847560975609756, 0.453...",0.633216,"[0.7782783969646668, 0.5749400479616307, 0.546...",,0.706306,...,"(512, 512)",0.5,0.0001,12,focal,2,True,5,64,768
0,,0.699033,0.622961,"[0.7923484119345523, 0.5754775107825015, 0.501...",0.615537,"[0.8041514041514042, 0.5918884664131813, 0.450...",0.635042,"[0.7808868864121413, 0.5599520383693045, 0.564...",,0.705405,...,"(512, 512)",0.5,0.0001,123,focal,2,True,5,64,768
0,,0.698736,0.623802,"[0.7921795800144821, 0.5744941753525444, 0.504...",0.615464,"[0.8065863848611452, 0.5878293601003765, 0.451...",0.637153,"[0.7782783969646668, 0.5617505995203836, 0.571...",,0.705518,...,"(512, 512)",0.5,0.0001,1234,focal,2,True,5,64,768


In [22]:
best_ffn_mean_history_kfold["f1"].mean()

0.6238619397044939

In [23]:
best_ffn_mean_history_kfold["precision"].mean()

0.6164396200166706

In [24]:
best_ffn_mean_history_kfold["recall"].mean()

0.6353465181376359

In [25]:
np.stack(best_ffn_mean_history_kfold["f1_scores"]).mean(axis=0)

array([0.79284511, 0.57732056, 0.50142016])

In [26]:
np.stack(best_ffn_mean_history_kfold["precision_scores"]).mean(axis=0)

array([0.80488186, 0.59040967, 0.45402733])

In [27]:
np.stack(best_ffn_mean_history_kfold["recall_scores"]).mean(axis=0)

array([0.78117145, 0.56486811, 0.56      ])