In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    ffn_hyperparameter_search,
)

In [3]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-14 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-14 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-14 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-14 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-14 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

# Baseline: FFN baseline

Using the embeddings for the sentences directly in a FFN to predict the client talk type.

In [7]:
num_epochs = 100
hidden_dim_sizes = [[128,128],[256,256]]
dropout_rates = [0.5, 0.2, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [0, 1, 12, 123, 1234]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [8]:
hidden_dim_sizes

[[128, 128], [256, 256]]

In [9]:
learning_rates

[0.001, 0.0001, 0.0005]

We use the `ffn_hyperparameter_search` function which loops through the different hidden dimensions, dropout rates and learning rates to find the best model for the validation set. We evaluate the model on several seeds and average the performance over the seeds.

In [10]:
ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data_client,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_client,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=False,
    patience=patience,
    split_ids=client_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_best_model.csv


In [11]:
ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,0.637251,0.616200,0.522944,"[0.7274984481688391, 0.5, 0.3413333333333333]",0.513077,"[0.7834224598930482, 0.47262247838616717, 0.28...",0.546434,"[0.6790266512166859, 0.5307443365695793, 0.429...",0.723316,0.551659,...,"(128, 128)",0.5,0.0010,0,focal,2,False,,64,0
0,0.635756,0.636639,0.530705,"[0.7475845410628019, 0.5092024539877301, 0.335...",0.522416,"[0.7805800756620429, 0.4839650145772595, 0.302...",0.543440,"[0.7172653534183082, 0.5372168284789643, 0.375...",0.733269,0.577251,...,"(128, 128)",0.5,0.0010,1,focal,2,False,,64,0
0,0.615039,0.638910,0.539896,"[0.75, 0.4919093851132686, 0.3777777777777778]",0.531070,"[0.7790262172284644, 0.4919093851132686, 0.322...",0.557115,"[0.7230590961761297, 0.4919093851132686, 0.456...",0.711731,0.597156,...,"(128, 128)",0.5,0.0010,12,focal,2,False,,64,0
0,0.611566,0.632854,0.534817,"[0.7428222357971899, 0.5116279069767441, 0.35]",0.525060,"[0.7855297157622739, 0.49107142857142855, 0.29...",0.553773,"[0.7045191193511008, 0.5339805825242718, 0.422...",0.703668,0.579147,...,"(128, 128)",0.5,0.0010,123,focal,2,False,,64,0
0,0.621929,0.632097,0.533083,"[0.741011578305911, 0.5052950075642965, 0.3529...",0.523353,"[0.781491002570694, 0.4744318181818182, 0.3141...",0.549219,"[0.7045191193511008, 0.540453074433657, 0.4026...",0.719943,0.570616,...,"(128, 128)",0.5,0.0010,1234,focal,2,False,,64,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.622785,0.632854,0.536788,"[0.7463235294117648, 0.5056, 0.35844155844155845]",0.528103,"[0.7919375812743823, 0.5, 0.2923728813559322]",0.560031,"[0.7056778679026651, 0.511326860841424, 0.4630...",0.715886,0.588626,...,"(256, 256)",0.1,0.0005,0,focal,2,False,,64,17
0,0.621736,0.623013,0.529769,"[0.7288449660284126, 0.5088757396449703, 0.351...",0.519056,"[0.7804232804232805, 0.46866485013623976, 0.30...",0.549897,"[0.6836616454229433, 0.5566343042071198, 0.409...",0.720500,0.573460,...,"(256, 256)",0.1,0.0005,1,focal,2,False,,64,17
0,0.632554,0.611658,0.508512,"[0.7301587301587301, 0.4676923076923077, 0.327...",0.500096,"[0.7716129032258064, 0.44574780058651026, 0.28...",0.524701,"[0.6929316338354577, 0.4919093851132686, 0.389...",0.730346,0.571564,...,"(256, 256)",0.1,0.0005,12,focal,2,False,,64,17
0,0.610736,0.653293,0.535385,"[0.7671711292200233, 0.5017064846416383, 0.337...",0.534344,"[0.7707602339181286, 0.5306859205776173, 0.301...",0.540631,"[0.7636152954808807, 0.47572815533980584, 0.38...",0.719530,0.595261,...,"(256, 256)",0.1,0.0005,123,focal,2,False,,64,17


In [12]:
ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(128, 128)",0.1,0.0001,0.625519,0.627858,0.523414,0.515132,0.53823,0.721451,0.582749,0.506509,0.502218,0.516595,274.0,2.0,0.0,64.0,7.0
"(128, 128)",0.1,0.0005,0.621501,0.62922,0.526017,0.517875,0.541612,0.721479,0.583886,0.506266,0.503091,0.51723,274.0,2.0,0.0,64.0,8.0
"(128, 128)",0.1,0.001,0.624819,0.626646,0.522916,0.515997,0.538216,0.725505,0.577441,0.499081,0.498508,0.510319,274.0,2.0,0.0,64.0,6.0
"(128, 128)",0.2,0.0001,0.624303,0.62922,0.526532,0.518033,0.541426,0.721702,0.583507,0.50458,0.500657,0.514134,274.0,2.0,0.0,64.0,4.0
"(128, 128)",0.2,0.0005,0.625194,0.632551,0.525395,0.518032,0.538174,0.726306,0.583697,0.50432,0.501007,0.513803,274.0,2.0,0.0,64.0,5.0
"(128, 128)",0.2,0.001,0.625463,0.627252,0.524187,0.517152,0.539588,0.72358,0.578768,0.499506,0.498692,0.510902,274.0,2.0,0.0,64.0,3.0
"(128, 128)",0.5,0.0001,0.628589,0.628312,0.520393,0.513558,0.534173,0.717527,0.585592,0.504722,0.501229,0.515303,274.0,2.0,0.0,64.0,1.0
"(128, 128)",0.5,0.0005,0.62379,0.628463,0.530722,0.521539,0.549639,0.71536,0.572512,0.499998,0.496218,0.51525,274.0,2.0,0.0,64.0,2.0
"(128, 128)",0.5,0.001,0.624308,0.63134,0.532289,0.522995,0.549996,0.718385,0.575166,0.50181,0.498213,0.5169,274.0,2.0,0.0,64.0,0.0
"(256, 256)",0.1,0.0001,0.620256,0.628463,0.523982,0.516164,0.538083,0.722609,0.587867,0.509627,0.505583,0.518271,274.0,2.0,0.0,64.0,16.0


In [13]:
best_ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,0.622476,0.629069,0.526388,"[0.7410551849605822, 0.5, 0.33810888252148996]",0.51765,"[0.77735368956743, 0.48059701492537316, 0.295]",0.541668,"[0.7079953650057937, 0.5210355987055016, 0.395...",0.721607,0.588626,...,"[0.6807817589576547, 0.4805194805194805, 0.413...","(256, 256)",0.2,0.0001,0,focal,2,False,,64
0,0.624013,0.619228,0.5132,"[0.7357013847080072, 0.4715189873417722, 0.332...",0.505655,"[0.7656641604010025, 0.4613003095975232, 0.29]",0.526486,"[0.7079953650057937, 0.48220064724919093, 0.38...",0.717619,0.589573,...,"[0.6938110749185668, 0.4577922077922078, 0.413...","(256, 256)",0.2,0.0001,1,focal,2,False,,64
0,0.62424,0.625284,0.517129,"[0.7432675044883305, 0.4738562091503268, 0.334...",0.510942,"[0.7685643564356436, 0.47854785478547857, 0.28...",0.530508,"[0.7195828505214369, 0.4692556634304207, 0.402...",0.719808,0.591469,...,"[0.7084690553745928, 0.4318181818181818, 0.421...","(256, 256)",0.2,0.0001,12,focal,2,False,,64
0,0.61889,0.630583,0.522057,"[0.7459653317393903, 0.483974358974359, 0.3362...",0.515218,"[0.7703703703703704, 0.4793650793650794, 0.295...",0.533665,"[0.7230590961761297, 0.4886731391585761, 0.389...",0.721754,0.582938,...,"[0.6970684039087948, 0.43506493506493504, 0.39...","(256, 256)",0.2,0.0001,123,focal,2,False,,64
0,0.621568,0.635882,0.532235,"[0.7468277945619335, 0.5015290519877676, 0.348...",0.523628,"[0.7803030303030303, 0.4753623188405797, 0.315...",0.545371,"[0.7161066048667439, 0.5307443365695793, 0.389...",0.729986,0.588626,...,"[0.6856677524429967, 0.4837662337662338, 0.383...","(256, 256)",0.2,0.0001,1234,focal,2,False,,64


In [14]:
best_ffn_current["f1"].mean()

0.5222016656517394

In [15]:
best_ffn_current["precision"].mean()

0.5146185485971388

In [16]:
best_ffn_current["recall"].mean()

0.5355394413139534

In [17]:
np.stack(best_ffn_current["f1_scores"]).mean(axis=0)

array([0.74256344, 0.48617572, 0.33786584])

In [18]:
np.stack(best_ffn_current["precision_scores"]).mean(axis=0)

array([0.77245112, 0.47503452, 0.29637001])

In [19]:
np.stack(best_ffn_current["recall_scores"]).mean(axis=0)

array([0.71494786, 0.49838188, 0.39328859])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [20]:
ffn_current_kfold, best_ffn_current_kfold, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data_client,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_client,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=True,
    patience=patience,
    split_ids=client_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv


In [21]:
ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,,0.623533,0.522106,"[0.7494803765741533, 0.44414168937329696, 0.37...",0.521112,"[0.7601686507936508, 0.49356548069644207, 0.30...",0.536963,"[0.7390884977091874, 0.4037151702786378, 0.468...",,0.683268,...,"(128, 128)",0.5,0.0010,0,focal,2,True,5,64,0
0,,0.617132,0.515195,"[0.743109151047409, 0.44414535666218036, 0.358...",0.513753,"[0.7552290836653387, 0.4863669859985262, 0.299...",0.528533,"[0.7313720761996624, 0.4086687306501548, 0.445...",,0.679239,...,"(128, 128)",0.5,0.0010,1,focal,2,True,5,64,0
0,,0.624143,0.518046,"[0.7502427184466018, 0.44296197464976655, 0.36...",0.516224,"[0.755191790862448, 0.48011569052783803, 0.313...",0.527345,"[0.7453580901856764, 0.41114551083591333, 0.42...",,0.680918,...,"(128, 128)",0.5,0.0010,12,focal,2,True,5,64,0
0,,0.623685,0.527732,"[0.7443702053947043, 0.47109067017082784, 0.36...",0.524257,"[0.7644218551461245, 0.5017494751574527, 0.306...",0.542877,"[0.7253436218953461, 0.4439628482972136, 0.459...",,0.677448,...,"(128, 128)",0.5,0.0010,123,focal,2,True,5,64,0
0,,0.622161,0.519448,"[0.7481751824817519, 0.43789035392088826, 0.37...",0.520314,"[0.7549717652835748, 0.4980268350434096, 0.307...",0.534267,"[0.7414998794309139, 0.3907120743034056, 0.470...",,0.683044,...,"(128, 128)",0.5,0.0010,1234,focal,2,True,5,64,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.615455,0.509210,"[0.7442088491236672, 0.44415329184408775, 0.33...",0.506616,"[0.7567298105682951, 0.4714881780250348, 0.291...",0.519139,"[0.7320954907161804, 0.4198142414860681, 0.405...",,0.677112,...,"(256, 256)",0.1,0.0005,0,focal,2,True,5,64,17
0,,0.605853,0.506332,"[0.7327693456083603, 0.4424778761061947, 0.34375]",0.501053,"[0.7568748393729119, 0.45190445448676564, 0.29...",0.518868,"[0.7101519170484688, 0.43343653250773995, 0.41...",,0.674538,...,"(256, 256)",0.1,0.0005,1,focal,2,True,5,64,17
0,,0.617284,0.520226,"[0.741308919955462, 0.4478964401294499, 0.3714...",0.515038,"[0.7611788617886179, 0.46915254237288134, 0.31...",0.534666,"[0.7224499638292742, 0.4284829721362229, 0.453...",,0.684275,...,"(256, 256)",0.1,0.0005,12,focal,2,True,5,64,17
0,,0.614693,0.521874,"[0.7370263848943354, 0.45692405871091263, 0.37...",0.515740,"[0.7654545454545455, 0.47136273864384465, 0.31...",0.539019,"[0.7106341933928141, 0.443343653250774, 0.4630...",,0.675210,...,"(256, 256)",0.1,0.0005,123,focal,2,True,5,64,17


In [22]:
ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
"(128, 128)",0.1,0.0001,0.624326,0.522265,0.518262,0.531748,0.680425,0.603583,0.600687,0.611739,274.0,2.0,1.0,5.0,64.0,7.0
"(128, 128)",0.1,0.0005,0.616949,0.518128,0.513417,0.530677,0.677068,0.604202,0.599074,0.617787,274.0,2.0,1.0,5.0,64.0,8.0
"(128, 128)",0.1,0.001,0.61576,0.51838,0.513756,0.533041,0.674717,0.602277,0.59634,0.617416,274.0,2.0,1.0,5.0,64.0,6.0
"(128, 128)",0.2,0.0001,0.626703,0.525416,0.521354,0.535419,0.681231,0.605319,0.602135,0.614316,274.0,2.0,1.0,5.0,64.0,4.0
"(128, 128)",0.2,0.0005,0.618839,0.515949,0.512417,0.526095,0.680582,0.606311,0.602377,0.617667,274.0,2.0,1.0,5.0,64.0,5.0
"(128, 128)",0.2,0.001,0.614479,0.51698,0.512375,0.531819,0.675837,0.603931,0.598065,0.619645,274.0,2.0,1.0,5.0,64.0,3.0
"(128, 128)",0.5,0.0001,0.628166,0.526122,0.522625,0.536016,0.682843,0.606357,0.604082,0.614939,274.0,2.0,1.0,5.0,64.0,1.0
"(128, 128)",0.5,0.0005,0.621521,0.518934,0.516869,0.530565,0.684477,0.609299,0.607495,0.62035,274.0,2.0,1.0,5.0,64.0,2.0
"(128, 128)",0.5,0.001,0.622131,0.520505,0.519132,0.533997,0.680783,0.60625,0.603863,0.619974,274.0,2.0,1.0,5.0,64.0,0.0
"(256, 256)",0.1,0.0001,0.62902,0.525246,0.522288,0.533646,0.685574,0.609094,0.607242,0.616003,274.0,2.0,1.0,5.0,64.0,16.0


In [23]:
best_ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,,0.628258,0.522119,"[0.7515195720885, 0.46014257939079717, 0.35469...",0.519319,"[0.7577837705319931, 0.48266485384092456, 0.31...",0.528913,"[0.7453580901856764, 0.43962848297213625, 0.40...",,0.686961,...,"[0.7892636924192964, 0.50355871886121, 0.55754...","(256, 256)",0.5,0.0001,0,focal,2,True,5,64
0,,0.623228,0.516974,"[0.7488417459156304, 0.45150827116445025, 0.35...",0.513888,"[0.7573366214549938, 0.47411444141689374, 0.31...",0.524833,"[0.7405353267422233, 0.4309597523219814, 0.403...",,0.688192,...,"[0.7870874138556402, 0.5111209964412812, 0.562...","(256, 256)",0.5,0.0001,1,focal,2,True,5,64
0,,0.63283,0.532779,"[0.7537295182196136, 0.4653883652908678, 0.379...",0.528592,"[0.7645745472587447, 0.4897400820793434, 0.331...",0.543195,"[0.7431878466361225, 0.443343653250774, 0.4430...",,0.685506,...,"[0.7863619876677548, 0.5048932384341637, 0.557...","(256, 256)",0.5,0.0001,12,focal,2,True,5,64
0,,0.625514,0.517296,"[0.750182437363172, 0.4589107315501128, 0.3427...",0.514641,"[0.7568098159509202, 0.478494623655914, 0.3086...",0.52334,"[0.7436701229804679, 0.4408668730650155, 0.385...",,0.685506,...,"[0.7867247007616975, 0.5071174377224199, 0.551...","(256, 256)",0.5,0.0001,123,focal,2,True,5,64
0,,0.631763,0.527912,"[0.7538311846266116, 0.46208916368369674, 0.36...",0.524769,"[0.7604907975460122, 0.4869684499314129, 0.326...",0.535814,"[0.7472871955630577, 0.43962848297213625, 0.42...",,0.68573,...,"[0.7876314834965542, 0.5084519572953736, 0.546...","(256, 256)",0.5,0.0001,1234,focal,2,True,5,64


In [24]:
best_ffn_current_kfold["f1"].mean()

0.52341592796879

In [25]:
best_ffn_current_kfold["precision"].mean()

0.5202415683845569

In [26]:
best_ffn_current_kfold["recall"].mean()

0.5312188732186052

In [27]:
np.stack(best_ffn_current_kfold["f1_scores"]).mean(axis=0)

array([0.75162089, 0.45960782, 0.35901907])

In [28]:
np.stack(best_ffn_current_kfold["precision_scores"]).mean(axis=0)

array([0.75939911, 0.48239649, 0.3189291 ])

In [29]:
np.stack(best_ffn_current_kfold["recall_scores"]).mean(axis=0)

array([0.74400772, 0.43888545, 0.41076345])