In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from nlpsig_networks.scripts.swmhau_network_functions import (
    swmhau_network_hyperparameter_search,
)

In [4]:
output_dir = "rumours_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## Rumours

In [5]:
%run load_sbert-embeddings.py

In [6]:
df_rumours.head()

Unnamed: 0,id,label,datetime,text,timeline_id,set
0,5.249902e+17,0,2014-10-22 18:26:23,Police have clarified that there were two shoo...,0,train
1,5.249906e+17,0,2014-10-22 18:27:58,"@CTVNews you guys ""confirmed"" there were 3 sho...",0,train
2,5.249908e+17,1,2014-10-22 18:28:46,@CTVNews get it right. http://t.co/GHYxMuzPG9,0,train
3,5.249927e+17,1,2014-10-22 18:36:29,RT @CTVNews Police have clarified that there w...,0,train
4,5.250038e+17,1,2014-10-22 19:20:41,@CTVNews @ctvsaskatoon so what happened at Rid...,0,train


# swmhau Network

In [7]:
features = ["time_encoding", "timeline_index"]
standardise_method = ["z_score", None]
num_features = len(features)
add_time_in_path = True

In [8]:
num_epochs = 100
embedding_dim = 384
dimensions = [15]  # [50, 15]
# define swmhau parameters: (output_channels, sig_depth, num_heads)
swmhau_parameters = [(12, 3, 10), (8, 4, 6), (8, 4, 12)]
num_layers = [1]
ffn_hidden_dim_sizes = [[256, 256], [512, 512]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5
split_indices = (
    df_rumours[df_rumours["set"] == "train"].index,
    df_rumours[df_rumours["set"] == "dev"].index,
    df_rumours[df_rumours["set"] == "test"].index,
)

## UMAP

In [12]:
size = 35
(
    swmhau_network_umap,
    best_swmhau_network_umap,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=df_rumours,
    id_column="timeline_id",
    label_column="label",
    embeddings=sbert_embeddings,
    y_data=y_data,
    embedding_dim=embedding_dim,
    output_dim=output_dim,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swmhau_parameters=swmhau_parameters,
    num_layers=num_layers,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    split_indices=split_indices,
    k_fold=False,
    features=features,
    standardise_method=standardise_method,
    add_time_in_path=add_time_in_path,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_umap_focal_2_35.csv
saving the best model results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_umap_focal_2_35_best_model.csv


In [13]:
swmhau_network_umap.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap.groupby(["dimensions",


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,...,add_time_in_path,num_features,embedding_dim,log_signature,dropout_rate,seed,gamma,k_fold,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.371428,0.632666,0.627786,0.64995,0.639137,0.206843,0.68446,0.635459,0.65499,0.632837,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.316856,0.622974,0.609554,0.659724,0.632674,0.211716,0.661329,0.638678,0.642134,0.644202,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.384461,0.618843,0.610358,0.640353,0.62544,0.216274,0.668446,0.631046,0.642696,0.631162,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.432205,0.629171,0.624358,0.647014,0.635574,0.245185,0.670818,0.635515,0.641283,0.635133,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.365014,0.62504,0.61572,0.651128,0.633175,0.229734,0.654804,0.625325,0.628143,0.627131,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.348443,0.60089,0.589771,0.628147,0.608422,0.223163,0.645314,0.627196,0.631765,0.635849,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.387116,0.622815,0.617594,0.639912,0.629369,0.203498,0.68624,0.644313,0.657086,0.643722,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.342072,0.64347,0.638969,0.661095,0.649935,0.214009,0.661329,0.626907,0.635995,0.628738,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.328106,0.621703,0.61003,0.653909,0.630939,0.219945,0.645314,0.624241,0.626615,0.630226,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.360836,0.637274,0.632549,0.655259,0.643747,0.227177,0.680308,0.636804,0.652388,0.635904,...,1.0,2.0,384.0,1.0,0.3,45.333333,2.0,0.0,64.0,32.5


In [14]:
best_swmhau_network_umap

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,0.370341,0.615825,0.595707,"[0.5055214723926381, 0.6858924395947]",0.668189,"[0.7744360902255639, 0.561941251596424]",0.627614,"[0.37522768670309653, 0.88]",0.220668,0.672598,...,0.0001,1,focal,2,False,,Conv1d,,concatenation,64
0,0.446978,0.626311,0.623477,"[0.5908141962421712, 0.656140350877193]",0.638153,"[0.6919315403422983, 0.584375]",0.631741,"[0.5154826958105647, 0.748]",0.300115,0.661922,...,0.0001,12,focal,2,False,,Conv1d,,concatenation,64
0,0.589764,0.613918,0.613298,"[0.6287809349220898, 0.5978152929493545]",0.613265,"[0.6328413284132841, 0.5936883629191322]",0.613386,"[0.6247723132969034, 0.602]",0.243468,0.72242,...,0.0001,123,focal,2,False,,Conv1d,,concatenation,64


In [15]:
best_swmhau_network_umap[
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
]

Unnamed: 0,dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate
0,15,12,3,10,1,"(512, 512)",0.1,0.0001
0,15,12,3,10,1,"(512, 512)",0.1,0.0001
0,15,12,3,10,1,"(512, 512)",0.1,0.0001


In [16]:
best_swmhau_network_umap["f1"].mean()

0.6108274478296911

In [17]:
best_swmhau_network_umap["precision"].mean()

0.6398689289161171

In [18]:
best_swmhau_network_umap["recall"].mean()

0.6242471159684274

In [19]:
np.stack(best_swmhau_network_umap["f1_scores"]).mean(axis=0)

array([0.57503887, 0.64661603])

In [20]:
np.stack(best_swmhau_network_umap["precision_scores"]).mean(axis=0)

array([0.69973632, 0.58000154])

In [21]:
np.stack(best_swmhau_network_umap["recall_scores"]).mean(axis=0)

array([0.5051609 , 0.74333333])

## Random Projections

In [9]:
size = 35
(
    swmhau_network_grp,
    best_swmhau_network_grp,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=df_rumours,
    id_column="timeline_id",
    label_column="label",
    embeddings=sbert_embeddings,
    y_data=y_data,
    embedding_dim=embedding_dim,
    output_dim=output_dim,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swmhau_parameters=swmhau_parameters,
    num_layers=num_layers,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    split_indices=split_indices,
    k_fold=False,
    features=features,
    standardise_method=standardise_method,
    add_time_in_path=add_time_in_path,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/5568 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_grp_focal_2_35.csv
saving the best model results dataframe to CSV for this hyperparameter search in rumours_output/swmhau_network_grp_focal_2_35_best_model.csv


In [10]:
swmhau_network_grp.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp.groupby(["dimensions",


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,...,input_channels,add_time_in_path,num_features,embedding_dim,log_signature,seed,gamma,k_fold,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.416057,0.62504,0.62382,0.625326,0.624607,0.210194,0.73191,0.661495,0.740224,0.659537,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.747329,0.619956,0.618705,0.624506,0.6221,0.217605,0.709371,0.685023,0.701513,0.687534,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.478982,0.614554,0.610244,0.62635,0.618129,0.203881,0.706999,0.671359,0.697511,0.671614,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,15.0
15,8,4,6,1,"(256, 256)",0.5,0.0001,0.372697,0.617731,0.615477,0.627024,0.622325,0.200793,0.715303,0.694313,0.699392,0.698516,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.5,0.0005,0.396219,0.642517,0.639383,0.654536,0.646064,0.198674,0.730724,0.706046,0.728692,0.713432,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.5,0.001,0.454991,0.612965,0.597113,0.647118,0.62033,0.208597,0.696323,0.674613,0.689212,0.680104,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,12.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.487213,0.620909,0.619611,0.624314,0.622386,0.191574,0.73191,0.699508,0.712793,0.69679,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,22.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.475595,0.621227,0.612861,0.641456,0.624861,0.202215,0.71293,0.673915,0.725335,0.680487,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,23.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.478557,0.621862,0.620213,0.629558,0.625647,0.192484,0.69395,0.664386,0.667993,0.662778,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,21.0
15,8,4,6,1,"(512, 512)",0.5,0.0001,0.355497,0.629488,0.628022,0.636125,0.632873,0.204633,0.697509,0.650644,0.670582,0.647999,...,17.0,1.0,2.0,384.0,1.0,45.333333,2.0,0.0,64.0,19.0


In [11]:
best_swmhau_network_grp

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,0.338599,0.626311,0.620928,"[0.5757575757575758, 0.666098807495741]",0.644726,"[0.7093333333333334, 0.5801186943620178]",0.633259,"[0.48451730418943534, 0.782]",0.230089,0.683274,...,0.0001,1,focal,2,False,,Conv1d,,concatenation,64
0,0.934016,0.585319,0.57493,"[0.6413849958779885, 0.5084745762711864]",0.585129,"[0.5858433734939759, 0.5844155844155844]",0.579281,"[0.7085610200364298, 0.45]",0.22475,0.814947,...,0.0001,12,focal,2,False,,Conv1d,,concatenation,64
0,1.25599,0.616778,0.614592,"[0.6436170212765957, 0.5855670103092784]",0.615599,"[0.6269430051813472, 0.6042553191489362]",0.614601,"[0.6612021857923497, 0.568]",0.49456,0.765125,...,0.0001,123,focal,2,False,,Conv1d,,concatenation,64


In [12]:
best_swmhau_network_grp.columns

Index(['loss', 'accuracy', 'f1', 'f1_scores', 'precision', 'precision_scores',
       'recall', 'recall_scores', 'valid_loss', 'valid_accuracy', 'valid_f1',
       'valid_f1_scores', 'valid_precision', 'valid_precision_scores',
       'valid_recall', 'valid_recall_scores', 'k', 'dimensions', 'sig_depth',
       'method', 'input_channels', 'output_channels', 'features',
       'standardise_method', 'add_time_in_path', 'num_features',
       'embedding_dim', 'log_signature', 'num_heads', 'num_layers',
       'ffn_hidden_dim', 'dropout_rate', 'learning_rate', 'seed',
       'loss_function', 'gamma', 'k_fold', 'n_splits', 'augmentation_type',
       'hidden_dim_aug', 'comb_method', 'batch_size'],
      dtype='object')

In [13]:
best_swmhau_network_grp[
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
]

Unnamed: 0,dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate
0,15,8,4,12,1,"(512, 512)",0.1,0.0001
0,15,8,4,12,1,"(512, 512)",0.1,0.0001
0,15,8,4,12,1,"(512, 512)",0.1,0.0001


In [14]:
best_swmhau_network_grp["f1"].mean()

0.6034833311647276

In [15]:
best_swmhau_network_grp["precision"].mean()

0.6151515516558658

In [16]:
best_swmhau_network_grp["recall"].mean()

0.6090467516697023

In [17]:
np.stack(best_swmhau_network_grp["f1_scores"]).mean(axis=0)

array([0.6202532 , 0.58671346])

In [18]:
np.stack(best_swmhau_network_grp["precision_scores"]).mean(axis=0)

array([0.64070657, 0.58959653])

In [19]:
np.stack(best_swmhau_network_grp["recall_scores"]).mean(axis=0)

array([0.6180935, 0.6      ])