In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    ffn_hyperparameter_search,
)

In [3]:
output_dir = "therapist_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-27 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-27 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-27 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-27 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-27 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

In [35]:
y_data_therapist = anno_mi["main_therapist_behaviour"][therapist_index]
y_data_therapist.unique()

array(['question', 'therapist_input', 'reflection', 'other'], dtype=object)

In [30]:
output_dim_therapist

4

In [32]:
label_to_id_therapist

{'question': 0, 'therapist_input': 1, 'reflection': 2, 'other': 3}

In [33]:
id_to_label_therapist

{0: 'question', 1: 'therapist_input', 2: 'reflection', 3: 'other'}

# Baseline: FFN baseline

Using the embeddings for the sentences directly in a FFN to predict the therapist talk type.

In [7]:
num_epochs = 100
hidden_dim_sizes = [[64,64],[128,128],[256,256],[512,512]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [8]:
hidden_dim_sizes

[[128, 128], [256, 256]]

In [9]:
learning_rates

[0.001, 0.0001, 0.0005]

We use the `ffn_hyperparameter_search` function which loops through the different hidden dimensions, dropout rates and learning rates to find the best model for the validation set. We evaluate the model on several seeds and average the performance over the seeds.

In [10]:
ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[therapist_index],
    y_data=y_data_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_therapist,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=False,
    patience=patience,
    split_ids=therapist_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2_best_model.csv


In [11]:
ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,0.689033,0.774795,0.745391,"[0.8295454545454546, 0.5882352941176471, 0.698...",0.760704,"[0.8257918552036199, 0.6690140845070423, 0.632...",0.739566,"[0.8333333333333334, 0.5248618784530387, 0.779...",0.769741,0.709888,...,"(128, 128)",0.5,0.0010,0,focal,2,False,,64,0
0,0.703176,0.765101,0.736536,"[0.8182857142857143, 0.5802816901408452, 0.684...",0.739443,"[0.8192219679633868, 0.5919540229885057, 0.650...",0.735259,"[0.817351598173516, 0.569060773480663, 0.72171...",0.776954,0.712687,...,"(128, 128)",0.5,0.0010,1,focal,2,False,,64,0
0,0.688772,0.768829,0.739822,"[0.8295964125560539, 0.5833333333333333, 0.687...",0.742537,"[0.8149779735682819, 0.5865921787709497, 0.663...",0.738780,"[0.8447488584474886, 0.580110497237569, 0.7125...",0.768796,0.701493,...,"(128, 128)",0.5,0.0010,12,focal,2,False,,64,0
0,0.700736,0.771812,0.740448,"[0.826879271070615, 0.5773809523809523, 0.6896...",0.747452,"[0.825, 0.6258064516129033, 0.6504065040650406...",0.736681,"[0.8287671232876712, 0.5359116022099447, 0.733...",0.764418,0.724813,...,"(128, 128)",0.5,0.0010,123,focal,2,False,,64,0
0,0.732741,0.768084,0.739555,"[0.8121353558926487, 0.5789473684210528, 0.695...",0.746412,"[0.8305489260143198, 0.6149068322981367, 0.637...",0.737261,"[0.7945205479452054, 0.5469613259668509, 0.764...",0.787570,0.712687,...,"(128, 128)",0.5,0.0010,1234,focal,2,False,,64,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.700243,0.762864,0.731153,"[0.8154020385050962, 0.5770308123249299, 0.676...",0.732356,"[0.8089887640449438, 0.5852272727272727, 0.695...",0.730472,"[0.821917808219178, 0.569060773480663, 0.65749...",0.763688,0.706157,...,"(256, 256)",0.1,0.0005,0,focal,2,False,,64,17
0,0.718789,0.769575,0.741758,"[0.8147268408551069, 0.5819209039548023, 0.698...",0.745249,"[0.849009900990099, 0.5953757225433526, 0.6450...",0.741700,"[0.7831050228310502, 0.569060773480663, 0.7614...",0.779120,0.709888,...,"(256, 256)",0.1,0.0005,1,focal,2,False,,64,17
0,0.720001,0.771812,0.745727,"[0.8137603795966785, 0.5942857142857143, 0.701...",0.751200,"[0.8469135802469135, 0.6153846153846154, 0.639...",0.745005,"[0.7831050228310502, 0.574585635359116, 0.7767...",0.771752,0.711754,...,"(256, 256)",0.1,0.0005,12,focal,2,False,,64,17
0,0.707225,0.772558,0.747159,"[0.8167664670658683, 0.5971830985915493, 0.705...",0.751829,"[0.8589420654911839, 0.6091954022988506, 0.641...",0.747522,"[0.7785388127853882, 0.585635359116022, 0.7828...",0.768998,0.704291,...,"(256, 256)",0.1,0.0005,123,focal,2,False,,64,17


In [12]:
ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(128, 128)",0.1,0.0001,0.715478,0.772558,0.744137,0.747381,0.743038,0.771822,0.713246,0.692085,0.693156,0.699584,274.0,2.0,0.0,64.0,7.0
"(128, 128)",0.1,0.0005,0.703012,0.774348,0.744205,0.756287,0.739831,0.765517,0.712313,0.689565,0.691992,0.693468,274.0,2.0,0.0,64.0,8.0
"(128, 128)",0.1,0.001,0.710868,0.763758,0.734379,0.737796,0.734626,0.784375,0.707836,0.687888,0.690413,0.697686,274.0,2.0,0.0,64.0,6.0
"(128, 128)",0.2,0.0001,0.717742,0.773602,0.744796,0.748411,0.74358,0.774803,0.712313,0.691827,0.692586,0.699375,274.0,2.0,0.0,64.0,4.0
"(128, 128)",0.2,0.0005,0.708361,0.771812,0.742386,0.754753,0.738141,0.766125,0.71194,0.68978,0.692135,0.69397,274.0,2.0,0.0,64.0,5.0
"(128, 128)",0.2,0.001,0.69555,0.772558,0.744317,0.75709,0.739196,0.768687,0.706716,0.68552,0.690614,0.688002,274.0,2.0,0.0,64.0,3.0
"(128, 128)",0.5,0.0001,0.722904,0.774795,0.745171,0.750967,0.742552,0.782547,0.710634,0.689079,0.689735,0.694423,274.0,2.0,0.0,64.0,1.0
"(128, 128)",0.5,0.0005,0.705857,0.775839,0.746163,0.752953,0.743148,0.772111,0.714179,0.692376,0.692935,0.697849,274.0,2.0,0.0,64.0,2.0
"(128, 128)",0.5,0.001,0.702892,0.769724,0.74035,0.74731,0.73751,0.773496,0.712313,0.691833,0.695558,0.695864,274.0,2.0,0.0,64.0,0.0
"(256, 256)",0.1,0.0001,0.712367,0.771961,0.743352,0.747441,0.742385,0.770505,0.710261,0.6905,0.692163,0.698062,274.0,2.0,0.0,64.0,16.0


In [13]:
best_ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,0.69997,0.776286,0.74688,"[0.8256880733944953, 0.5862068965517242, 0.708...",0.750543,"[0.8294930875576036, 0.6107784431137725, 0.671...",0.745065,"[0.821917808219178, 0.56353591160221, 0.749235...",0.754973,0.712687,...,"[0.7115902964959568, 0.5827338129496403, 0.686...","(256, 256)",0.5,0.0001,0,focal,2,False,,64
0,0.704376,0.774795,0.746627,"[0.8246445497630331, 0.5852272727272727, 0.703...",0.750824,"[0.8571428571428571, 0.6023391812865497, 0.647...",0.746214,"[0.7945205479452054, 0.569060773480663, 0.7706...",0.768462,0.714552,...,"[0.6900269541778976, 0.60431654676259, 0.70648...","(256, 256)",0.5,0.0001,1,focal,2,False,,64
0,0.703381,0.774049,0.744416,"[0.8243559718969554, 0.5838150289017341, 0.706...",0.747987,"[0.8461538461538461, 0.6121212121212121, 0.666...",0.743046,"[0.8036529680365296, 0.5580110497237569, 0.752...",0.76089,0.719216,...,"[0.706199460916442, 0.5755395683453237, 0.7133...","(256, 256)",0.5,0.0001,12,focal,2,False,,64
0,0.710109,0.774795,0.747178,"[0.8193624557260921, 0.5941176470588235, 0.705...",0.754196,"[0.8484107579462102, 0.6352201257861635, 0.646...",0.745043,"[0.7922374429223744, 0.5580110497237569, 0.776...",0.763004,0.711754,...,"[0.6981132075471698, 0.5755395683453237, 0.709...","(256, 256)",0.5,0.0001,123,focal,2,False,,64
0,0.695649,0.773304,0.742814,"[0.8289920724801813, 0.5835694050991501, 0.690...",0.745079,"[0.8224719101123595, 0.5988372093023255, 0.672...",0.741198,"[0.8356164383561644, 0.569060773480663, 0.7094...",0.761174,0.716418,...,"[0.7169811320754716, 0.6115107913669064, 0.672...","(256, 256)",0.5,0.0001,1234,focal,2,False,,64


In [14]:
best_ffn_current["f1"].mean()

0.7455830202577547

In [15]:
best_ffn_current["precision"].mean()

0.7497258632398834

In [16]:
best_ffn_current["recall"].mean()

0.7441131198762245

In [17]:
np.stack(best_ffn_current["f1_scores"]).mean(axis=0)

array([0.82460862, 0.58658725, 0.70298628, 0.86814993])

In [18]:
np.stack(best_ffn_current["precision_scores"]).mean(axis=0)

array([0.84073449, 0.61185923, 0.66089773, 0.885412  ])

In [19]:
np.stack(best_ffn_current["recall_scores"]).mean(axis=0)

array([0.80958904, 0.56353591, 0.75168196, 0.85164557])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [20]:
ffn_current_kfold, best_ffn_current_kfold, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[therapist_index],
    y_data=y_data_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_therapist,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=True,
    patience=patience,
    split_ids=therapist_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2_kfold_best_model.csv


In [21]:
ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,,0.738120,0.714234,"[0.7706237424547285, 0.5614754098360655, 0.665...",0.713553,"[0.7753036437246964, 0.5399014778325123, 0.656...",0.715724,"[0.766, 0.5848452508004269, 0.6742556917688266...",,0.782331,...,"(128, 128)",0.5,0.0010,0,focal,2,True,5,64,0
0,,0.729522,0.706199,"[0.7623862487360971, 0.5501519756838906, 0.659...",0.706024,"[0.7709611451942741, 0.523625843780135, 0.6438...",0.707932,"[0.754, 0.5795090715048026, 0.676590776415645,...",,0.786517,...,"(128, 128)",0.5,0.0010,1,focal,2,True,5,64,0
0,,0.733896,0.709682,"[0.7669806643529995, 0.5599173553719009, 0.656...",0.709225,"[0.7605703048180924, 0.5425425425425425, 0.655...",0.710736,"[0.7735, 0.5784418356456777, 0.656742556917688...",,0.782441,...,"(128, 128)",0.5,0.0010,12,focal,2,True,5,64,0
0,,0.729220,0.704391,"[0.76049766718507, 0.5412506568575932, 0.66281...",0.706010,"[0.7895586652314317, 0.5331262939958592, 0.627...",0.704683,"[0.7335, 0.5496264674493063, 0.702860478692352...",,0.784644,...,"(128, 128)",0.5,0.0010,123,focal,2,True,5,64,0
0,,0.729069,0.705644,"[0.7603930461073318, 0.5506234413965088, 0.654...",0.704730,"[0.766378872524124, 0.5168539325842697, 0.6499...",0.708227,"[0.7545, 0.5891141942369263, 0.659077641564506...",,0.782661,...,"(128, 128)",0.5,0.0010,1234,focal,2,True,5,64,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.725449,0.703002,"[0.760398060729778, 0.5499999999999999, 0.6497...",0.702092,"[0.7764460656591975, 0.5086128739800544, 0.641...",0.706465,"[0.745, 0.5987193169690501, 0.658493870402802,...",,0.780128,...,"(256, 256)",0.1,0.0005,0,focal,2,True,5,64,17
0,,0.728013,0.705695,"[0.7573415765069551, 0.5549252191851469, 0.655...",0.707031,"[0.7810839532412327, 0.5369261477045908, 0.622...",0.706585,"[0.735, 0.5741728922091782, 0.6935201401050788...",,0.782331,...,"(256, 256)",0.1,0.0005,1,focal,2,True,5,64,17
0,,0.722281,0.700108,"[0.7539952842546502, 0.5489383738995338, 0.650...",0.702456,"[0.7919647771051184, 0.5331991951710262, 0.608...",0.700879,"[0.7195, 0.5656350053361793, 0.698774080560420...",,0.782331,...,"(256, 256)",0.1,0.0005,12,focal,2,True,5,64,17
0,,0.731483,0.708867,"[0.7667087011349307, 0.5591397849462366, 0.658...",0.709143,"[0.7735368956743003, 0.5374015748031497, 0.639...",0.710097,"[0.76, 0.5827107790821772, 0.6800934033858728,...",,0.783653,...,"(256, 256)",0.1,0.0005,123,focal,2,True,5,64,17


In [22]:
ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
"(128, 128)",0.1,0.0001,0.729341,0.705905,0.705622,0.707748,0.781251,0.764851,0.762702,0.768156,274.0,2.0,1.0,5.0,64.0,7.0
"(128, 128)",0.1,0.0005,0.727319,0.703968,0.70399,0.705478,0.780062,0.763443,0.761851,0.76595,274.0,2.0,1.0,5.0,64.0,8.0
"(128, 128)",0.1,0.001,0.726837,0.703925,0.703542,0.70635,0.78059,0.764019,0.762221,0.767332,274.0,2.0,1.0,5.0,64.0,6.0
"(128, 128)",0.2,0.0001,0.731543,0.708124,0.707705,0.709824,0.782661,0.766052,0.764104,0.768963,274.0,2.0,1.0,5.0,64.0,4.0
"(128, 128)",0.2,0.0005,0.727199,0.704134,0.703956,0.706056,0.780635,0.764205,0.762523,0.767108,274.0,2.0,1.0,5.0,64.0,5.0
"(128, 128)",0.2,0.001,0.7256,0.703183,0.702863,0.706178,0.780282,0.7637,0.761886,0.767018,274.0,2.0,1.0,5.0,64.0,3.0
"(128, 128)",0.5,0.0001,0.734349,0.710279,0.709885,0.711632,0.783895,0.767089,0.765347,0.769729,274.0,2.0,1.0,5.0,64.0,1.0
"(128, 128)",0.5,0.0005,0.731604,0.707587,0.707605,0.708647,0.783521,0.766668,0.765468,0.768755,274.0,2.0,1.0,5.0,64.0,2.0
"(128, 128)",0.5,0.001,0.731966,0.70803,0.707908,0.70946,0.783719,0.767037,0.765933,0.76919,274.0,2.0,1.0,5.0,64.0,0.0
"(256, 256)",0.1,0.0001,0.728345,0.705249,0.70495,0.707512,0.78189,0.765562,0.763371,0.769134,274.0,2.0,1.0,5.0,64.0,16.0


In [23]:
best_ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,,0.737064,0.715041,"[0.7750126968004063, 0.5684210526315789, 0.663...",0.714244,"[0.7874097007223942, 0.5359168241965974, 0.650...",0.717635,"[0.763, 0.6051227321237994, 0.6777583187390543...",,0.78167,...,"[0.8162781509284868, 0.6969253294289898, 0.724...","(256, 256)",0.5,0.0005,0,focal,2,True,5,64
0,,0.733896,0.710191,"[0.7748159431327748, 0.5533435426237877, 0.658...",0.710334,"[0.7870036101083032, 0.5303326810176126, 0.638...",0.711549,"[0.763, 0.5784418356456777, 0.6800934033858728...",,0.783653,...,"[0.8190438561833268, 0.698389458272328, 0.7259...","(256, 256)",0.5,0.0005,1,focal,2,True,5,64
0,,0.72922,0.704538,"[0.7662172475197152, 0.5371200797209766, 0.658...",0.704033,"[0.7799067840497151, 0.5037383177570094, 0.646...",0.706924,"[0.753, 0.575240128068303, 0.670753064798599, ...",,0.783102,...,"[0.8182536546819439, 0.7108345534407028, 0.715...","(256, 256)",0.5,0.0005,12,focal,2,True,5,64
0,,0.728164,0.704507,"[0.760613810741688, 0.5442247658688867, 0.6595...",0.70649,"[0.7785340314136125, 0.5309644670050762, 0.623...",0.704778,"[0.7435, 0.5581643543223053, 0.699357851722125...",,0.785085,...,"[0.8158830501777954, 0.6874084919472914, 0.748...","(256, 256)",0.5,0.0005,123,focal,2,True,5,64
0,,0.728918,0.706336,"[0.7666497203863752, 0.5537231384307846, 0.653...",0.706008,"[0.7797311271975181, 0.5206766917293233, 0.637...",0.708736,"[0.754, 0.5912486659551761, 0.6695855224751898...",,0.785415,...,"[0.817463453180561, 0.6991215226939971, 0.7329...","(256, 256)",0.5,0.0005,1234,focal,2,True,5,64


In [24]:
best_ffn_current_kfold["f1"].mean()

0.7081224732274419

In [25]:
best_ffn_current_kfold["precision"].mean()

0.7082219052566595

In [26]:
best_ffn_current_kfold["recall"].mean()

0.7099243095262822

In [27]:
np.stack(best_ffn_current_kfold["f1_scores"]).mean(axis=0)

array([0.76866188, 0.55136652, 0.65863017, 0.85383132])

In [28]:
np.stack(best_ffn_current_kfold["precision_scores"]).mean(axis=0)

array([0.78251705, 0.5243258 , 0.63924052, 0.88680426])

In [29]:
np.stack(best_ffn_current_kfold["recall_scores"]).mean(axis=0)

array([0.7553    , 0.58164354, 0.67950963, 0.82324406])