In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from nlpsig_networks.scripts.swmhau_network_functions import (
    swmhau_network_hyperparameter_search
)

In [4]:
output_dir = "rumours_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## Rumours

In [5]:
%run load_sbert-embeddings.py

In [6]:
df_rumours.head()

Unnamed: 0,id,label,datetime,text,timeline_id,set
0,5.249902e+17,0,2014-10-22 18:26:23,Police have clarified that there were two shoo...,0,train
1,5.249906e+17,0,2014-10-22 18:27:58,"@CTVNews you guys ""confirmed"" there were 3 sho...",0,train
2,5.249908e+17,1,2014-10-22 18:28:46,@CTVNews get it right. http://t.co/GHYxMuzPG9,0,train
3,5.249927e+17,1,2014-10-22 18:36:29,RT @CTVNews Police have clarified that there w...,0,train
4,5.250038e+17,1,2014-10-22 19:20:41,@CTVNews @ctvsaskatoon so what happened at Rid...,0,train


# swmhau Network

In [7]:
features = ["time_encoding", "timeline_index"]
standardise_method = ["z_score", None]
include_features_in_path = True

In [8]:
num_epochs = 100
dimensions = [15] # [50, 15]
# define swmhau parameters: (output_channels, sig_depth, num_heads)
swmhau_parameters = [(12, 3, 10), (8, 4, 6), (8, 4, 12)]
num_layers = [1]
ffn_hidden_dim_sizes = [[128,128],[256,256],[512,512]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5
split_indices = (df_rumours[df_rumours['set']=='train'].index,
                 df_rumours[df_rumours['set']=='dev'].index,
                 df_rumours[df_rumours['set']=='test'].index)

## UMAP

In [9]:
size = 35
swmhau_network_umap, best_swmhau_network_umap, _, __ = swmhau_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=df_rumours,
    id_column="timeline_id",
    label_column="label",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swmhau_parameters=swmhau_parameters,
    num_layers=num_layers,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    split_indices=split_indices,
    k_fold=False,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}.csv",
    verbose=False
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_umap_focal_2_35.csv
saving the best model results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_umap_focal_2_35_best_model.csv


In [10]:
swmhau_network_umap.groupby(["dimensions",
                             "output_channels",
                             "sig_depth",
                             "num_heads",
                             "num_layers",
                             "ffn_hidden_dim",
                             "learning_rate"]).mean()

  swmhau_network_umap.groupby(["dimensions",


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,...,include_features_in_path,embedding_dim,num_features,log_signature,dropout_rate,seed,gamma,k_fold,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(128, 128)",0.0001,0.416103,0.632189,0.628933,0.645383,0.637804,0.236027,0.681495,0.629497,0.651081,0.626292,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,20.5
15,8,4,6,1,"(128, 128)",0.0005,0.322329,0.633302,0.629238,0.64815,0.639,0.218794,0.650059,0.617723,0.619684,0.61673,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,21.5
15,8,4,6,1,"(128, 128)",0.001,0.357057,0.626152,0.61707,0.65225,0.632735,0.213779,0.642349,0.603629,0.621554,0.613138,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,19.5
15,8,4,6,1,"(256, 256)",0.0001,0.484244,0.62504,0.623034,0.633478,0.628742,0.34472,0.68446,0.627711,0.659305,0.627565,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,26.5
15,8,4,6,1,"(256, 256)",0.0005,0.300911,0.632984,0.625748,0.655165,0.64011,0.230354,0.650652,0.617901,0.622846,0.619655,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,27.5
15,8,4,6,1,"(256, 256)",0.001,0.323981,0.609628,0.59965,0.630403,0.615446,0.222513,0.641162,0.603624,0.611238,0.606935,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,25.5
15,8,4,6,1,"(512, 512)",0.0001,0.475924,0.614871,0.61101,0.627559,0.620426,0.271898,0.674377,0.63419,0.644391,0.631248,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,32.5
15,8,4,6,1,"(512, 512)",0.0005,0.328637,0.626946,0.619229,0.648602,0.634015,0.217221,0.642349,0.610127,0.615115,0.613841,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,33.5
15,8,4,6,1,"(512, 512)",0.001,0.328253,0.616301,0.607464,0.633657,0.620512,0.223254,0.63879,0.584005,0.601594,0.586446,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,31.5
15,8,4,12,1,"(128, 128)",0.0001,0.410756,0.637432,0.632105,0.654323,0.643572,0.241258,0.67497,0.639747,0.647775,0.63804,...,1.0,384.0,2.0,1.0,0.3,45.333333,2.0,0.0,64.0,38.5


In [11]:
best_swmhau_network_umap

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,0.360839,0.628217,0.626408,"[0.6004098360655737, 0.6524064171122995]",0.637304,"[0.6861826697892272, 0.5884244372990354]",0.632849,"[0.5336976320582878, 0.732]",0.213752,0.701068,...,0.0005,1,focal,2,False,,Conv1d,,concatenation,64
0,0.648381,0.604385,0.60375,"[0.5878848063555113, 0.6196150320806599]",0.6091,"[0.6462882096069869, 0.571912013536379]",0.607581,"[0.5391621129326047, 0.676]",0.243665,0.772242,...,0.0005,12,focal,2,False,,Conv1d,,concatenation,64
0,0.831684,0.589133,0.589109,"[0.5859750240153698, 0.5922421948912014]",0.590929,"[0.6199186991869918, 0.5619389587073609]",0.590778,"[0.5555555555555556, 0.626]",0.639707,0.697509,...,0.0005,123,focal,2,False,,Conv1d,,concatenation,64


In [12]:
best_swmhau_network_umap[["dimensions",
                          "output_channels",
                          "sig_depth",
                          "num_heads",
                          "num_layers",
                          "ffn_hidden_dim",
                          "dropout_rate",
                          "learning_rate"]]

Unnamed: 0,dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate
0,15,12,3,10,1,"(512, 512)",0.1,0.0005
0,15,12,3,10,1,"(512, 512)",0.1,0.0005
0,15,12,3,10,1,"(512, 512)",0.1,0.0005


In [13]:
best_swmhau_network_umap["f1"].mean()

0.6064222184201026

In [14]:
best_swmhau_network_umap["precision"].mean()

0.6124441646876636

In [15]:
best_swmhau_network_umap["recall"].mean()

0.6104025500910747

In [16]:
np.stack(best_swmhau_network_umap["f1_scores"]).mean(axis=0)

array([0.59142322, 0.62142121])

In [17]:
np.stack(best_swmhau_network_umap["precision_scores"]).mean(axis=0)

array([0.65079653, 0.5740918 ])

In [18]:
np.stack(best_swmhau_network_umap["recall_scores"]).mean(axis=0)

array([0.5428051, 0.678    ])

## Random Projections

In [19]:
size = 35
swmhau_network_grp, best_swmhau_network_grp, _, __ = swmhau_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=df_rumours,
    id_column="timeline_id",
    label_column="label",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swmhau_parameters=swmhau_parameters,
    num_layers=num_layers,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    split_indices=split_indices,
    k_fold=False,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}.csv",
    verbose=False
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_grp_focal_2_35.csv
saving the best model results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_grp_focal_2_35_best_model.csv


In [20]:
swmhau_network_grp.groupby(["dimensions",
                            "output_channels",
                            "sig_depth",
                            "num_heads",
                            "num_layers",
                            "ffn_hidden_dim",
                            "dropout_rate",
                            "learning_rate"]).mean()

  swmhau_network_grp.groupby(["dimensions",


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,...,input_channels,include_features_in_path,embedding_dim,num_features,log_signature,seed,gamma,k_fold,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(128, 128)",0.1,0.0001,0.329374,0.615507,0.610961,0.629069,0.619426,0.219909,0.733096,0.687649,0.723538,0.680852,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,22.0
15,8,4,6,1,"(128, 128)",0.1,0.0005,0.439792,0.616142,0.610939,0.61784,0.614321,0.197619,0.730724,0.690838,0.712682,0.68391,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,23.0
15,8,4,6,1,"(128, 128)",0.1,0.001,0.315147,0.630759,0.628612,0.630653,0.629327,0.203969,0.702254,0.637147,0.706016,0.639774,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,21.0
15,8,4,6,1,"(128, 128)",0.5,0.0001,0.463705,0.591992,0.590513,0.593252,0.59226,0.197053,0.786477,0.754924,0.776396,0.74665,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,19.0
15,8,4,6,1,"(128, 128)",0.5,0.0005,0.350249,0.633619,0.631604,0.63948,0.635838,0.194894,0.715302,0.674717,0.716573,0.673212,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,20.0
15,8,4,6,1,"(128, 128)",0.5,0.001,0.374076,0.619638,0.618097,0.622038,0.620427,0.210467,0.714116,0.662627,0.714472,0.656817,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,18.0
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.483176,0.609152,0.606477,0.617357,0.611629,0.232993,0.739027,0.705235,0.722628,0.69816,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,28.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.368917,0.636797,0.636025,0.637384,0.63694,0.201707,0.710558,0.681055,0.687547,0.677922,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,29.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.354118,0.607563,0.598573,0.623242,0.610825,0.227521,0.644128,0.597177,0.605787,0.601882,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,27.0
15,8,4,6,1,"(256, 256)",0.5,0.0001,0.580154,0.615507,0.610027,0.633532,0.621211,0.270629,0.735469,0.702825,0.719001,0.696069,...,17.0,1.0,384.0,2.0,1.0,45.333333,2.0,0.0,64.0,25.0


In [21]:
best_swmhau_network_grp

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,0.358502,0.617731,0.617563,"[0.6095423563777994, 0.6255835667600375]",0.620751,"[0.6548117154811716, 0.5866900175131349]",0.620064,"[0.5701275045537341, 0.67]",0.209083,0.718861,...,0.0001,1,focal,2,False,,Conv1d,,concatenation,64
0,0.413685,0.594852,0.590673,"[0.632034632034632, 0.5493107104984094]",0.59348,"[0.6023102310231023, 0.5846501128668171]",0.591423,"[0.6648451730418944, 0.518]",0.168258,0.811388,...,0.0001,12,focal,2,False,,Conv1d,,concatenation,64
0,0.618928,0.563394,0.563304,"[0.5570599613152806, 0.5695488721804511]",0.565524,"[0.5938144329896907, 0.5372340425531915]",0.565295,"[0.5245901639344263, 0.606]",0.213818,0.829181,...,0.0001,123,focal,2,False,,Conv1d,,concatenation,64


In [22]:
best_swmhau_network_grp.columns

Index(['loss', 'accuracy', 'f1', 'f1_scores', 'precision', 'precision_scores',
       'recall', 'recall_scores', 'valid_loss', 'valid_accuracy', 'valid_f1',
       'valid_f1_scores', 'valid_precision', 'valid_precision_scores',
       'valid_recall', 'valid_recall_scores', 'k', 'dimensions', 'sig_depth',
       'method', 'input_channels', 'output_channels', 'features',
       'standardise_method', 'include_features_in_path', 'embedding_dim',
       'num_features', 'log_signature', 'num_heads', 'num_layers',
       'ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',
       'loss_function', 'gamma', 'k_fold', 'n_splits', 'augmentation_type',
       'hidden_dim_aug', 'comb_method', 'batch_size'],
      dtype='object')

In [23]:
best_swmhau_network_grp[["dimensions",
                         "output_channels",
                         "sig_depth",
                         "num_heads",
                         "num_layers",
                         "ffn_hidden_dim",
                         "dropout_rate",
                         "learning_rate"]]

Unnamed: 0,dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate
0,15,8,4,6,1,"(128, 128)",0.5,0.0001
0,15,8,4,6,1,"(128, 128)",0.5,0.0001
0,15,8,4,6,1,"(128, 128)",0.5,0.0001


In [24]:
best_swmhau_network_grp["f1"].mean()

0.5905133498611016

In [25]:
best_swmhau_network_grp["precision"].mean()

0.5932517587378513

In [26]:
best_swmhau_network_grp["recall"].mean()

0.5922604735883424

In [27]:
np.stack(best_swmhau_network_grp["f1_scores"]).mean(axis=0)

array([0.59954565, 0.58148105])

In [28]:
np.stack(best_swmhau_network_grp["precision_scores"]).mean(axis=0)

array([0.61697879, 0.56952472])

In [29]:
np.stack(best_swmhau_network_grp["recall_scores"]).mean(axis=0)

array([0.58652095, 0.598     ])