In [30]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    histories_baseline_hyperparameter_search
)

In [3]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-25 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-25 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-25 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-25 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-25 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

In [7]:
client_index

[False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 True,
 True,
 False,
 True,
 False,
 True,
 False,
 T

In [8]:
client_transcript_id

tensor([  0,   0,   0,  ..., 131, 131, 131])

# Baseline: FFN using signatures

First, we dimension reduce these and then take signatures. We use the path signature as input to the FFN for classification.

We want to choose a dimension and signature depth such that the number of terms in the signature is _roughly_ 384 so that it is comparable to the number of features that we used for the previous baseline where we computed the mean of the history. Again, we are concatenating the features we obtain with the current utterance embedding.

In [9]:
num_epochs = 100
hidden_dim_sizes = [[128,128],[256,256],[512,512]]
dropout_rates = [0.5, 0.2, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [0, 1, 12, 123, 1234]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

## Using log signature

In [10]:
log_signature_dimensions_and_sig_depths = [(28, 2), (10, 3), (6, 4)]

In [11]:
import signatory

[signatory.logsignature_channels(channels, depth)
 for (channels, depth) in log_signature_dimensions_and_sig_depths]

[406, 385, 406]

### Using UMAP

In [12]:
ffn_logsignature_umap_kfold, best_ffn_logsignature_umap_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    log_signature=True,
    dim_reduce_methods=["umap"],
    dimension_and_sig_depths=log_signature_dimensions_and_sig_depths,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_umap_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_logsignature_umap_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_logsignature_umap_focal_2_kfold_best_model.csv


In [13]:
ffn_logsignature_umap_kfold.groupby(["dimension", "sig_depth", "hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_logsignature_umap_kfold.groupby(["dimension", "sig_depth", "hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id,input_dim,log_signature
dimension,sig_depth,hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
6,4,"(128, 128)",0.1,0.0001,0.507727,0.360960,0.361479,0.361194,0.674561,0.605467,0.598644,0.615130,274.0,2.0,1.0,5.0,64.0,2.70,790.0,1.0
6,4,"(128, 128)",0.1,0.0005,0.521780,0.353246,0.356971,0.353992,0.677583,0.603045,0.601896,0.605461,274.0,2.0,1.0,5.0,64.0,2.80,790.0,1.0
6,4,"(128, 128)",0.1,0.0010,0.533547,0.351780,0.356994,0.353643,0.671964,0.590930,0.596559,0.588186,274.0,2.0,1.0,5.0,64.0,2.60,790.0,1.0
6,4,"(128, 128)",0.2,0.0001,0.510837,0.350623,0.353179,0.351986,0.667085,0.593311,0.590125,0.600733,274.0,2.0,1.0,5.0,64.0,2.40,790.0,1.0
6,4,"(128, 128)",0.2,0.0005,0.535985,0.351374,0.357155,0.353632,0.669950,0.579355,0.593786,0.570571,274.0,2.0,1.0,5.0,64.0,2.50,790.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28,2,"(512, 512)",0.2,0.0005,0.558482,0.421841,0.424717,0.422468,0.691214,0.622049,0.613301,0.635576,274.0,2.0,1.0,5.0,64.0,0.23,790.0,1.0
28,2,"(512, 512)",0.2,0.0010,0.552446,0.418218,0.419986,0.418650,0.684947,0.607931,0.603000,0.615136,274.0,2.0,1.0,5.0,64.0,0.21,790.0,1.0
28,2,"(512, 512)",0.5,0.0001,0.571651,0.424104,0.437790,0.418329,0.700727,0.632929,0.625617,0.642152,274.0,2.0,1.0,5.0,64.0,0.19,790.0,1.0
28,2,"(512, 512)",0.5,0.0005,0.569212,0.425719,0.432750,0.422995,0.698041,0.622801,0.619750,0.627353,274.0,2.0,1.0,5.0,64.0,0.20,790.0,1.0


In [14]:
best_ffn_logsignature_umap_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,loss_function,gamma,k_fold,n_splits,batch_size,input_dim,dimension,sig_depth,method,log_signature
0,,0.554031,0.419155,"[0.7003303445021235, 0.30224470439456214, 0.25...",0.42357,"[0.6856086856086856, 0.3087855297157623, 0.276...",0.416073,"[0.7156980950084398, 0.2959752321981424, 0.236...",,0.710241,...,focal,2,True,5,64,790,28,2,umap,True
0,,0.557994,0.420992,"[0.7042889390519187, 0.32575521644347555, 0.23...",0.424066,"[0.6941451990632318, 0.3276942355889724, 0.250...",0.418782,"[0.7147335423197492, 0.3238390092879257, 0.217...",,0.707219,...,focal,2,True,5,64,790,28,2,umap,True
0,,0.567139,0.437518,"[0.7096165191740413, 0.3223418573351279, 0.280...",0.438676,"[0.6947781885397413, 0.35298452468680913, 0.26...",0.438605,"[0.7251024837231734, 0.29659442724458207, 0.29...",,0.709121,...,focal,2,True,5,64,790,28,2,umap,True
0,,0.586344,0.448983,"[0.7226557949912639, 0.3574440480051898, 0.266...",0.456829,"[0.6989634970707526, 0.3753405994550409, 0.296...",0.443997,"[0.7480106100795756, 0.3411764705882353, 0.242...",,0.706659,...,focal,2,True,5,64,790,28,2,umap,True
0,,0.589392,0.433413,"[0.7356192194901872, 0.2915430267062315, 0.273...",0.444828,"[0.6910362364907819, 0.3635522664199815, 0.279...",0.432093,"[0.7863515794550278, 0.243343653250774, 0.2665...",,0.706995,...,focal,2,True,5,64,790,28,2,umap,True


In [15]:
best_ffn_logsignature_umap_kfold["f1"].mean()

0.4320122761881547

In [16]:
best_ffn_logsignature_umap_kfold["precision"].mean()

0.4375938257198291

In [17]:
best_ffn_logsignature_umap_kfold["recall"].mean()

0.4299098254001956

In [18]:
np.stack(best_ffn_logsignature_umap_kfold["f1_scores"]).mean(axis=0)

array([0.71450216, 0.31986577, 0.26166889])

In [19]:
np.stack(best_ffn_logsignature_umap_kfold["precision_scores"]).mean(axis=0)

array([0.69290636, 0.34567143, 0.27420368])

In [20]:
np.stack(best_ffn_logsignature_umap_kfold["recall_scores"]).mean(axis=0)

array([0.73797926, 0.30018576, 0.25156446])

### Using random projections

In [21]:
ffn_logsignature_grp_kfold, best_ffn_logsignature_grp_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    log_signature=True,
    dim_reduce_methods=["gaussian_random_projection"],
    dimension_and_sig_depths=log_signature_dimensions_and_sig_depths,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_grp_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_logsignature_grp_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_logsignature_grp_focal_2_kfold_best_model.csv


In [22]:
ffn_logsignature_grp_kfold.groupby(["dimension", "sig_depth", "hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_logsignature_grp_kfold.groupby(["dimension", "sig_depth", "hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id,input_dim,log_signature
dimension,sig_depth,hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
6,4,"(128, 128)",0.1,0.0001,0.596738,0.497316,0.492549,0.504843,0.726782,0.666325,0.656439,0.680714,274.0,2.0,1.0,5.0,64.0,2.70,790.0,1.0
6,4,"(128, 128)",0.1,0.0005,0.590520,0.492322,0.487288,0.502508,0.722171,0.664517,0.652301,0.683179,274.0,2.0,1.0,5.0,64.0,2.80,790.0,1.0
6,4,"(128, 128)",0.1,0.0010,0.583783,0.487006,0.482385,0.499384,0.720739,0.664204,0.651099,0.684598,274.0,2.0,1.0,5.0,64.0,2.60,790.0,1.0
6,4,"(128, 128)",0.2,0.0001,0.599726,0.498823,0.494740,0.505544,0.729200,0.667892,0.659051,0.680876,274.0,2.0,1.0,5.0,64.0,2.40,790.0,1.0
6,4,"(128, 128)",0.2,0.0005,0.590642,0.494794,0.489605,0.507964,0.721500,0.664249,0.651206,0.684439,274.0,2.0,1.0,5.0,64.0,2.50,790.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28,2,"(512, 512)",0.2,0.0005,0.602683,0.497789,0.495270,0.502661,0.736363,0.678770,0.668547,0.694509,274.0,2.0,1.0,5.0,64.0,0.23,790.0,1.0
28,2,"(512, 512)",0.2,0.0010,0.590764,0.492928,0.487598,0.506012,0.728774,0.672048,0.659378,0.692117,274.0,2.0,1.0,5.0,64.0,0.21,790.0,1.0
28,2,"(512, 512)",0.5,0.0001,0.612498,0.496083,0.499469,0.493388,0.742899,0.685333,0.675236,0.698494,274.0,2.0,1.0,5.0,64.0,0.19,790.0,1.0
28,2,"(512, 512)",0.5,0.0005,0.600671,0.501886,0.497436,0.508291,0.740168,0.684544,0.672763,0.700381,274.0,2.0,1.0,5.0,64.0,0.20,790.0,1.0


In [23]:
best_ffn_logsignature_grp_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,loss_function,gamma,k_fold,n_splits,batch_size,input_dim,dimension,sig_depth,method,log_signature
0,,0.607682,0.487769,"[0.7375328083989501, 0.4107657801986543, 0.315...",0.488826,"[0.7298701298701299, 0.4256308100929615, 0.310...",0.487137,"[0.7453580901856764, 0.39690402476780184, 0.31...",,0.745495,...,focal,2,True,5,64,790,28,2,gaussian_random_projection,True
0,,0.613931,0.499005,"[0.7391977145577906, 0.4161073825503355, 0.341...",0.500963,"[0.729901269393512, 0.4299867899603699, 0.3430...",0.497419,"[0.7487340245960936, 0.40309597523219814, 0.34...",,0.745831,...,focal,2,True,5,64,790,28,2,gaussian_random_projection,True
0,,0.612102,0.503115,"[0.73437876960193, 0.4318957493018926, 0.34307...",0.50277,"[0.7347332850591359, 0.43283582089552236, 0.34...",0.503472,"[0.7340245960935616, 0.4309597523219814, 0.345...",,0.743928,...,focal,2,True,5,64,790,28,2,gaussian_random_projection,True
0,,0.609511,0.489517,"[0.7385348421679572, 0.4137709137709138, 0.316...",0.490876,"[0.7297551789077212, 0.4306764902880107, 0.312...",0.48869,"[0.7475283337352303, 0.39814241486068114, 0.32...",,0.742026,...,focal,2,True,5,64,790,28,2,gaussian_random_projection,True
0,,0.604481,0.493301,"[0.7306628172741488, 0.40918969264203664, 0.34...",0.493846,"[0.728996639462314, 0.41033623910336237, 0.342...",0.49277,"[0.732336628888353, 0.40804953560371515, 0.337...",,0.74404,...,focal,2,True,5,64,790,28,2,gaussian_random_projection,True


In [24]:
best_ffn_logsignature_grp_kfold["f1"].mean()

0.49454129875840563

In [25]:
best_ffn_logsignature_grp_kfold["precision"].mean()

0.49545604731393744

In [26]:
best_ffn_logsignature_grp_kfold["recall"].mean()

0.4938975025158072

In [27]:
np.stack(best_ffn_logsignature_grp_kfold["f1_scores"]).mean(axis=0)

array([0.73606139, 0.4163459 , 0.3312166 ])

In [28]:
np.stack(best_ffn_logsignature_grp_kfold["precision_scores"]).mean(axis=0)

array([0.7306513 , 0.42589323, 0.32982361])

In [29]:
np.stack(best_ffn_logsignature_grp_kfold["recall_scores"]).mean(axis=0)

array([0.74159633, 0.40743034, 0.33266583])