In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    histories_baseline_hyperparameter_search
)

In [3]:
output_dir = "therapist_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-19 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-19 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-19 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-19 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-19 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(9699, 384)

# Baseline: Averaging history and use FFN

Here, we average the full history of a path and concatenate it the current embedding (the total number of features that are passed into the FFN is `2 * sbert_embeddings.shape[0]`).

Here, we will run the hyperparameter search to implement the FFN with the same parameters as the standard FFN baseline on the sentence embeddings. Going to try out some variations (1 hidden layer, 2 hidden layers and 3 hidden layers - all of size 100).

In [7]:
num_epochs = 100
hidden_dim_sizes = [[64,64],[128,128],[256,256],[512,512]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [9]:
ffn_mean_history, best_ffn_mean_history, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="main_therapist_behaviour",
    embeddings=sbert_embeddings,
    y_data=y_data_therapist,
    output_dim=output_dim_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=therapist_index,
    split_ids=therapist_transcript_id,
    k_fold=False,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}.csv",
    verbose=False
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/9699 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/9699 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2_best_model.csv


In [10]:
ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,0.666719,0.752222,0.713572,"[0.7882136279926335, 0.5244444444444445, 0.680...",0.715044,"[0.7867647058823529, 0.5315315315315315, 0.638...",0.714503,"[0.7896678966789668, 0.5175438596491229, 0.727...",0.765459,0.712139,...,0.5,0.0010,1,focal,2,False,,64,0.00,768
0,0.702486,0.755556,0.718542,"[0.776173285198556, 0.5533980582524273, 0.6845...",0.727660,"[0.7597173144876325, 0.6195652173913043, 0.642...",0.714687,"[0.7933579335793358, 0.5, 0.7320574162679426, ...",0.776475,0.710983,...,0.5,0.0010,12,focal,2,False,,64,0.00,768
0,0.713564,0.768889,0.740267,"[0.7956600361663653, 0.6007905138339921, 0.695...",0.737263,"[0.7801418439716312, 0.5467625899280576, 0.702...",0.747750,"[0.8118081180811808, 0.6666666666666666, 0.688...",0.779774,0.707514,...,0.5,0.0010,123,focal,2,False,,64,0.00,768
0,0.727447,0.764444,0.719929,"[0.800718132854578, 0.5268817204301075, 0.6824...",0.745177,"[0.7797202797202797, 0.6805555555555556, 0.618...",0.713334,"[0.8228782287822878, 0.4298245614035088, 0.760...",0.770657,0.712139,...,0.5,0.0001,1,focal,2,False,,64,0.10,768
0,0.731906,0.764444,0.722639,"[0.797153024911032, 0.5396825396825397, 0.6853...",0.745921,"[0.7697594501718213, 0.68, 0.6235294117647059,...",0.716192,"[0.8265682656826568, 0.4473684210526316, 0.760...",0.772632,0.715607,...,0.5,0.0001,12,focal,2,False,,64,0.10,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.701245,0.752222,0.715184,"[0.781132075471698, 0.5428571428571428, 0.6723...",0.724057,"[0.7992277992277992, 0.59375, 0.60536398467432...",0.713288,"[0.7638376383763837, 0.5, 0.7559808612440191, ...",0.763299,0.715607,...,0.1,0.0001,12,focal,2,False,,64,0.22,768
0,0.701796,0.746667,0.700341,"[0.7850799289520426, 0.4830917874396135, 0.672...",0.708185,"[0.7568493150684932, 0.5376344086021505, 0.629...",0.698394,"[0.8154981549815498, 0.43859649122807015, 0.72...",0.793200,0.724855,...,0.1,0.0001,123,focal,2,False,,64,0.22,768
0,0.717127,0.746667,0.713703,"[0.7806691449814127, 0.5478260869565217, 0.666...",0.716587,"[0.7865168539325843, 0.5431034482758621, 0.608...",0.716259,"[0.7749077490774908, 0.5526315789473685, 0.736...",0.737720,0.722543,...,0.1,0.0005,1,focal,2,False,,64,0.23,768
0,0.716846,0.765556,0.736553,"[0.7857142857142857, 0.5975103734439833, 0.696...",0.735797,"[0.8497854077253219, 0.5669291338582677, 0.652...",0.742024,"[0.7306273062730627, 0.631578947368421, 0.7464...",0.775155,0.708670,...,0.1,0.0005,12,focal,2,False,,64,0.23,768


In [11]:
ffn_mean_history.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
"(64, 64)",0.1,0.0001,0.736051,0.758889,0.716342,0.739139,0.711677,0.778089,0.717534,0.685286,0.688111,0.683793,45.333333,2.0,0.0,64.0,0.4,768.0
"(64, 64)",0.1,0.0005,0.723697,0.756667,0.720999,0.727708,0.719801,0.753201,0.717534,0.688972,0.695609,0.686436,45.333333,2.0,0.0,64.0,0.5,768.0
"(64, 64)",0.1,0.001,0.692843,0.751852,0.720381,0.723154,0.724669,0.752821,0.711368,0.682705,0.68381,0.686062,45.333333,2.0,0.0,64.0,0.3,768.0
"(64, 64)",0.5,0.0001,0.762052,0.764074,0.722295,0.746265,0.715512,0.796692,0.713295,0.679974,0.684739,0.676975,45.333333,2.0,0.0,64.0,0.1,768.0
"(64, 64)",0.5,0.0005,0.67478,0.757037,0.719438,0.722556,0.719577,0.753116,0.716763,0.686759,0.689402,0.686522,45.333333,2.0,0.0,64.0,0.2,768.0
"(64, 64)",0.5,0.001,0.694256,0.758889,0.724127,0.726656,0.725647,0.773903,0.710212,0.680604,0.683714,0.682488,45.333333,2.0,0.0,64.0,0.0,768.0
"(128, 128)",0.1,0.0001,0.704139,0.764444,0.729751,0.741797,0.726294,0.758332,0.719461,0.689698,0.690779,0.690096,45.333333,2.0,0.0,64.0,0.1,768.0
"(128, 128)",0.1,0.0005,0.713912,0.758889,0.724404,0.729996,0.723803,0.741361,0.716378,0.687681,0.69058,0.688119,45.333333,2.0,0.0,64.0,0.11,768.0
"(128, 128)",0.1,0.001,0.686555,0.75963,0.725475,0.731929,0.724041,0.744976,0.717148,0.686932,0.694707,0.684028,45.333333,2.0,0.0,64.0,0.9,768.0
"(128, 128)",0.5,0.0001,0.778704,0.755185,0.710627,0.746454,0.702935,0.817088,0.716763,0.685966,0.690965,0.68281,45.333333,2.0,0.0,64.0,0.7,768.0


In [12]:
best_ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,0.686164,0.747778,0.711873,"[0.7701149425287356, 0.5399999999999999, 0.669...",0.732002,"[0.8007968127490039, 0.627906976744186, 0.5733...",0.709051,"[0.7416974169741697, 0.47368421052631576, 0.80...",0.746844,0.724855,...,"(512, 512)",0.1,0.0001,1,focal,2,False,,64,768
0,0.701245,0.752222,0.715184,"[0.781132075471698, 0.5428571428571428, 0.6723...",0.724057,"[0.7992277992277992, 0.59375, 0.60536398467432...",0.713288,"[0.7638376383763837, 0.5, 0.7559808612440191, ...",0.763299,0.715607,...,"(512, 512)",0.1,0.0001,12,focal,2,False,,64,768
0,0.701796,0.746667,0.700341,"[0.7850799289520426, 0.4830917874396135, 0.672...",0.708185,"[0.7568493150684932, 0.5376344086021505, 0.629...",0.698394,"[0.8154981549815498, 0.43859649122807015, 0.72...",0.7932,0.724855,...,"(512, 512)",0.1,0.0001,123,focal,2,False,,64,768


In [13]:
best_ffn_mean_history["f1"].mean()

0.7091326946451284

In [14]:
best_ffn_mean_history["precision"].mean()

0.7214149135236309

In [15]:
best_ffn_mean_history["recall"].mean()

0.706910902020175

In [16]:
np.stack(best_ffn_mean_history["f1_scores"]).mean(axis=0)

array([0.77877565, 0.52198298, 0.67142298, 0.86434918])

In [17]:
np.stack(best_ffn_mean_history["precision_scores"]).mean(axis=0)

array([0.78562464, 0.58643046, 0.6026365 , 0.91096805])

In [18]:
np.stack(best_ffn_mean_history["recall_scores"]).mean(axis=0)

array([0.77367774, 0.47076023, 0.76076555, 0.82244009])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [19]:
ffn_mean_history_kfold, best_ffn_mean_history_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="main_therapist_behaviour",
    embeddings=sbert_embeddings,
    y_data=y_data_therapist,
    output_dim=output_dim_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=therapist_index,
    split_ids=therapist_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}_kfold.csv",
    verbose=False
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/9699 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/9699 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_mean_history_focal_2_kfold_best_model.csv


In [29]:
ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,,0.728653,0.695516,"[0.7532381268681501, 0.5130742049469964, 0.653...",0.694326,"[0.7490092470277411, 0.487248322147651, 0.6503...",0.697892,"[0.7575150300601202, 0.5417910447761194, 0.656...",,0.769462,...,0.5,0.0010,1,focal,2,True,5,64,0.00,768
0,,0.728463,0.696962,"[0.751592356687898, 0.513675783855904, 0.65560...",0.695332,"[0.7543741588156124, 0.4644149577804584, 0.660...",0.702204,"[0.7488309953239813, 0.5746268656716418, 0.651...",,0.762025,...,0.5,0.0010,12,focal,2,True,5,64,0.00,768
0,,0.721632,0.691130,"[0.7432206226983595, 0.5131578947368421, 0.641...",0.689678,"[0.7449664429530202, 0.4588235294117647, 0.648...",0.697121,"[0.7414829659318637, 0.582089552238806, 0.6353...",,0.762025,...,0.5,0.0010,123,focal,2,True,5,64,0.00,768
0,,0.734725,0.697533,"[0.7647252385653175, 0.5003825554705432, 0.661...",0.699873,"[0.7535667963683528, 0.5133437990580848, 0.641...",0.696200,"[0.7762191048764195, 0.4880597014925373, 0.683...",,0.766930,...,0.5,0.0001,1,focal,2,True,5,64,0.10,768
0,,0.734725,0.697353,"[0.7621550591327202, 0.49809596344249807, 0.66...",0.699490,"[0.7498383968972204, 0.5085536547433903, 0.644...",0.696224,"[0.7748830995323981, 0.4880597014925373, 0.684...",,0.765348,...,0.5,0.0001,12,focal,2,True,5,64,0.10,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.725617,0.694610,"[0.747523061154766, 0.5134408602150538, 0.6534...",0.693322,"[0.765034965034965, 0.4669926650366748, 0.6445...",0.699294,"[0.7307949231796927, 0.5701492537313433, 0.662...",,0.758386,...,0.1,0.0001,12,focal,2,True,5,64,0.22,768
0,,0.727135,0.695862,"[0.7460047602856172, 0.5194444444444445, 0.653...",0.694766,"[0.7596952908587258, 0.4857142857142857, 0.639...",0.698975,"[0.7327989311957248, 0.5582089552238806, 0.668...",,0.763449,...,0.1,0.0001,123,focal,2,True,5,64,0.22,768
0,,0.725427,0.693686,"[0.7432478632478632, 0.5159500693481276, 0.655...",0.692302,"[0.761204481792717, 0.48186528497409326, 0.644...",0.696933,"[0.7261189044756179, 0.5552238805970149, 0.666...",,0.762342,...,0.1,0.0005,1,focal,2,True,5,64,0.23,768
0,,0.721442,0.690078,"[0.7397820163487739, 0.505685618729097, 0.6503...",0.688978,"[0.7546907574704657, 0.4581818181818182, 0.642...",0.694806,"[0.7254509018036072, 0.564179104477612, 0.6583...",,0.753639,...,0.1,0.0005,12,focal,2,True,5,64,0.23,768


In [21]:
ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(64, 64)",0.1,0.0001,0.729412,0.696168,0.696054,0.697734,0.764821,0.73833,0.736015,0.741982,45.333333,2.0,1.0,5.0,64.0,0.4,768.0
"(64, 64)",0.1,0.0005,0.726565,0.694422,0.693497,0.697372,0.764821,0.737945,0.735427,0.742204,45.333333,2.0,1.0,5.0,64.0,0.5,768.0
"(64, 64)",0.1,0.001,0.724668,0.692276,0.690929,0.696103,0.760232,0.732821,0.730138,0.737897,45.333333,2.0,1.0,5.0,64.0,0.3,768.0
"(64, 64)",0.5,0.0001,0.733586,0.697451,0.698681,0.697108,0.766667,0.738258,0.737441,0.739742,45.333333,2.0,1.0,5.0,64.0,0.1,768.0
"(64, 64)",0.5,0.0005,0.727451,0.694279,0.693637,0.696445,0.766297,0.73868,0.736925,0.741676,45.333333,2.0,1.0,5.0,64.0,0.2,768.0
"(64, 64)",0.5,0.001,0.726249,0.694536,0.693112,0.699072,0.764504,0.737228,0.734614,0.742729,45.333333,2.0,1.0,5.0,64.0,0.0,768.0
"(128, 128)",0.1,0.0001,0.72783,0.696032,0.695336,0.698548,0.763924,0.737821,0.735375,0.741657,45.333333,2.0,1.0,5.0,64.0,0.1,768.0
"(128, 128)",0.1,0.0005,0.722201,0.691629,0.690969,0.696395,0.760759,0.734409,0.731748,0.739993,45.333333,2.0,1.0,5.0,64.0,0.11,768.0
"(128, 128)",0.1,0.001,0.724162,0.692818,0.69146,0.698105,0.759019,0.732602,0.729685,0.740362,45.333333,2.0,1.0,5.0,64.0,0.9,768.0
"(128, 128)",0.5,0.0001,0.730108,0.697059,0.695902,0.699489,0.76635,0.739802,0.737065,0.744174,45.333333,2.0,1.0,5.0,64.0,0.7,768.0


In [22]:
best_ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,,0.731309,0.697849,"[0.7557788944723617, 0.5106382978723404, 0.662...",0.696962,"[0.7580645161290323, 0.4864864864864865, 0.652...",0.699922,"[0.7535070140280561, 0.5373134328358209, 0.674...",,0.765981,...,"(128, 128)",0.5,0.0001,1,focal,2,True,5,64,768
0,,0.729981,0.696762,"[0.7532641446267156, 0.510938602681722, 0.6607...",0.695711,"[0.7550335570469798, 0.48460508701472554, 0.65...",0.699093,"[0.751503006012024, 0.5402985074626866, 0.6697...",,0.767405,...,"(128, 128)",0.5,0.0001,12,focal,2,True,5,64,768
0,,0.729032,0.696567,"[0.7529411764705882, 0.516445066480056, 0.6564...",0.695034,"[0.7577807848443843, 0.48616600790513836, 0.65...",0.699453,"[0.7481629926519706, 0.5507462686567164, 0.661...",,0.765665,...,"(128, 128)",0.5,0.0001,123,focal,2,True,5,64,768


In [23]:
best_ffn_mean_history_kfold["f1"].mean()

0.6970593651089759

In [24]:
best_ffn_mean_history_kfold["precision"].mean()

0.6959021903264374

In [25]:
best_ffn_mean_history_kfold["recall"].mean()

0.6994890280677256

In [26]:
np.stack(best_ffn_mean_history_kfold["f1_scores"]).mean(axis=0)

array([0.75399474, 0.51267399, 0.66005304, 0.86151569])

In [27]:
np.stack(best_ffn_mean_history_kfold["precision_scores"]).mean(axis=0)

array([0.75695962, 0.48575253, 0.65199071, 0.88890591])

In [28]:
np.stack(best_ffn_mean_history_kfold["recall_scores"]).mean(axis=0)

array([0.75105767, 0.54278607, 0.66833811, 0.83577426])