In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from nlpsig_networks.scripts.swmhau_network_functions import (
    swmhau_network_hyperparameter_search,
)

In [4]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [5]:
%run ../load_anno_mi.py

In [6]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-10-01 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-10-01 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-10-01 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-10-01 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-10-01 00:00:34,-1


In [7]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# swmhau Network

## Obtaining path by looking at post history

We can obtain a path by looking at the history of each post. Here we look at the last 10 posts (and pad with vectors of zeros if there are less than 10 posts) including the current post.

We only want to consider paths that correspond to a client's utterance as we want to model a change in mood at that time. Their history will still contain the therapist's utterances too.

In [8]:
features = ["timeline_index", "speaker"]
standardise_method = [None, None]
include_features_in_path = True
include_features_in_input = False

In [9]:
num_epochs = 100
dimensions = [15]
# define swmhau parameters: (output_channels, sig_depth, num_heads)
swmhau_parameters = [(12, 3, 10), (8, 4, 6)]
num_layers = [1]
ffn_hidden_dim_sizes = [[256, 256], [512, 512]]
dropout_rates = [0.1, 0.2]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [None]:
kwargs = {
    "num_epochs": num_epochs,
    "df": anno_mi,
    "id_column": "transcript_id",
    "label_column": "client_talk_type",
    "embeddings": sbert_embeddings,
    "y_data": y_data_client,
    "output_dim": output_dim_client,
    "dimensions": dimensions,
    "log_signature": True,
    "swmhau_parameters": swmhau_parameters,
    "num_layers": num_layers,
    "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes,
    "dropout_rates": dropout_rates,
    "learning_rates": learning_rates,
    "seeds": seeds,
    "loss": loss,
    "gamma": gamma,
    "device": device,
    "features": features,
    "standardise_method": standardise_method,
    "include_features_in_path": include_features_in_path,
    "include_features_in_input": include_features_in_input,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "patience": patience,
    "validation_metric": validation_metric,
    "verbose": False,
}

# w=5

In [10]:
size = 5

## UMAP

In [11]:
(
    swmhau_network_umap_kfold_5,
    best_swmhau_network_umap_kfold_5,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_5_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_5_kfold_best_model.csv


In [12]:
swmhau_network_umap_kfold_5.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_5.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.640336,0.482679,0.493165,0.476756,0.641221,0.48019,0.49084,0.474597,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.660197,0.550703,0.546765,0.562782,0.65566,0.543267,0.53961,0.554517,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.658467,0.54766,0.545894,0.562311,0.65021,0.537786,0.536211,0.553204,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.643346,0.506706,0.510292,0.505869,0.641116,0.502109,0.504308,0.502324,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.659954,0.549407,0.546009,0.560956,0.654691,0.543142,0.539284,0.55496,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.652723,0.5416,0.538661,0.554723,0.648585,0.53771,0.53442,0.553,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.640648,0.473925,0.490131,0.467486,0.640252,0.471295,0.487253,0.465227,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.66203,0.547727,0.545041,0.555928,0.657862,0.541814,0.539214,0.54991,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.657844,0.548781,0.544864,0.561378,0.651572,0.541747,0.537682,0.556031,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.649609,0.51601,0.519302,0.514667,0.642086,0.50286,0.50479,0.502279,5.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,32.5


In [13]:
best_swmhau_network_umap_kfold_5

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.655595,0.54893,"[0.7752996005326231, 0.5074246924056003, 0.364...",0.541514,"[0.8008940852819807, 0.5072094995759118, 0.316...",0.562501,"[0.7512903225806452, 0.5076400679117148, 0.428...",,0.650943,...,0.001,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.655387,0.547256,"[0.775890231950343, 0.4795321637426901, 0.3863...",0.54265,"[0.78590337524818, 0.5100478468899522, 0.332]",0.560186,"[0.7661290322580645, 0.4524617996604414, 0.461...",,0.652516,...,0.001,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.661823,0.550473,"[0.7834549878345499, 0.4817044928207504, 0.386...",0.548715,"[0.7879282218597063, 0.5300713557594292, 0.328...",0.563282,"[0.7790322580645161, 0.44142614601018676, 0.46...",,0.660849,...,0.001,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_umap_kfold_5["f1"].mean()

0.5488864344199431

In [15]:
best_swmhau_network_umap_kfold_5["precision"].mean()

0.5442931118519999

In [16]:
best_swmhau_network_umap_kfold_5["recall"].mean()

0.5619894905536428

In [17]:
np.stack(best_swmhau_network_umap_kfold_5["f1_scores"]).mean(axis=0)

array([0.77821494, 0.48955378, 0.37889058])

In [18]:
np.stack(best_swmhau_network_umap_kfold_5["precision_scores"]).mean(axis=0)

array([0.79157523, 0.51577623, 0.32552787])

In [19]:
np.stack(best_swmhau_network_umap_kfold_5["recall_scores"]).mean(axis=0)

array([0.76548387, 0.467176  , 0.4533086 ])

## Random Projections

In [20]:
(
    swmhau_network_grp_kfold_5,
    best_swmhau_network_grp_kfold_5,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_5_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_5_kfold_best_model.csv


In [21]:
swmhau_network_grp_kfold_5.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_5.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,include_features_in_input,embedding_dim,num_features,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.651581,0.510848,0.514279,0.509714,0.650629,0.510344,0.514993,0.508165,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.665421,0.552505,0.550572,0.561456,0.659539,0.54306,0.541759,0.551483,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.653034,0.54681,0.542822,0.563689,0.648166,0.540349,0.536329,0.557713,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(256, 256)",0.2,0.0001,0.652066,0.50748,0.511219,0.504862,0.648742,0.506703,0.511321,0.503527,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.2,0.0005,0.664591,0.550962,0.549184,0.560563,0.658176,0.541883,0.540326,0.550894,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.2,0.001,0.660577,0.550947,0.547768,0.563603,0.653302,0.543406,0.539697,0.558374,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.663553,0.528451,0.533048,0.525671,0.654769,0.515913,0.521432,0.512261,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.653934,0.547192,0.54182,0.561383,0.64979,0.544153,0.538013,0.560044,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,20.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.651235,0.540689,0.539109,0.555534,0.646331,0.535187,0.533607,0.552316,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,18.0
15,8,4,6,1,"(512, 512)",0.2,0.0001,0.664037,0.532037,0.535644,0.530511,0.657862,0.525046,0.529332,0.522942,5.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0


In [22]:
best_swmhau_network_grp_kfold_5

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.665767,0.553378,"[0.7845232313420755, 0.4928057553956835, 0.382...",0.550471,"[0.7874553136171596, 0.5239005736137667, 0.340...",0.561552,"[0.7816129032258065, 0.46519524617996605, 0.43...",,0.658491,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.658086,0.548396,"[0.7786657967705106, 0.48490310950878773, 0.38...",0.5444,"[0.7875288683602771, 0.5168107588856868, 0.328...",0.560417,"[0.77, 0.4567062818336163, 0.45454545454545453]",,0.656447,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.659539,0.553164,"[0.7791270101739416, 0.48851590106007065, 0.39...",0.547113,"[0.7929191716766867, 0.509208103130755, 0.3392...",0.566356,"[0.7658064516129032, 0.4694397283531409, 0.463...",,0.650786,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [23]:
best_swmhau_network_grp_kfold_5["f1"].mean()

0.5516462706171555

In [24]:
best_swmhau_network_grp_kfold_5["precision"].mean()

0.5473280569283129

In [25]:
best_swmhau_network_grp_kfold_5["recall"].mean()

0.5627750916181671

In [26]:
np.stack(best_swmhau_network_grp_kfold_5["f1_scores"]).mean(axis=0)

array([0.78077201, 0.48874159, 0.38542521])

In [27]:
np.stack(best_swmhau_network_grp_kfold_5["precision_scores"]).mean(axis=0)

array([0.78930112, 0.51663981, 0.33604324])

In [28]:
np.stack(best_swmhau_network_grp_kfold_5["recall_scores"]).mean(axis=0)

array([0.77247312, 0.46378042, 0.45207174])

# w=11

In [10]:
size = 11

## UMAP

In [30]:
(
    swmhau_network_umap_kfold_11,
    best_swmhau_network_umap_kfold_11,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_11_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_11_kfold_best_model.csv


In [31]:
swmhau_network_umap_kfold_11.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_11.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.632275,0.441281,0.468892,0.436022,0.628538,0.437348,0.463303,0.432526,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.662653,0.552123,0.548232,0.562828,0.657757,0.544635,0.541017,0.55512,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.660716,0.552117,0.548836,0.56567,0.652699,0.542098,0.538745,0.556755,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.640648,0.491061,0.499184,0.487466,0.634879,0.490746,0.498685,0.487505,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.663103,0.552675,0.549544,0.56281,0.658857,0.546893,0.543389,0.55718,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.660958,0.550473,0.549149,0.56372,0.652568,0.539084,0.537691,0.553023,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.633347,0.429898,0.466953,0.424219,0.631918,0.424719,0.461246,0.419892,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.667324,0.555524,0.553232,0.564608,0.659853,0.544113,0.541934,0.551957,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.662619,0.552202,0.550366,0.564233,0.654769,0.541328,0.539849,0.554478,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.638399,0.489445,0.496294,0.485676,0.635482,0.490903,0.49753,0.487346,11.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,32.5


In [32]:
best_swmhau_network_umap_kfold_11

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.654349,0.555269,"[0.7724321624771101, 0.4979919678714859, 0.395...",0.548833,"[0.7980736154110767, 0.5249294449670743, 0.323...",0.576807,"[0.7483870967741936, 0.47368421052631576, 0.50...",,0.647799,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.672203,0.555992,"[0.7885016862052353, 0.5013309671694764, 0.378...",0.55508,"[0.7850975375759514, 0.525092936802974, 0.3550...",0.558672,"[0.7919354838709678, 0.4796264855687606, 0.404...",,0.667925,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.659539,0.550721,"[0.7789542270728131, 0.4851485148514852, 0.388...",0.546523,"[0.786771964461994, 0.5162835249042146, 0.3365...",0.562367,"[0.7712903225806451, 0.45755517826825126, 0.45...",,0.657233,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [33]:
best_swmhau_network_umap_kfold_11["f1"].mean()

0.553993954974317

In [34]:
best_swmhau_network_umap_kfold_11["precision"].mean()

0.5501449814188412

In [35]:
best_swmhau_network_umap_kfold_11["recall"].mean()

0.5659484768337545

In [36]:
np.stack(best_swmhau_network_umap_kfold_11["f1_scores"]).mean(axis=0)

array([0.77996269, 0.49482382, 0.38719536])

In [37]:
np.stack(best_swmhau_network_umap_kfold_11["precision_scores"]).mean(axis=0)

array([0.78998104, 0.52210197, 0.33835194])

In [38]:
np.stack(best_swmhau_network_umap_kfold_11["recall_scores"]).mean(axis=0)

array([0.77053763, 0.47028862, 0.45701917])

## Random Projections

In [11]:
(
    swmhau_network_grp_kfold_11,
    best_swmhau_network_grp_kfold_11,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_11_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_11_kfold_best_model.csv


In [12]:
swmhau_network_grp_kfold_11.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_11.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,include_features_in_input,embedding_dim,num_features,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.655387,0.505974,0.517411,0.500432,0.648585,0.497945,0.51013,0.492513,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.668189,0.55488,0.552166,0.56213,0.661059,0.544093,0.542153,0.550671,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.655456,0.544671,0.540418,0.556102,0.651625,0.542867,0.537828,0.556852,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(256, 256)",0.2,0.0001,0.65511,0.490277,0.510192,0.482529,0.651677,0.483062,0.505207,0.475136,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.2,0.0005,0.665905,0.555109,0.55087,0.563691,0.657966,0.544336,0.540211,0.552603,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.2,0.001,0.655526,0.54675,0.542995,0.560433,0.65,0.541033,0.537215,0.556994,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.65691,0.523064,0.527179,0.521423,0.655608,0.52109,0.526343,0.518567,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.662238,0.551366,0.547431,0.560101,0.653145,0.540101,0.535649,0.548906,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,20.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.648675,0.541578,0.538018,0.559116,0.645021,0.539374,0.535605,0.559326,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,18.0
15,8,4,6,1,"(512, 512)",0.2,0.0001,0.657532,0.515893,0.525331,0.511737,0.656918,0.51382,0.52481,0.509244,11.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swmhau_network_grp_kfold_11

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.675109,0.561004,"[0.7927179814755669, 0.48633626679018066, 0.40...",0.561205,"[0.7849462365591398, 0.5351681957186545, 0.363...",0.566954,"[0.8006451612903226, 0.4456706281833616, 0.454...",,0.66195,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.664106,0.55745,"[0.7810709022433274, 0.4993441189331001, 0.391...",0.551558,"[0.7931493182573994, 0.5148782687105501, 0.346...",0.568303,"[0.7693548387096775, 0.4847198641765705, 0.450...",,0.663208,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.665144,0.555677,"[0.7845953002610966, 0.493368700265252, 0.3890...",0.550649,"[0.7939233817701453, 0.514760147601476, 0.3432...",0.566049,"[0.775483870967742, 0.47368421052631576, 0.448...",,0.65566,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_grp_kfold_11["f1"].mean()

0.5580437121150238

In [15]:
best_swmhau_network_grp_kfold_11["precision"].mean()

0.5544707877626771

In [16]:
best_swmhau_network_grp_kfold_11["recall"].mean()

0.5671020555158319

In [17]:
np.stack(best_swmhau_network_grp_kfold_11["f1_scores"]).mean(axis=0)

array([0.78612806, 0.49301636, 0.39498671])

In [18]:
np.stack(best_swmhau_network_grp_kfold_11["precision_scores"]).mean(axis=0)

array([0.79067298, 0.5216022 , 0.35113718])

In [19]:
np.stack(best_swmhau_network_grp_kfold_11["recall_scores"]).mean(axis=0)

array([0.78182796, 0.4680249 , 0.45145331])

# w=20

In [10]:
size = 20

## UMAP

In [11]:
(
    swmhau_network_umap_kfold_20,
    best_swmhau_network_umap_kfold_20,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_20_kfold_best_model.csv


In [12]:
swmhau_network_umap_kfold_20.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_20.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.628884,0.418432,0.466556,0.415253,0.628643,0.420592,0.465529,0.417535,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.666874,0.554198,0.551239,0.561267,0.661714,0.546086,0.54345,0.552813,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.659954,0.549599,0.54589,0.560774,0.655739,0.544818,0.54096,0.556858,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.637811,0.476903,0.496793,0.469384,0.636426,0.475085,0.493219,0.467965,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.664521,0.55267,0.549122,0.560319,0.658988,0.546128,0.542354,0.554416,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.661165,0.55231,0.549634,0.564984,0.65414,0.541847,0.539714,0.554933,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.628434,0.424375,0.468896,0.420075,0.62783,0.42648,0.468139,0.422095,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.66248,0.55414,0.548844,0.565688,0.659591,0.550466,0.545436,0.561544,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.664175,0.553284,0.55195,0.565085,0.657914,0.545221,0.543737,0.55817,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.637499,0.469893,0.494502,0.461354,0.637631,0.471364,0.493344,0.463307,20.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,32.5


In [13]:
best_swmhau_network_umap_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.669504,0.560952,"[0.7845232313420755, 0.4986642920747996, 0.399...",0.557306,"[0.7874553136171596, 0.5243445692883895, 0.360...",0.568658,"[0.7816129032258065, 0.47538200339558573, 0.44...",,0.668082,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.671165,0.558731,"[0.7872169584069376, 0.4967860422405877, 0.392...",0.558031,"[0.7838183562519987, 0.541, 0.3492753623188406]",0.565674,"[0.7906451612903226, 0.4592529711375212, 0.447...",,0.671541,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.668051,0.55663,"[0.7848509266720387, 0.4884135472370766, 0.396...",0.554022,"[0.784219001610306, 0.5140712945590994, 0.3637...",0.562224,"[0.785483870967742, 0.46519524617996605, 0.435...",,0.661792,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_umap_kfold_20["f1"].mean()

0.5587708818192851

In [15]:
best_swmhau_network_umap_kfold_20["precision"].mean()

0.5564533372275692

In [16]:
best_swmhau_network_umap_kfold_20["recall"].mean()

0.5655187367945068

In [17]:
np.stack(best_swmhau_network_umap_kfold_20["f1_scores"]).mean(axis=0)

array([0.78553037, 0.49462129, 0.39616098])

In [18]:
np.stack(best_swmhau_network_umap_kfold_20["precision_scores"]).mean(axis=0)

array([0.78516422, 0.52647195, 0.35772383])

In [19]:
np.stack(best_swmhau_network_umap_kfold_20["recall_scores"]).mean(axis=0)

array([0.78591398, 0.46661007, 0.44403216])

## Random Projections

In [20]:
(
    swmhau_network_grp_kfold_20,
    best_swmhau_network_grp_kfold_20,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_20_kfold_best_model.csv


In [21]:
swmhau_network_grp_kfold_20.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_20.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,include_features_in_input,embedding_dim,num_features,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.66037,0.517558,0.530409,0.51372,0.655346,0.513355,0.526706,0.509408,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.663137,0.554047,0.549809,0.565347,0.657966,0.546417,0.542262,0.557405,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.656702,0.549415,0.544776,0.563831,0.650314,0.542321,0.537382,0.558114,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(256, 256)",0.2,0.0001,0.663068,0.51927,0.534689,0.511983,0.660744,0.514401,0.530395,0.50685,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.2,0.0005,0.661269,0.556492,0.550165,0.570295,0.65456,0.546401,0.540834,0.55988,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.2,0.001,0.659678,0.552867,0.550151,0.569339,0.651834,0.541886,0.539259,0.558618,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.671995,0.544176,0.548555,0.541094,0.665304,0.535854,0.54069,0.53209,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.6612,0.553147,0.548099,0.564522,0.65739,0.549273,0.544474,0.56221,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,20.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.658086,0.544675,0.54245,0.55402,0.654665,0.541953,0.538719,0.55273,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,18.0
15,8,4,6,1,"(512, 512)",0.2,0.0001,0.667151,0.541035,0.544557,0.538753,0.661583,0.534074,0.53695,0.53212,20.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0


In [22]:
best_swmhau_network_grp_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.661408,0.554739,"[0.7769028871391077, 0.5072020951549542, 0.380...",0.549049,"[0.7903871829105474, 0.5220125786163522, 0.334...",0.565594,"[0.7638709677419355, 0.4932088285229202, 0.439...",,0.658648,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.662445,0.551283,"[0.7835518860288165, 0.4789254284390922, 0.391...",0.549381,"[0.7864803379915503, 0.527013251783894, 0.3346...",0.563589,"[0.7806451612903226, 0.4388794567062818, 0.471...",,0.666352,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.665352,0.556871,"[0.7831736235179472, 0.4896860986547085, 0.397...",0.552659,"[0.7886817140987896, 0.5190114068441065, 0.350...",0.567117,"[0.7777419354838709, 0.4634974533106961, 0.460...",,0.656918,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [23]:
best_swmhau_network_grp_kfold_20["f1"].mean()

0.5542978044413046

In [24]:
best_swmhau_network_grp_kfold_20["precision"].mean()

0.5503628419136443

In [25]:
best_swmhau_network_grp_kfold_20["recall"].mean()

0.5654334796634092

In [26]:
np.stack(best_swmhau_network_grp_kfold_20["f1_scores"]).mean(axis=0)

array([0.78120947, 0.49193787, 0.38974607])

In [27]:
np.stack(best_swmhau_network_grp_kfold_20["precision_scores"]).mean(axis=0)

array([0.78851641, 0.52267908, 0.33989303])

In [28]:
np.stack(best_swmhau_network_grp_kfold_20["recall_scores"]).mean(axis=0)

array([0.77408602, 0.46519525, 0.45701917])

# w=35

In [10]:
size = 35

## UMAP

In [11]:
(
    swmhau_network_umap_kfold_35,
    best_swmhau_network_umap_kfold_35,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_35_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_35_kfold_best_model.csv


In [12]:
swmhau_network_umap_kfold_35.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_35.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.632794,0.444878,0.487271,0.436974,0.634696,0.442336,0.484459,0.4351,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.664694,0.556212,0.551546,0.568166,0.661268,0.551474,0.546501,0.563114,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.665975,0.554473,0.551622,0.563956,0.660194,0.547334,0.544492,0.557292,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.638987,0.479944,0.502351,0.471731,0.639439,0.47722,0.49877,0.469678,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.671857,0.55847,0.556035,0.564418,0.666483,0.552882,0.549587,0.559797,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.664279,0.550092,0.547635,0.557667,0.657442,0.541153,0.538723,0.549454,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.629368,0.445461,0.484199,0.439802,0.625524,0.437837,0.472952,0.433283,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.670992,0.560352,0.557317,0.567963,0.666693,0.554952,0.551258,0.562876,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.65947,0.549162,0.544601,0.559217,0.657573,0.548265,0.543209,0.559577,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.640821,0.484706,0.505095,0.476638,0.638339,0.478196,0.498009,0.471011,35.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,32.5


In [13]:
best_swmhau_network_umap_kfold_35

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.682167,0.576071,"[0.7941128901827592, 0.5139318885448917, 0.420...",0.572267,"[0.7963023029516705, 0.5364727608494921, 0.384...",0.582989,"[0.7919354838709678, 0.4932088285229202, 0.463...",,0.676887,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.669504,0.562096,"[0.7844546048334421, 0.5109553023663453, 0.390...",0.556908,"[0.7943121693121693, 0.5280797101449275, 0.348...",0.571671,"[0.7748387096774193, 0.49490662139219016, 0.44...",,0.662421,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.66037,0.55853,"[0.7742256087460659, 0.5073252406864797, 0.394...",0.550626,"[0.7957099080694586, 0.500412881915772, 0.3557...",0.569954,"[0.7538709677419355, 0.5144312393887945, 0.441...",,0.659591,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_umap_kfold_35["f1"].mean()

0.5655658685015083

In [15]:
best_swmhau_network_umap_kfold_35["precision"].mean()

0.5599333425757905

In [16]:
best_swmhau_network_umap_kfold_35["recall"].mean()

0.5748712445826197

In [17]:
np.stack(best_swmhau_network_umap_kfold_35["f1_scores"]).mean(axis=0)

array([0.78426437, 0.51073748, 0.40169576])

In [18]:
np.stack(best_swmhau_network_umap_kfold_35["precision_scores"]).mean(axis=0)

array([0.79544146, 0.52165512, 0.36270345])

In [19]:
np.stack(best_swmhau_network_umap_kfold_35["recall_scores"]).mean(axis=0)

array([0.77354839, 0.5008489 , 0.45021645])

## Random Projections

In [11]:
(
    swmhau_network_grp_kfold_35,
    best_swmhau_network_grp_kfold_35,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_35_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_35_kfold_best_model.csv


In [12]:
swmhau_network_grp_kfold_35.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_35.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,include_features_in_input,embedding_dim,num_features,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.66985,0.538008,0.550643,0.530846,0.663574,0.530845,0.541324,0.524913,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.664037,0.550546,0.547861,0.557947,0.661111,0.546672,0.543252,0.554658,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.654557,0.543823,0.541344,0.558082,0.651834,0.540814,0.537611,0.556089,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(256, 256)",0.2,0.0001,0.669089,0.542697,0.554786,0.53712,0.660377,0.528778,0.538442,0.524579,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.2,0.0005,0.669642,0.55522,0.555044,0.561411,0.664308,0.546156,0.545782,0.552432,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.2,0.001,0.655733,0.54841,0.5463,0.566345,0.650629,0.542319,0.53978,0.56075,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.675663,0.556465,0.558316,0.555184,0.67065,0.549535,0.552078,0.547573,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.665975,0.548726,0.548732,0.554082,0.659539,0.539476,0.539164,0.545558,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,20.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.656633,0.548371,0.544669,0.562666,0.650734,0.539441,0.535934,0.554401,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,18.0
15,8,4,6,1,"(512, 512)",0.2,0.0001,0.680299,0.558312,0.563176,0.555103,0.66782,0.542091,0.545997,0.539456,35.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swmhau_network_grp_kfold_35

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.662445,0.548837,"[0.7775053130619586, 0.4990067540723084, 0.37]",0.552836,"[0.7882001988730527, 0.4690067214339059, 0.401...",0.547811,"[0.7670967741935484, 0.533106960950764, 0.3432...",,0.668396,...,0.0001,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.674071,0.562239,"[0.7864803379915501, 0.5142385472554686, 0.385...",0.562209,"[0.7924034053700065, 0.5004016064257029, 0.393...",0.562662,"[0.7806451612903226, 0.5288624787775892, 0.378...",,0.672013,...,0.0001,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.683205,0.574145,"[0.7963859309454662, 0.5141843971631206, 0.411...",0.571257,"[0.7966429954809554, 0.5380333951762524, 0.379...",0.579775,"[0.7961290322580645, 0.4923599320882852, 0.450...",,0.675157,...,0.0001,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_grp_kfold_35["f1"].mean()

0.5617404216802162

In [15]:
best_swmhau_network_grp_kfold_35["precision"].mean()

0.5621008220918889

In [16]:
best_swmhau_network_grp_kfold_35["recall"].mean()

0.5634157870587654

In [17]:
np.stack(best_swmhau_network_grp_kfold_35["f1_scores"]).mean(axis=0)

array([0.78679053, 0.50914323, 0.3892875 ])

In [18]:
np.stack(best_swmhau_network_grp_kfold_35["precision_scores"]).mean(axis=0)

array([0.79241553, 0.50248057, 0.39140636])

In [19]:
np.stack(best_swmhau_network_grp_kfold_35["recall_scores"]).mean(axis=0)

array([0.78129032, 0.51810979, 0.39084725])

# w=80

In [10]:
size = 80

## UMAP

In [None]:
(
    swmhau_network_umap_kfold_80,
    best_swmhau_network_umap_kfold_80,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


In [21]:
swmhau_network_umap_kfold_80.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_80.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.644281,0.498125,0.513844,0.490017,0.63522,0.493172,0.505498,0.487687,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.678154,0.572334,0.568054,0.583155,0.674319,0.567209,0.562996,0.578347,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.664556,0.555362,0.551562,0.567121,0.658962,0.548342,0.544464,0.560796,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.652792,0.521615,0.529838,0.516833,0.642112,0.512621,0.517075,0.51053,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.673967,0.567223,0.562548,0.577483,0.670597,0.562947,0.55843,0.573907,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.660162,0.551887,0.548977,0.566218,0.654219,0.546552,0.542768,0.562933,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.640648,0.493679,0.511271,0.485908,0.635613,0.489608,0.504526,0.483769,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.675005,0.573806,0.566652,0.587171,0.670073,0.566336,0.559317,0.579718,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.664418,0.553735,0.549902,0.563561,0.660246,0.548738,0.544807,0.55977,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.647429,0.517836,0.523083,0.515629,0.643213,0.515622,0.51907,0.515059,80.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,32.5


In [22]:
best_swmhau_network_umap_kfold_80

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.683621,0.58303,"[0.7889480692410119, 0.5590964098426786, 0.401...",0.575317,"[0.8149931224209078, 0.5326671790930054, 0.378...",0.593172,"[0.7645161290322581, 0.5882852292020373, 0.426...",,0.68522,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.675524,0.587565,"[0.779539641943734, 0.5592, 0.42395587076438135]",0.574667,"[0.8267631103074141, 0.5287443267776096, 0.368...",0.609957,"[0.7374193548387097, 0.5933786078098472, 0.499...",,0.672327,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.674486,0.580257,"[0.7821881254169446, 0.5377796538623892, 0.420...",0.570112,"[0.8097375690607734, 0.5348446683459278, 0.365...",0.59752,"[0.7564516129032258, 0.5407470288624788, 0.495...",,0.668396,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [23]:
best_swmhau_network_umap_kfold_80["f1"].mean()

0.5836175290083189

In [24]:
best_swmhau_network_umap_kfold_80["precision"].mean()

0.5733651138924793

In [25]:
best_swmhau_network_umap_kfold_80["recall"].mean()

0.6002164712157437

In [26]:
np.stack(best_swmhau_network_umap_kfold_80["f1_scores"]).mean(axis=0)

array([0.78355861, 0.55202535, 0.41526862])

In [27]:
np.stack(best_swmhau_network_umap_kfold_80["precision_scores"]).mean(axis=0)

array([0.8171646 , 0.53208539, 0.37084535])

In [28]:
np.stack(best_swmhau_network_umap_kfold_80["recall_scores"]).mean(axis=0)

array([0.7527957 , 0.57413696, 0.47371676])

## Random Projections

In [11]:
(
    swmhau_network_grp_kfold_80,
    best_swmhau_network_grp_kfold_80,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_80_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_80_kfold_best_model.csv


In [12]:
swmhau_network_grp_kfold_80.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_80.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,include_features_in_input,embedding_dim,num_features,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.683344,0.559982,0.56966,0.553257,0.67804,0.555291,0.562591,0.550402,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.667843,0.555469,0.554287,0.56435,0.66326,0.550932,0.548621,0.560832,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.660162,0.546853,0.54638,0.559374,0.652411,0.537385,0.536401,0.551516,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(256, 256)",0.2,0.0001,0.69054,0.579643,0.581913,0.577868,0.679245,0.567385,0.567199,0.568047,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.2,0.0005,0.66812,0.557632,0.555359,0.56763,0.66153,0.550069,0.546716,0.560739,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.2,0.001,0.661961,0.550788,0.54887,0.562912,0.654455,0.541427,0.539714,0.554809,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.692686,0.587255,0.584524,0.59103,0.686164,0.576511,0.574568,0.579997,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.665975,0.554854,0.552237,0.565078,0.65739,0.543266,0.541019,0.553042,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,20.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.660785,0.544507,0.543439,0.552603,0.65718,0.541564,0.540004,0.551487,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,18.0
15,8,4,6,1,"(512, 512)",0.2,0.0001,0.691094,0.584235,0.582188,0.587106,0.687579,0.579874,0.57657,0.584165,80.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swmhau_network_grp_kfold_80

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.706249,0.604602,"[0.809685230024213, 0.5535405872193436, 0.4505...",0.602276,"[0.8103392568659128, 0.5632688927943761, 0.433...",0.607521,"[0.8090322580645162, 0.5441426146010186, 0.469...",,0.695755,...,0.0001,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.696699,0.603275,"[0.7966353290450272, 0.5614035087719298, 0.451...",0.596988,"[0.8150523118461019, 0.5404556166535742, 0.435...",0.61082,"[0.7790322580645161, 0.5840407470288624, 0.469...",,0.684591,...,0.0001,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.718705,0.618968,"[0.8181382723934219, 0.5656292286874153, 0.473...",0.619119,"[0.8099905153335442, 0.6034648700673725, 0.443...",0.621734,"[0.8264516129032258, 0.532258064516129, 0.5064...",,0.702516,...,0.0001,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_grp_kfold_80["f1"].mean()

0.6089481769191823

In [15]:
best_swmhau_network_grp_kfold_80["precision"].mean()

0.6061276878691522

In [16]:
best_swmhau_network_grp_kfold_80["recall"].mean()

0.6133585079862063

In [17]:
np.stack(best_swmhau_network_grp_kfold_80["f1_scores"]).mean(axis=0)

array([0.80815294, 0.56019111, 0.45850048])

In [18]:
np.stack(best_swmhau_network_grp_kfold_80["precision_scores"]).mean(axis=0)

array([0.81179403, 0.56906313, 0.43752591])

In [19]:
np.stack(best_swmhau_network_grp_kfold_80["recall_scores"]).mean(axis=0)

array([0.80483871, 0.55348048, 0.48175634])

## w=110

In [10]:
size = 110

## UMAP

In [11]:
(
    swmhau_network_umap_kfold_110,
    best_swmhau_network_umap_kfold_110,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_110_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_umap_focal_2_110_kfold_best_model.csv


In [12]:
swmhau_network_umap_kfold_110.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_110.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.648606,0.519951,0.52766,0.516412,0.644785,0.517025,0.523248,0.515343,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.679642,0.57798,0.571494,0.588586,0.676205,0.573867,0.566975,0.58533,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.667255,0.556191,0.5532,0.565457,0.664675,0.552637,0.549413,0.562505,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.656218,0.534601,0.53553,0.534669,0.652594,0.532145,0.531769,0.53386,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.680472,0.576095,0.570975,0.587012,0.677175,0.571532,0.566479,0.583149,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.667116,0.554988,0.55404,0.565976,0.660482,0.548046,0.546293,0.560175,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.644973,0.513467,0.520106,0.510071,0.638941,0.509598,0.514362,0.508022,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.682929,0.583782,0.57607,0.596204,0.674476,0.573373,0.565291,0.587259,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.666321,0.557572,0.55444,0.569741,0.659984,0.550382,0.546843,0.5637,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.659747,0.539178,0.540799,0.538998,0.653931,0.533845,0.533562,0.536101,110.0,17.0,...,384.0,0.0,1.0,0.15,45.333333,2.0,1.0,5.0,64.0,32.5


In [13]:
best_swmhau_network_umap_kfold_110

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.687772,0.592274,"[0.78848433530906, 0.5852516582130317, 0.40308...",0.582088,"[0.8299465240641711, 0.5415162454873647, 0.374...",0.607878,"[0.7509677419354839, 0.6366723259762309, 0.435...",,0.681918,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.689018,0.596621,"[0.7953294412010009, 0.556701030927835, 0.4378...",0.586243,"[0.8234887737478411, 0.5634782608695652, 0.371...",0.617195,"[0.7690322580645161, 0.5500848896434635, 0.532...",,0.68978,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.692962,0.587655,"[0.7998700454840807, 0.5415617128463477, 0.421...",0.585355,"[0.805628272251309, 0.5357142857142857, 0.4147...",0.590101,"[0.7941935483870968, 0.5475382003395586, 0.428...",,0.692924,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_umap_kfold_110["f1"].mean()

0.5921835848863831

In [15]:
best_swmhau_network_umap_kfold_110["precision"].mean()

0.5845618201825564

In [16]:
best_swmhau_network_umap_kfold_110["recall"].mean()

0.6050578338038924

In [17]:
np.stack(best_swmhau_network_umap_kfold_110["f1_scores"]).mean(axis=0)

array([0.79456127, 0.56117147, 0.42081801])

In [18]:
np.stack(best_swmhau_network_umap_kfold_110["precision_scores"]).mean(axis=0)

array([0.81968786, 0.54690293, 0.38709467])

In [19]:
np.stack(best_swmhau_network_umap_kfold_110["recall_scores"]).mean(axis=0)

array([0.77139785, 0.57809847, 0.46567718])

## Random Projections

In [11]:
(
    swmhau_network_grp_kfold_110,
    best_swmhau_network_grp_kfold_110,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_110_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swmhau_network_grp_focal_2_110_kfold_best_model.csv


In [12]:
swmhau_network_grp_kfold_110.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_110.groupby([


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,include_features_in_input,embedding_dim,num_features,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.687496,0.58321,0.580269,0.586901,0.681709,0.577239,0.571985,0.584046,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.668812,0.562205,0.557159,0.572997,0.659277,0.548973,0.544104,0.559261,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.664868,0.549473,0.548217,0.557809,0.657285,0.541697,0.540264,0.551199,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(256, 256)",0.2,0.0001,0.684451,0.576915,0.574665,0.580485,0.684486,0.578815,0.575077,0.583804,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.2,0.0005,0.668466,0.5617,0.557885,0.573168,0.664675,0.554977,0.551429,0.566672,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.2,0.001,0.660439,0.547591,0.545932,0.557678,0.654036,0.53965,0.537938,0.550905,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.696007,0.5956,0.589621,0.603255,0.692138,0.590505,0.583702,0.599856,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.66376,0.552982,0.55057,0.564566,0.658229,0.546871,0.543753,0.559158,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,20.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.663622,0.549712,0.547824,0.558054,0.655451,0.538929,0.536629,0.548085,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,18.0
15,8,4,6,1,"(512, 512)",0.2,0.0001,0.701474,0.596601,0.596297,0.597832,0.699633,0.595201,0.592891,0.599354,110.0,17.0,...,0.0,384.0,0.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swmhau_network_grp_kfold_110

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.701474,0.603205,"[0.8034719947592533, 0.5651991614255765, 0.440...",0.597221,"[0.8160345974717232, 0.5584092792046396, 0.417...",0.610326,"[0.7912903225806451, 0.5721561969439728, 0.467...",,0.700786,...,0.0001,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.687357,0.589001,"[0.7909849749582637, 0.5639396346306592, 0.412...",0.581601,"[0.8197231833910035, 0.5298507462686567, 0.395...",0.599112,"[0.7641935483870967, 0.6027164685908319, 0.430...",,0.679245,...,0.0001,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.713099,0.610666,"[0.8145513338722717, 0.5733964700817908, 0.444...",0.608028,"[0.8165316045380875, 0.5816593886462882, 0.425...",0.613923,"[0.8125806451612904, 0.565365025466893, 0.4638...",,0.70739,...,0.0001,123,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swmhau_network_grp_kfold_110["f1"].mean()

0.6009573708823165

In [15]:
best_swmhau_network_grp_kfold_110["precision"].mean()

0.5956168559694656

In [16]:
best_swmhau_network_grp_kfold_110["recall"].mean()

0.6077870314663912

In [17]:
np.stack(best_swmhau_network_grp_kfold_110["f1_scores"]).mean(axis=0)

array([0.80300277, 0.56751176, 0.43235759])

In [18]:
np.stack(best_swmhau_network_grp_kfold_110["precision_scores"]).mean(axis=0)

array([0.8174298 , 0.5566398 , 0.41278097])

In [19]:
np.stack(best_swmhau_network_grp_kfold_110["recall_scores"]).mean(axis=0)

array([0.78935484, 0.58007923, 0.45392703])