In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    histories_baseline_hyperparameter_search
)

In [3]:
output_dir = "therapist_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-27 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-27 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-27 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-27 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-27 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

# Baseline: Averaging history and use FFN

Here, we average the full history of a path and concatenate it the current embedding (the total number of features that are passed into the FFN is `2 * sbert_embeddings.shape[0]`).

Here, we will run the hyperparameter search to implement the FFN with the same parameters as the standard FFN baseline on the sentence embeddings. Going to try out some variations (1 hidden layer, 2 hidden layers and 3 hidden layers - all of size 100).

In [7]:
num_epochs = 100
hidden_dim_sizes = [[64,64],[128,128],[256,256],[512,512]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [9]:
ffn_mean_history, best_ffn_mean_history, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="main_therapist_behaviour",
    embeddings=sbert_embeddings,
    y_data=y_data_therapist,
    output_dim=output_dim_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=therapist_index,
    split_ids=therapist_transcript_id,
    k_fold=False,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}.csv",
    verbose=False
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2_best_model.csv


In [10]:
ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,0.703645,0.765101,0.732532,"[0.8179723502304148, 0.5615141955835962, 0.684...",0.748203,"[0.8255813953488372, 0.6544117647058824, 0.620...",0.726812,"[0.8105022831050228, 0.49171270718232046, 0.76...",0.757242,0.719216,...,0.5,0.0010,0,focal,2,False,,64,0.00,768
0,0.755611,0.776286,0.751284,"[0.8149038461538461, 0.608955223880597, 0.7033...",0.762718,"[0.8604060913705583, 0.6623376623376623, 0.626...",0.748607,"[0.773972602739726, 0.56353591160221, 0.801223...",0.822343,0.713619,...,0.5,0.0010,1,focal,2,False,,64,0.00,768
0,0.692211,0.753915,0.712624,"[0.8026315789473685, 0.5032679738562091, 0.679...",0.730674,"[0.7721518987341772, 0.616, 0.6356382978723404...",0.706207,"[0.8356164383561644, 0.425414364640884, 0.7308...",0.720180,0.721082,...,0.5,0.0010,12,focal,2,False,,64,0.00,768
0,0.678555,0.757644,0.723916,"[0.7986270022883296, 0.5313432835820896, 0.700...",0.732329,"[0.8004587155963303, 0.577922077922078, 0.6428...",0.721119,"[0.7968036529680366, 0.49171270718232046, 0.77...",0.722930,0.723881,...,0.5,0.0010,123,focal,2,False,,64,0.00,768
0,0.695901,0.766592,0.732357,"[0.8153153153153153, 0.5555555555555555, 0.689...",0.743155,"[0.8044444444444444, 0.6293706293706294, 0.647...",0.727206,"[0.8264840182648402, 0.4972375690607735, 0.737...",0.741430,0.713619,...,0.5,0.0010,1234,focal,2,False,,64,0.00,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.710750,0.750932,0.716124,"[0.7955555555555555, 0.5286624203821656, 0.680...",0.734327,"[0.7748917748917749, 0.6240601503759399, 0.619...",0.709715,"[0.817351598173516, 0.4585635359116022, 0.7553...",0.748978,0.724813,...,0.1,0.0005,0,focal,2,False,,64,0.26,768
0,0.695115,0.764355,0.735397,"[0.8031145717463848, 0.5779036827195468, 0.696...",0.738695,"[0.7830802603036876, 0.5930232558139535, 0.684...",0.733165,"[0.8242009132420092, 0.56353591160221, 0.70948...",0.750000,0.714552,...,0.1,0.0005,1,focal,2,False,,64,0.26,768
0,0.722081,0.750932,0.721820,"[0.7897934386391252, 0.562691131498471, 0.6837...",0.733938,"[0.8441558441558441, 0.6301369863013698, 0.612...",0.719290,"[0.7420091324200914, 0.5082872928176796, 0.773...",0.749692,0.716418,...,0.1,0.0005,12,focal,2,False,,64,0.26,768
0,0.713407,0.757644,0.720996,"[0.7977011494252874, 0.5206349206349207, 0.698...",0.736296,"[0.8032407407407407, 0.6119402985074627, 0.628...",0.716663,"[0.7922374429223744, 0.4530386740331492, 0.785...",0.748313,0.720149,...,0.1,0.0005,123,focal,2,False,,64,0.26,768


In [11]:
ffn_mean_history.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
"(128, 128)",0.1,0.0001,0.722813,0.768829,0.737768,0.756442,0.731786,0.759252,0.7125,0.696398,0.696747,0.704092,274.0,2.0,0.0,64.0,0.7,768.0
"(128, 128)",0.1,0.0005,0.700667,0.753766,0.723807,0.740939,0.717841,0.738476,0.721269,0.703307,0.708046,0.703521,274.0,2.0,0.0,64.0,0.8,768.0
"(128, 128)",0.1,0.001,0.707521,0.755854,0.722674,0.736857,0.719279,0.748745,0.714366,0.696691,0.69795,0.70113,274.0,2.0,0.0,64.0,0.6,768.0
"(128, 128)",0.2,0.0001,0.730573,0.767934,0.737488,0.752905,0.732095,0.771427,0.7125,0.695987,0.697064,0.70395,274.0,2.0,0.0,64.0,0.4,768.0
"(128, 128)",0.2,0.0005,0.698215,0.754959,0.723617,0.737051,0.718941,0.739578,0.721828,0.703219,0.707172,0.705123,274.0,2.0,0.0,64.0,0.5,768.0
"(128, 128)",0.2,0.001,0.699787,0.755854,0.72476,0.734402,0.721868,0.744842,0.719216,0.701665,0.704125,0.704656,274.0,2.0,0.0,64.0,0.3,768.0
"(128, 128)",0.5,0.0001,0.752733,0.764057,0.728949,0.750791,0.722862,0.794679,0.708022,0.691034,0.69247,0.697786,274.0,2.0,0.0,64.0,0.1,768.0
"(128, 128)",0.5,0.0005,0.688535,0.764653,0.733256,0.744679,0.729222,0.735677,0.71903,0.702379,0.703962,0.70695,274.0,2.0,0.0,64.0,0.2,768.0
"(128, 128)",0.5,0.001,0.705185,0.763908,0.730543,0.743416,0.72599,0.752825,0.718284,0.700224,0.702699,0.704937,274.0,2.0,0.0,64.0,0.0,768.0
"(256, 256)",0.1,0.0001,0.705543,0.757196,0.725568,0.738256,0.72057,0.746883,0.71959,0.701516,0.703504,0.705297,274.0,2.0,0.0,64.0,0.16,768.0


In [12]:
best_ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,0.70106,0.759135,0.727361,"[0.8017334777898157, 0.5660377358490566, 0.679...",0.744021,"[0.7628865979381443, 0.656934306569343, 0.6438...",0.719591,"[0.8447488584474886, 0.4972375690607735, 0.718...",0.739506,0.723881,...,"(512, 512)",0.2,0.0001,0,focal,2,False,,64,768
0,0.697738,0.761372,0.73174,"[0.8133640552995393, 0.56797583081571, 0.68125...",0.743574,"[0.8209302325581396, 0.6266666666666667, 0.616...",0.727381,"[0.8059360730593608, 0.5193370165745856, 0.761...",0.733883,0.720149,...,"(512, 512)",0.2,0.0001,1,focal,2,False,,64,768
0,0.708779,0.746458,0.714027,"[0.7862857142857144, 0.5353846153846153, 0.679...",0.72589,"[0.7871853546910755, 0.6041666666666666, 0.619...",0.709649,"[0.7853881278538812, 0.48066298342541436, 0.75...",0.739779,0.722948,...,"(512, 512)",0.2,0.0001,12,focal,2,False,,64,768
0,0.698313,0.759135,0.727455,"[0.7986577181208054, 0.555223880597015, 0.6916...",0.73546,"[0.7828947368421053, 0.6038961038961039, 0.653...",0.723301,"[0.815068493150685, 0.5138121546961326, 0.7339...",0.758043,0.724813,...,"(512, 512)",0.2,0.0001,123,focal,2,False,,64,768
0,0.72307,0.764355,0.737554,"[0.8088578088578088, 0.5868263473053892, 0.692...",0.74846,"[0.8261904761904761, 0.6405228758169934, 0.624...",0.733937,"[0.7922374429223744, 0.5414364640883977, 0.776...",0.771632,0.717351,...,"(512, 512)",0.2,0.0001,1234,focal,2,False,,64,768


In [13]:
best_ffn_mean_history["f1"].mean()

0.7276275253985091

In [14]:
best_ffn_mean_history["precision"].mean()

0.7394812350515159

In [15]:
best_ffn_mean_history["recall"].mean()

0.7227720072768336

In [16]:
np.stack(best_ffn_mean_history["f1_scores"]).mean(axis=0)

array([0.80177975, 0.56228968, 0.68474961, 0.86169105])

In [17]:
np.stack(best_ffn_mean_history["precision_scores"]).mean(axis=0)

array([0.79601748, 0.62643732, 0.63156984, 0.9039003 ])

In [18]:
np.stack(best_ffn_mean_history["recall_scores"]).mean(axis=0)

array([0.8086758 , 0.51049724, 0.74862385, 0.82329114])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [19]:
ffn_mean_history_kfold, best_ffn_mean_history_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="main_therapist_behaviour",
    embeddings=sbert_embeddings,
    y_data=y_data_therapist,
    output_dim=output_dim_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=therapist_index,
    split_ids=therapist_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}_kfold.csv",
    verbose=False
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2_kfold_best_model.csv


In [20]:
ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,,0.727259,0.705419,"[0.7542566709021601, 0.5554445554445554, 0.651...",0.705466,"[0.7669250645994832, 0.5220657276995305, 0.633...",0.707718,"[0.742, 0.5933831376734259, 0.6713368359603036...",,0.789711,...,0.5,0.0010,0,focal,2,True,5,64,0.00,768
0,,0.729522,0.707497,"[0.7501933488012374, 0.5598356445814072, 0.657...",0.708235,"[0.7743480574773816, 0.5396039603960396, 0.626...",0.708754,"[0.7275, 0.5816435432230523, 0.690601284296555...",,0.791143,...,0.5,0.0010,1,focal,2,True,5,64,0.00,768
0,,0.735556,0.711041,"[0.7587421383647798, 0.5562700964630224, 0.665...",0.712851,"[0.7635443037974684, 0.5586652314316469, 0.639...",0.710182,"[0.754, 0.5538954108858057, 0.6935201401050788...",,0.793567,...,0.5,0.0010,12,focal,2,True,5,64,0.00,768
0,,0.726505,0.703808,"[0.7551020408163265, 0.5518630412890231, 0.649...",0.703853,"[0.79198682766191, 0.5224022878932316, 0.62332...",0.706213,"[0.7215, 0.5848452508004269, 0.677174547577349...",,0.786847,...,0.5,0.0010,123,focal,2,True,5,64,0.00,768
0,,0.729069,0.706646,"[0.7531925983841543, 0.5547226386806596, 0.655...",0.706341,"[0.7866086009798585, 0.5216165413533834, 0.632...",0.709459,"[0.7225, 0.5923159018143009, 0.680677174547577...",,0.790372,...,0.5,0.0010,1234,focal,2,True,5,64,0.00,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.726052,0.701506,"[0.7506453278265359, 0.5434782608695651, 0.657...",0.701417,"[0.775880469583778, 0.5276381909547738, 0.6352...",0.702726,"[0.727, 0.5602988260405549, 0.681260945709282,...",,0.783432,...,0.1,0.0005,0,focal,2,True,5,64,0.26,768
0,,0.718811,0.695878,"[0.7376212353241449, 0.5438775510204081, 0.650...",0.695649,"[0.7533889468196038, 0.5210166177908113, 0.631...",0.697445,"[0.7225, 0.5688367129135539, 0.670753064798599...",,0.786407,...,0.1,0.0005,1,focal,2,True,5,64,0.26,768
0,,0.721677,0.700205,"[0.7469688982604112, 0.5463367297428432, 0.649...",0.700464,"[0.7898550724637681, 0.5008896797153025, 0.624...",0.704273,"[0.7085, 0.6008537886872999, 0.676007005253940...",,0.788390,...,0.1,0.0005,12,focal,2,True,5,64,0.26,768
0,,0.718057,0.694801,"[0.7454401268834258, 0.5306122448979592, 0.643...",0.696244,"[0.7908020190689848, 0.49720149253731344, 0.60...",0.697213,"[0.705, 0.5688367129135539, 0.681260945709282,...",,0.786627,...,0.1,0.0005,123,focal,2,True,5,64,0.26,768


In [21]:
ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(128, 128)",0.1,0.0001,0.728737,0.70535,0.704977,0.707955,0.785393,0.770517,0.768613,0.773644,274.0,2.0,1.0,5.0,64.0,0.7,768.0
"(128, 128)",0.1,0.0005,0.723096,0.699377,0.69931,0.701633,0.786473,0.77169,0.769884,0.77481,274.0,2.0,1.0,5.0,64.0,0.8,768.0
"(128, 128)",0.1,0.001,0.722432,0.699372,0.699523,0.702018,0.785085,0.770288,0.768644,0.773275,274.0,2.0,1.0,5.0,64.0,0.6,768.0
"(128, 128)",0.2,0.0001,0.72922,0.705507,0.705322,0.707684,0.786385,0.77148,0.769486,0.774649,274.0,2.0,1.0,5.0,64.0,0.4,768.0
"(128, 128)",0.2,0.0005,0.726173,0.702546,0.702109,0.704688,0.788544,0.773445,0.771951,0.775965,274.0,2.0,1.0,5.0,64.0,0.5,768.0
"(128, 128)",0.2,0.001,0.726354,0.703069,0.70339,0.705334,0.786561,0.772048,0.770409,0.775468,274.0,2.0,1.0,5.0,64.0,0.3,768.0
"(128, 128)",0.5,0.0001,0.73281,0.709314,0.708618,0.711629,0.788742,0.773689,0.771723,0.776729,274.0,2.0,1.0,5.0,64.0,0.1,768.0
"(128, 128)",0.5,0.0005,0.729099,0.705419,0.705742,0.707052,0.790923,0.776276,0.774469,0.779417,274.0,2.0,1.0,5.0,64.0,0.2,768.0
"(128, 128)",0.5,0.001,0.729582,0.706882,0.707349,0.708465,0.790328,0.775071,0.773627,0.777571,274.0,2.0,1.0,5.0,64.0,0.0,768.0
"(256, 256)",0.1,0.0001,0.727712,0.704915,0.704552,0.707798,0.785525,0.771043,0.768836,0.774798,274.0,2.0,1.0,5.0,64.0,0.16,768.0


In [22]:
best_ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,,0.730276,0.706933,"[0.7533748701973001, 0.5504587155963303, 0.663...",0.707048,"[0.7834773218142549, 0.526829268292683, 0.6366...",0.708786,"[0.7255, 0.576307363927428, 0.6935201401050788...",,0.78894,...,"(128, 128)",0.5,0.0005,0,focal,2,True,5,64,768
0,,0.723789,0.70138,"[0.7528438469493279, 0.5442850074220682, 0.647...",0.701715,"[0.7794432548179872, 0.507380073800738, 0.6239...",0.704066,"[0.728, 0.5869797225186766, 0.6730881494454174...",,0.791143,...,"(128, 128)",0.5,0.0005,1,focal,2,True,5,64,768
0,,0.731634,0.708531,"[0.7595001275184902, 0.5503018108651911, 0.658...",0.708762,"[0.7751171264966163, 0.5204567078972407, 0.637...",0.710444,"[0.7445, 0.583778014941302, 0.681260945709282,...",,0.792575,...,"(128, 128)",0.5,0.0005,12,focal,2,True,5,64,768
0,,0.730729,0.706746,"[0.7581227436823104, 0.547034764826176, 0.6595...",0.706965,"[0.7827476038338658, 0.5250245338567223, 0.634...",0.708222,"[0.735, 0.5709711846318036, 0.6870986573263281...",,0.791033,...,"(128, 128)",0.5,0.0005,123,focal,2,True,5,64,768
0,,0.729069,0.703504,"[0.7569741140990199, 0.5375, 0.655904842820730...",0.704219,"[0.7609903991915109, 0.5249237029501526, 0.636...",0.703743,"[0.753, 0.5506937033084311, 0.6760070052539404...",,0.790923,...,"(128, 128)",0.5,0.0005,1234,focal,2,True,5,64,768


In [23]:
best_ffn_mean_history_kfold["f1"].mean()

0.7054186961283366

In [24]:
best_ffn_mean_history_kfold["precision"].mean()

0.7057416915451535

In [25]:
best_ffn_mean_history_kfold["recall"].mean()

0.7070522984260954

In [26]:
np.stack(best_ffn_mean_history_kfold["f1_scores"]).mean(axis=0)

array([0.75616314, 0.54591606, 0.65706179, 0.86253379])

In [27]:
np.stack(best_ffn_mean_history_kfold["precision_scores"]).mean(axis=0)

array([0.77635514, 0.52092286, 0.63374112, 0.89194765])

In [28]:
np.stack(best_ffn_mean_history_kfold["recall_scores"]).mean(axis=0)

array([0.7372    , 0.573746  , 0.68219498, 0.83506822])