In [1]:
import numpy as np
import pickle
import os
from tqdm.notebook import tqdm

seed = 2023

In [2]:
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from nlpsig_networks.scripts.swnu_network_functions import (
    swnu_network_hyperparameter_search,
)

In [4]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [5]:
%run ../load_anno_mi.py

In [6]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-10-01 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-10-01 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-10-01 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-10-01 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-10-01 00:00:34,-1


In [7]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# SWNU Network

## Obtaining path by looking at post history

We can obtain a path by looking at the history of each post. Here we look at the last 10 posts (and pad with vectors of zeros if there are less than 10 posts) including the current post.

We only want to consider paths that correspond to a client's utterance as we want to model a change in mood at that time. Their history will still contain the therapist's utterances too.

In [8]:
features = ["timeline_index", "speaker"]
standardise_method = [None, None]
include_features_in_path = True
include_features_in_input = False

In [9]:
num_epochs = 100
dimensions = [15]
swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 4)]
ffn_hidden_dim_sizes = [[256,256],[512,512]]
dropout_rates = [0.1, 0.2]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

# w=5

In [22]:
size = 5

## UMAP

In [11]:
(
    swnu_network_umap_kfold_5,
    best_swnu_network_umap_kfold_5,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_5_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_5_kfold_best_model.csv


In [12]:
swnu_network_umap_kfold_5.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_umap_kfold_5.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.636207,0.51712,0.509553,0.529213,0.659602,0.551338,0.546421,0.558437,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.641124,0.519248,0.513344,0.534045,0.665007,0.556678,0.553971,0.566355,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.636015,0.514764,0.51087,0.533548,0.661314,0.554335,0.553353,0.56809,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.633908,0.514416,0.506947,0.526355,0.660726,0.551654,0.546938,0.558689,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.637037,0.520095,0.512287,0.53658,0.662599,0.557632,0.552358,0.569053,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.633269,0.512613,0.508013,0.531401,0.664954,0.558017,0.556513,0.569694,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.629119,0.512011,0.503888,0.526364,0.659174,0.551885,0.546492,0.56018,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.634547,0.51661,0.511586,0.537764,0.659923,0.555643,0.552993,0.570199,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.633014,0.514058,0.508267,0.532817,0.660351,0.556983,0.552752,0.570734,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.632375,0.513389,0.506048,0.526618,0.658478,0.550487,0.545497,0.558398,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swnu_network_umap_kfold_5

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.637165,0.519572,"[0.7667568379921852, 0.44970414201183434, 0.34...",0.512834,"[0.7981852315394243, 0.4610051993067591, 0.279...",0.539491,"[0.7377096587622903, 0.4389438943894389, 0.441...",,0.666024,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.635057,0.524985,"[0.7604435667628741, 0.45958145260566263, 0.35...",0.515919,"[0.80096, 0.45714285714285713, 0.2896551724137...",0.548019,"[0.7238288027761712, 0.46204620462046203, 0.45...",,0.657354,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.633525,0.514677,"[0.761904761904762, 0.43298969072164945, 0.349...",0.508335,"[0.7847946045370938, 0.45161290322580644, 0.28...",0.532657,"[0.7403123192596877, 0.4158415841584158, 0.441...",,0.666185,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swnu_network_umap_kfold_5["f1"].mean()

0.5197446090694444

In [15]:
best_swnu_network_umap_kfold_5["precision"].mean()

0.512362765312797

In [16]:
best_swnu_network_umap_kfold_5["recall"].mean()

0.5400556273094054

In [17]:
np.stack(best_swnu_network_umap_kfold_5["f1_scores"]).mean(axis=0)

array([0.76303506, 0.4474251 , 0.34877368])

In [18]:
np.stack(best_swnu_network_umap_kfold_5["precision_scores"]).mean(axis=0)

array([0.79464661, 0.45658699, 0.2858547 ])

In [19]:
np.stack(best_swnu_network_umap_kfold_5["recall_scores"]).mean(axis=0)

array([0.73395026, 0.43894389, 0.44727273])

## GRP

In [23]:
(
    swnu_network_grp_kfold_5,
    best_swnu_network_grp_kfold_5,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_5_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_5_kfold_best_model.csv


In [24]:
swnu_network_grp_kfold_5.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_5.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.637931,0.519559,0.512261,0.533513,0.667362,0.562064,0.557537,0.57069,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.634674,0.520334,0.51266,0.539509,0.660458,0.559003,0.553552,0.572639,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.637867,0.522015,0.516169,0.542813,0.660619,0.555634,0.552482,0.568887,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.636782,0.518263,0.510996,0.532297,0.665329,0.557206,0.5535,0.564858,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.6341,0.521276,0.513364,0.541796,0.659709,0.55946,0.553524,0.574741,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631737,0.520523,0.513127,0.544242,0.657033,0.55464,0.550091,0.570381,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.635824,0.518346,0.511019,0.532982,0.666024,0.56231,0.557239,0.571738,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.638314,0.520946,0.516589,0.540605,0.664205,0.558337,0.556891,0.571032,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.63014,0.515548,0.507976,0.535613,0.656765,0.554958,0.549925,0.568716,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.635441,0.520424,0.511906,0.536442,0.666881,0.563577,0.558097,0.573069,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [25]:
swnu_network_grp_kfold_5.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_5.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.637931,0.519559,0.512261,0.533513,0.667362,0.562064,0.557537,0.57069,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.634674,0.520334,0.51266,0.539509,0.660458,0.559003,0.553552,0.572639,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.637867,0.522015,0.516169,0.542813,0.660619,0.555634,0.552482,0.568887,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.636782,0.518263,0.510996,0.532297,0.665329,0.557206,0.5535,0.564858,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.6341,0.521276,0.513364,0.541796,0.659709,0.55946,0.553524,0.574741,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631737,0.520523,0.513127,0.544242,0.657033,0.55464,0.550091,0.570381,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.635824,0.518346,0.511019,0.532982,0.666024,0.56231,0.557239,0.571738,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.638314,0.520946,0.516589,0.540605,0.664205,0.558337,0.556891,0.571032,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.63014,0.515548,0.507976,0.535613,0.656765,0.554958,0.549925,0.568716,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.635441,0.520424,0.511906,0.536442,0.666881,0.563577,0.558097,0.573069,5.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [26]:
best_swnu_network_grp_kfold_5

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.629885,0.512292,"[0.7563527653213753, 0.43732018084669133, 0.34...",0.504387,"[0.7827970297029703, 0.4357084357084357, 0.294...",0.527163,"[0.7316367842683632, 0.4389438943894389, 0.410...",,0.668112,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.631226,0.51651,"[0.7575757575757577, 0.4415584415584416, 0.350...",0.508643,"[0.7870947630922693, 0.44851063829787235, 0.29...",0.535609,"[0.7301908617698092, 0.43481848184818483, 0.44...",,0.662492,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.633908,0.514286,"[0.7594074074074074, 0.4338756779307468, 0.349...",0.507182,"[0.778554070473876, 0.4388185654008439, 0.3041...",0.527044,"[0.7411798727588201, 0.429042904290429, 0.4109...",,0.670199,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [27]:
best_swnu_network_grp_kfold_5["f1"].mean()

0.5143628505176907

In [28]:
best_swnu_network_grp_kfold_5["precision"].mean()

0.5067369839920379

In [29]:
best_swnu_network_grp_kfold_5["recall"].mean()

0.529938795884601

In [30]:
np.stack(best_swnu_network_grp_kfold_5["f1_scores"]).mean(axis=0)

array([0.75777864, 0.43758477, 0.34772514])

In [31]:
np.stack(best_swnu_network_grp_kfold_5["precision_scores"]).mean(axis=0)

array([0.78281529, 0.44101255, 0.29638312])

In [32]:
np.stack(best_swnu_network_grp_kfold_5["recall_scores"]).mean(axis=0)

array([0.73433584, 0.43426843, 0.42121212])

# w=11

In [33]:
size = 11

## UMAP

In [34]:
(
    swnu_network_umap_kfold_11,
    best_swnu_network_umap_kfold_11,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_11_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_11_kfold_best_model.csv


In [35]:
swnu_network_umap_kfold_11.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_umap_kfold_11.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.633653,0.511955,0.505871,0.521421,0.656872,0.545975,0.541768,0.55189,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.628225,0.51076,0.503805,0.529144,0.65623,0.554562,0.548353,0.567675,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.630013,0.513454,0.509199,0.537231,0.661368,0.560047,0.557334,0.577786,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.634227,0.510878,0.504871,0.520673,0.657996,0.547947,0.543574,0.55414,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.632248,0.51248,0.5059,0.529214,0.659548,0.555254,0.550383,0.567566,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.62931,0.516421,0.510382,0.540317,0.653393,0.552089,0.547796,0.569875,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.638633,0.515204,0.509742,0.525045,0.662813,0.551695,0.548703,0.556687,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.625223,0.51404,0.50613,0.536438,0.653768,0.553459,0.547154,0.568607,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.631481,0.512974,0.506853,0.531552,0.660779,0.557081,0.552816,0.570339,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.634994,0.512026,0.505242,0.52338,0.66201,0.550918,0.547089,0.55644,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [36]:
best_swnu_network_umap_kfold_11

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.632375,0.513024,"[0.7635711081202332, 0.4328037784456848, 0.342...",0.507011,"[0.7906472592133788, 0.4512085944494181, 0.279...",0.532589,"[0.738288027761712, 0.4158415841584158, 0.4436...",,0.659602,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.623755,0.510287,"[0.7518796992481204, 0.42736486486486486, 0.35...",0.502541,"[0.7832080200501254, 0.43771626297577854, 0.28...",0.531666,"[0.7229612492770388, 0.4174917491749175, 0.454...",,0.662171,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.634291,0.509492,"[0.7609841827768014, 0.42397914856646396, 0.34...",0.504893,"[0.7709198813056379, 0.44770642201834865, 0.29...",0.521011,"[0.7513013302486987, 0.40264026402640263, 0.40...",,0.668433,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [37]:
best_swnu_network_umap_kfold_11["f1"].mean()

0.5109342557602629

In [38]:
best_swnu_network_umap_kfold_11["precision"].mean()

0.5048147245190061

In [39]:
best_swnu_network_umap_kfold_11["recall"].mean()

0.5284218813244347

In [40]:
np.stack(best_swnu_network_umap_kfold_11["f1_scores"]).mean(axis=0)

array([0.75881166, 0.42804926, 0.34594184])

In [41]:
np.stack(best_swnu_network_umap_kfold_11["precision_scores"]).mean(axis=0)

array([0.78159172, 0.44554376, 0.28730869])

In [42]:
np.stack(best_swnu_network_umap_kfold_11["recall_scores"]).mean(axis=0)

array([0.73751687, 0.4119912 , 0.43575758])

## GRP

In [43]:
(
    swnu_network_grp_kfold_11,
    best_swnu_network_grp_kfold_11,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_11_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_11_kfold_best_model.csv


In [44]:
swnu_network_grp_kfold_11.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_11.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.635951,0.517082,0.509585,0.530314,0.662653,0.554108,0.550072,0.560642,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.635888,0.518681,0.511607,0.535058,0.662117,0.558223,0.554039,0.569593,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.636271,0.519882,0.514296,0.540819,0.65912,0.556866,0.553305,0.572849,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.634483,0.516197,0.508696,0.528811,0.662974,0.554354,0.550255,0.560835,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.635568,0.521268,0.513372,0.539572,0.660565,0.556545,0.551676,0.568575,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.636462,0.518802,0.513983,0.538383,0.661529,0.556358,0.554792,0.571427,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.637995,0.520074,0.512482,0.534144,0.663081,0.558038,0.552941,0.567183,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.637676,0.519678,0.515185,0.539011,0.66383,0.558126,0.557314,0.571906,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.633078,0.51491,0.508914,0.533292,0.658531,0.554274,0.550083,0.567195,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.636079,0.51882,0.510904,0.532671,0.665382,0.560106,0.554865,0.568625,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [45]:
swnu_network_grp_kfold_11.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_11.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.635951,0.517082,0.509585,0.530314,0.662653,0.554108,0.550072,0.560642,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.635888,0.518681,0.511607,0.535058,0.662117,0.558223,0.554039,0.569593,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.636271,0.519882,0.514296,0.540819,0.65912,0.556866,0.553305,0.572849,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.634483,0.516197,0.508696,0.528811,0.662974,0.554354,0.550255,0.560835,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.635568,0.521268,0.513372,0.539572,0.660565,0.556545,0.551676,0.568575,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.636462,0.518802,0.513983,0.538383,0.661529,0.556358,0.554792,0.571427,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.637995,0.520074,0.512482,0.534144,0.663081,0.558038,0.552941,0.567183,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.637676,0.519678,0.515185,0.539011,0.66383,0.558126,0.557314,0.571906,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.633078,0.51491,0.508914,0.533292,0.658531,0.554274,0.550083,0.567195,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.636079,0.51882,0.510904,0.532671,0.665382,0.560106,0.554865,0.568625,11.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [46]:
best_swnu_network_grp_kfold_11

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.633333,0.516964,"[0.7592203898050974, 0.44256348246674726, 0.34...",0.508461,"[0.788293897882939, 0.4326241134751773, 0.3044...",0.531425,"[0.7322151532677849, 0.452970297029703, 0.4090...",,0.670199,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.649808,0.52387,"[0.7722368037328668, 0.4507042253521127, 0.348...",0.519722,"[0.7788235294117647, 0.46684350132625996, 0.31...",0.531377,"[0.7657605552342395, 0.43564356435643564, 0.39...",,0.675498,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.640613,0.51904,"[0.7656250000000001, 0.438921651221567, 0.3525...",0.512642,"[0.780817799158148, 0.4483648881239243, 0.3087...",0.530596,"[0.7510121457489879, 0.4298679867986799, 0.410...",,0.671484,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [47]:
best_swnu_network_grp_kfold_11["f1"].mean()

0.5199579746606222

In [48]:
best_swnu_network_grp_kfold_11["precision"].mean()

0.5136082461799552

In [49]:
best_swnu_network_grp_kfold_11["recall"].mean()

0.5311329972403448

In [50]:
np.stack(best_swnu_network_grp_kfold_11["f1_scores"]).mean(axis=0)

array([0.76569406, 0.44406312, 0.35011674])

In [None]:
np.stack(best_swnu_network_grp_kfold_11["precision_scores"]).mean(axis=0)

array([0.78264508, 0.4492775 , 0.30890216])

In [None]:
np.stack(best_swnu_network_grp_kfold_11["recall_scores"]).mean(axis=0)

array([0.74966262, 0.43949395, 0.40424242])

# w=20

In [10]:
size = 20

## UMAP

In [11]:
(
    swnu_network_umap_kfold_20,
    best_swnu_network_umap_kfold_20,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_20_kfold_best_model.csv


In [12]:
swnu_network_umap_kfold_20.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_umap_kfold_20.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.62765,0.494926,0.496842,0.498241,0.66383,0.543002,0.546849,0.542859,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.634994,0.514723,0.508342,0.530955,0.662599,0.55591,0.552442,0.566035,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.633078,0.514975,0.51031,0.53554,0.662546,0.559394,0.556475,0.575048,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.632631,0.496811,0.497164,0.499495,0.665917,0.546658,0.549894,0.545738,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.626054,0.510704,0.503491,0.53083,0.661047,0.558644,0.553538,0.572924,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.623946,0.510259,0.505439,0.535344,0.658103,0.558495,0.555287,0.579128,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.625734,0.50202,0.49642,0.511061,0.66292,0.552958,0.549165,0.557962,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.622542,0.508704,0.501521,0.530302,0.657782,0.559348,0.553183,0.576977,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.63493,0.513492,0.507185,0.528069,0.666078,0.563568,0.559303,0.575521,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.625862,0.500076,0.495277,0.507922,0.662492,0.551654,0.548896,0.556041,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swnu_network_umap_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.64387,0.511869,"[0.7680988816951148, 0.45691699604743086, 0.31...",0.509062,"[0.7819053325344517, 0.4385432473444613, 0.306...",0.515405,"[0.7547715442452284, 0.4768976897689769, 0.314...",,0.671965,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.631609,0.514124,"[0.7606056063558688, 0.4337448559670783, 0.348...",0.505864,"[0.7896047307812014, 0.43267651888341546, 0.29...",0.530705,"[0.7336610757663389, 0.43481848184818483, 0.42...",,0.667951,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.645211,0.518232,"[0.7710668025032746, 0.42753313381787084, 0.35...",0.513955,"[0.7761500146498681, 0.44365572315882873, 0.32...",0.525591,"[0.7660497397339503, 0.41254125412541254, 0.39...",,0.678227,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swnu_network_umap_kfold_20["f1"].mean()

0.5147419120105536

In [15]:
best_swnu_network_umap_kfold_20["precision"].mean()

0.509626944353249

In [16]:
best_swnu_network_umap_kfold_20["recall"].mean()

0.5239003802057476

In [17]:
np.stack(best_swnu_network_umap_kfold_20["f1_scores"]).mean(axis=0)

array([0.76659043, 0.43939833, 0.33823698])

In [18]:
np.stack(best_swnu_network_umap_kfold_20["precision_scores"]).mean(axis=0)

array([0.78255336, 0.43829183, 0.30803564])

In [19]:
np.stack(best_swnu_network_umap_kfold_20["recall_scores"]).mean(axis=0)

array([0.75149412, 0.44141914, 0.37878788])

## GRP

In [11]:
(
    swnu_network_grp_kfold_20,
    best_swnu_network_grp_kfold_20,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_20_kfold_best_model.csv


In [12]:
swnu_network_grp_kfold_20.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_20.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.636845,0.512512,0.506699,0.521145,0.667041,0.559155,0.555785,0.56346,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.631928,0.519179,0.510761,0.539115,0.659548,0.558707,0.552736,0.572985,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.62963,0.519158,0.511065,0.542875,0.655106,0.556204,0.549974,0.572997,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.633269,0.509656,0.503574,0.520169,0.667095,0.560308,0.555909,0.566991,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.637165,0.521925,0.51413,0.538394,0.659655,0.554998,0.550223,0.56677,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631992,0.518273,0.512666,0.541663,0.657407,0.555902,0.552361,0.572543,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.630843,0.510786,0.50359,0.522578,0.665489,0.560293,0.554989,0.567638,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.633269,0.518884,0.511691,0.539491,0.662546,0.561188,0.556455,0.576718,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.635249,0.515574,0.510029,0.532973,0.662385,0.557933,0.554339,0.570091,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.629949,0.512507,0.504665,0.526545,0.664365,0.561216,0.555227,0.570017,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
swnu_network_grp_kfold_20.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_20.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.636845,0.512512,0.506699,0.521145,0.667041,0.559155,0.555785,0.56346,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.631928,0.519179,0.510761,0.539115,0.659548,0.558707,0.552736,0.572985,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.62963,0.519158,0.511065,0.542875,0.655106,0.556204,0.549974,0.572997,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.633269,0.509656,0.503574,0.520169,0.667095,0.560308,0.555909,0.566991,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.637165,0.521925,0.51413,0.538394,0.659655,0.554998,0.550223,0.56677,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631992,0.518273,0.512666,0.541663,0.657407,0.555902,0.552361,0.572543,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.630843,0.510786,0.50359,0.522578,0.665489,0.560293,0.554989,0.567638,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.633269,0.518884,0.511691,0.539491,0.662546,0.561188,0.556455,0.576718,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.635249,0.515574,0.510029,0.532973,0.662385,0.557933,0.554339,0.570091,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.629949,0.512507,0.504665,0.526545,0.664365,0.561216,0.555227,0.570017,20.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [14]:
best_swnu_network_grp_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.638506,0.514383,"[0.7664425606547792, 0.42052221713238663, 0.35...",0.512954,"[0.774822695035461, 0.4727085478887745, 0.2913...",0.531712,"[0.7582417582417582, 0.3787128712871287, 0.458...",,0.662331,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.62931,0.520225,"[0.7538438118435073, 0.4538893344025661, 0.352...",0.510405,"[0.7958855673416908, 0.4414976599063963, 0.293...",0.541612,"[0.7160208212839791, 0.466996699669967, 0.4418...",,0.664579,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.642146,0.519327,"[0.7673874926943309, 0.4283185840707965, 0.362...",0.515089,"[0.7755463673951565, 0.4618320610687023, 0.307...",0.532913,"[0.7593984962406015, 0.39933993399339934, 0.44]",,0.676782,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [15]:
best_swnu_network_grp_kfold_20["f1"].mean()

0.5179782635504212

In [16]:
best_swnu_network_grp_kfold_20["precision"].mean()

0.5128159501021253

In [17]:
best_swnu_network_grp_kfold_20["recall"].mean()

0.5354122867463149

In [18]:
np.stack(best_swnu_network_grp_kfold_20["f1_scores"]).mean(axis=0)

array([0.76255796, 0.43424338, 0.35713346])

In [19]:
np.stack(best_swnu_network_grp_kfold_20["precision_scores"]).mean(axis=0)

array([0.78208488, 0.45867942, 0.29768355])

In [20]:
np.stack(best_swnu_network_grp_kfold_20["recall_scores"]).mean(axis=0)

array([0.74455369, 0.4150165 , 0.44666667])

# w=35

In [10]:
size = 35

## UMAP

In [22]:
(
    swnu_network_umap_kfold_35,
    best_swnu_network_umap_kfold_35,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_35_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_35_kfold_best_model.csv


In [23]:
swnu_network_umap_kfold_35.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_umap_kfold_35.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.618327,0.471522,0.475986,0.471715,0.656979,0.536934,0.54302,0.534891,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.62567,0.507319,0.500092,0.521969,0.662653,0.563354,0.556849,0.575947,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.633461,0.516621,0.50964,0.536429,0.662599,0.5599,0.556433,0.575022,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.619413,0.473589,0.476565,0.4736,0.658371,0.54029,0.543985,0.538434,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.62401,0.506852,0.499889,0.524577,0.657836,0.559451,0.553031,0.574786,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631737,0.517368,0.509418,0.53774,0.658103,0.555295,0.550645,0.570391,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.622605,0.483077,0.484155,0.483795,0.660726,0.543971,0.546316,0.542763,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.619987,0.503479,0.495443,0.520437,0.663241,0.564158,0.557554,0.578794,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.631034,0.511467,0.505596,0.528255,0.663348,0.561231,0.556816,0.574798,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.617944,0.483983,0.483,0.488053,0.659388,0.551257,0.550347,0.554296,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [24]:
best_swnu_network_umap_kfold_35

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.628736,0.493059,"[0.7616099071207432, 0.42324649298597194, 0.29...",0.489263,"[0.7768421052631579, 0.41153546375681993, 0.27...",0.497839,"[0.7469635627530364, 0.43564356435643564, 0.31...",,0.689467,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.617816,0.48408,"[0.754733727810651, 0.39775910364145656, 0.299...",0.479746,"[0.7725620835857057, 0.3861693861693862, 0.280...",0.489865,"[0.7377096587622903, 0.41006600660066006, 0.32...",,0.668593,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.625479,0.489244,"[0.7591155366817983, 0.4062116877809563, 0.302...",0.485995,"[0.7689113022841887, 0.40242914979757083, 0.28...",0.493211,"[0.7495662232504338, 0.41006600660066006, 0.32]",,0.675819,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [25]:
best_swnu_network_umap_kfold_35["f1"].mean()

0.4887942299979727

In [26]:
best_swnu_network_umap_kfold_35["precision"].mean()

0.4850014820266688

In [27]:
best_swnu_network_umap_kfold_35["recall"].mean()

0.49363803278342105

In [28]:
np.stack(best_swnu_network_umap_kfold_35["f1_scores"]).mean(axis=0)

array([0.75848639, 0.40907243, 0.29882387])

In [29]:
np.stack(best_swnu_network_umap_kfold_35["precision_scores"]).mean(axis=0)

array([0.77277183, 0.40004467, 0.28218795])

In [30]:
np.stack(best_swnu_network_umap_kfold_35["recall_scores"]).mean(axis=0)

array([0.74474648, 0.41859186, 0.31757576])

## GRP

In [11]:
(
    swnu_network_grp_kfold_35,
    best_swnu_network_grp_kfold_35,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_35_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_35_kfold_best_model.csv


In [12]:
swnu_network_grp_kfold_35.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_35.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.624521,0.497917,0.491986,0.507318,0.668326,0.562935,0.558536,0.569261,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.623308,0.506423,0.499172,0.521456,0.666667,0.569588,0.563301,0.583915,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.629821,0.514509,0.509264,0.535659,0.662438,0.560027,0.556452,0.574728,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.623819,0.496021,0.490486,0.504672,0.667416,0.562382,0.558087,0.568799,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.625415,0.504233,0.498186,0.517154,0.669824,0.570722,0.565329,0.582683,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631034,0.514572,0.50856,0.534387,0.663884,0.56504,0.559901,0.580935,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.625543,0.504555,0.497271,0.517535,0.670253,0.570833,0.56477,0.581059,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.630843,0.514675,0.508531,0.534875,0.664954,0.566212,0.561464,0.58215,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.637995,0.512485,0.510613,0.527942,0.668005,0.562703,0.56284,0.575392,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.627714,0.504768,0.498006,0.516682,0.670413,0.569954,0.564339,0.579504,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
swnu_network_grp_kfold_35.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_35.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.624521,0.497917,0.491986,0.507318,0.668326,0.562935,0.558536,0.569261,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.623308,0.506423,0.499172,0.521456,0.666667,0.569588,0.563301,0.583915,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.629821,0.514509,0.509264,0.535659,0.662438,0.560027,0.556452,0.574728,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.623819,0.496021,0.490486,0.504672,0.667416,0.562382,0.558087,0.568799,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.625415,0.504233,0.498186,0.517154,0.669824,0.570722,0.565329,0.582683,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.631034,0.514572,0.50856,0.534387,0.663884,0.56504,0.559901,0.580935,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.625543,0.504555,0.497271,0.517535,0.670253,0.570833,0.56477,0.581059,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.630843,0.514675,0.508531,0.534875,0.664954,0.566212,0.561464,0.58215,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.637995,0.512485,0.510613,0.527942,0.668005,0.562703,0.56284,0.575392,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.627714,0.504768,0.498006,0.516682,0.670413,0.569954,0.564339,0.579504,35.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [14]:
best_swnu_network_grp_kfold_35

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.629694,0.508746,"[0.7544642857142857, 0.44112903225806455, 0.33...",0.501873,"[0.7771305947271613, 0.4313880126182965, 0.297...",0.519043,"[0.7330827067669173, 0.4513201320132013, 0.372...",,0.673732,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.632567,0.514838,"[0.7555886736214605, 0.4442662389735365, 0.344...",0.507384,"[0.7795202952029521, 0.43213728549141966, 0.31...",0.525817,"[0.7330827067669173, 0.4570957095709571, 0.387...",,0.67614,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.63659,0.508435,"[0.7655474452554745, 0.4192982456140351, 0.340...",0.504661,"[0.7729952830188679, 0.44756554307116103, 0.29...",0.519362,"[0.7582417582417582, 0.3943894389438944, 0.405...",,0.679351,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [15]:
best_swnu_network_grp_kfold_35["f1"].mean()

0.5106730324632348

In [16]:
best_swnu_network_grp_kfold_35["precision"].mean()

0.5046394603176619

In [17]:
best_swnu_network_grp_kfold_35["recall"].mean()

0.5214074441953546

In [18]:
np.stack(best_swnu_network_grp_kfold_35["f1_scores"]).mean(axis=0)

array([0.75853347, 0.43489784, 0.33858779])

In [19]:
np.stack(best_swnu_network_grp_kfold_35["precision_scores"]).mean(axis=0)

array([0.77654872, 0.43703028, 0.30033938])

# w=80

In [10]:
size = 80

## UMAP

In [11]:
(
    swnu_network_umap_kfold_80,
    best_swnu_network_umap_kfold_80,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_80_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_80_kfold_best_model.csv


In [12]:
swnu_network_umap_kfold_80.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_umap_kfold_80.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.619476,0.462949,0.466907,0.461751,0.665543,0.553903,0.554666,0.554367,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.618966,0.498448,0.491258,0.513866,0.673892,0.578897,0.57102,0.593756,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.624074,0.504912,0.498198,0.521646,0.668593,0.571478,0.565266,0.586163,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.622861,0.465501,0.468959,0.463727,0.671858,0.558338,0.560106,0.557628,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.613729,0.48836,0.482406,0.501618,0.674106,0.580653,0.57207,0.598106,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.624904,0.50926,0.50214,0.53078,0.672554,0.574532,0.570338,0.591537,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.610983,0.46972,0.46713,0.474699,0.665971,0.565315,0.559287,0.574478,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.61258,0.501213,0.492755,0.522174,0.6649,0.576381,0.565679,0.595829,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.623819,0.510067,0.501638,0.529151,0.66656,0.570916,0.563091,0.585767,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.61341,0.467833,0.466802,0.470689,0.668379,0.562972,0.559376,0.568892,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swnu_network_umap_kfold_80

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.609387,0.481852,"[0.7443530291697831, 0.40855874041178847, 0.29...",0.475954,"[0.7709947319491788, 0.4, 0.25686813186813184]",0.492328,"[0.719491035280509, 0.4174917491749175, 0.34]",,0.696371,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.616667,0.499949,"[0.7484343974339392, 0.4305239179954442, 0.320...",0.49226,"[0.7931369375202331, 0.3987341772151899, 0.284...",0.514532,"[0.708502024291498, 0.46782178217821785, 0.367...",,0.685132,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.633333,0.498982,"[0.7644742598804142, 0.4147465437788018, 0.317...",0.495501,"[0.7711091497499265, 0.42127659574468085, 0.29...",0.503941,"[0.7579525737420474, 0.4084158415841584, 0.345...",,0.687861,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [None]:
best_swnu_network_umap_kfold_80["f1"].mean()

0.49359455483651365

In [15]:
best_swnu_network_umap_kfold_80["precision"].mean()

0.48790507696509494

In [16]:
best_swnu_network_umap_kfold_80["recall"].mean()

0.5036002532198468

In [17]:
np.stack(best_swnu_network_umap_kfold_80["f1_scores"]).mean(axis=0)

array([0.75242056, 0.41794307, 0.31042003])

In [18]:
np.stack(best_swnu_network_umap_kfold_80["precision_scores"]).mean(axis=0)

array([0.77841361, 0.40667026, 0.27863137])

In [19]:
np.stack(best_swnu_network_umap_kfold_80["recall_scores"]).mean(axis=0)

array([0.72864854, 0.43124312, 0.35090909])

## GRP

In [11]:
(
    swnu_network_grp_kfold_80,
    best_swnu_network_grp_kfold_80,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_80_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_80_kfold_best_model.csv


In [12]:
swnu_network_grp_kfold_80.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_80.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.603321,0.481278,0.474794,0.492,0.682349,0.588798,0.580977,0.601596,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.610728,0.498871,0.49058,0.517422,0.673732,0.58273,0.573855,0.599987,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.61507,0.501945,0.49384,0.521522,0.666078,0.573661,0.5662,0.592945,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.609451,0.48222,0.476693,0.490717,0.685239,0.590269,0.583577,0.600894,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.60977,0.499958,0.491147,0.521193,0.67143,0.580846,0.571757,0.601226,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.616092,0.502482,0.494491,0.521764,0.664205,0.568649,0.561899,0.585776,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.612899,0.489557,0.482983,0.500594,0.679779,0.586137,0.578376,0.599931,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.608046,0.497007,0.488407,0.516223,0.669129,0.576239,0.567458,0.593379,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.626373,0.5121,0.504238,0.530966,0.666399,0.569357,0.563629,0.584749,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.607982,0.486714,0.479825,0.499264,0.681813,0.588589,0.580896,0.602835,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
swnu_network_grp_kfold_80.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_80.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.603321,0.481278,0.474794,0.492,0.682349,0.588798,0.580977,0.601596,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.610728,0.498871,0.49058,0.517422,0.673732,0.58273,0.573855,0.599987,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.61507,0.501945,0.49384,0.521522,0.666078,0.573661,0.5662,0.592945,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.609451,0.48222,0.476693,0.490717,0.685239,0.590269,0.583577,0.600894,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.60977,0.499958,0.491147,0.521193,0.67143,0.580846,0.571757,0.601226,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.616092,0.502482,0.494491,0.521764,0.664205,0.568649,0.561899,0.585776,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.612899,0.489557,0.482983,0.500594,0.679779,0.586137,0.578376,0.599931,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.608046,0.497007,0.488407,0.516223,0.669129,0.576239,0.567458,0.593379,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.626373,0.5121,0.504238,0.530966,0.666399,0.569357,0.563629,0.584749,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.607982,0.486714,0.479825,0.499264,0.681813,0.588589,0.580896,0.602835,80.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [14]:
best_swnu_network_grp_kfold_80

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.614368,0.484163,"[0.7479650732573628, 0.40032089851584435, 0.30...",0.479447,"[0.7659896938466202, 0.38953942232630756, 0.28...",0.490525,"[0.7307692307692307, 0.41171617161716173, 0.32...",,0.686737,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.608812,0.484133,"[0.7404546129846977, 0.39710028191703584, 0.31...",0.478256,"[0.7613809960281087, 0.3878835562549174, 0.285...",0.492774,"[0.7206477732793523, 0.4067656765676568, 0.350...",,0.68754,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.605172,0.478363,"[0.7408181546730367, 0.395131845841785, 0.2991...",0.472377,"[0.7657407407407407, 0.38866719872306466, 0.26...",0.488852,"[0.7174667437825333, 0.4018151815181518, 0.347...",,0.681439,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [15]:
best_swnu_network_grp_kfold_80["f1"].mean()

0.4822195753822131

In [16]:
best_swnu_network_grp_kfold_80["precision"].mean()

0.47669339864668175

In [17]:
best_swnu_network_grp_kfold_80["recall"].mean()

0.49071705608964605

In [18]:
np.stack(best_swnu_network_grp_kfold_80["f1_scores"]).mean(axis=0)

array([0.74307928, 0.39751768, 0.30606177])

In [19]:
np.stack(best_swnu_network_grp_kfold_80["precision_scores"]).mean(axis=0)

array([0.76437048, 0.38869673, 0.27701299])

In [20]:
np.stack(best_swnu_network_grp_kfold_80["recall_scores"]).mean(axis=0)

array([0.72296125, 0.40676568, 0.34242424])

# w=110

In [10]:
size = 110

## UMAP

In [11]:
(
    swnu_network_umap_kfold_110,
    best_swnu_network_umap_kfold_110,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_110_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_umap_focal_2_110_kfold_best_model.csv


In [12]:
swnu_network_umap_kfold_110.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_umap_kfold_110.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.611622,0.458722,0.462186,0.457373,0.667791,0.559401,0.557944,0.562353,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.620115,0.497577,0.49157,0.510138,0.682402,0.58734,0.579641,0.600913,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.627714,0.509374,0.502451,0.527453,0.673089,0.574862,0.568636,0.590026,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.616475,0.458668,0.463728,0.456256,0.673571,0.563054,0.562608,0.564045,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.610345,0.485802,0.480109,0.496973,0.68283,0.587274,0.579728,0.599704,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.616028,0.501352,0.494522,0.522308,0.674213,0.581246,0.573647,0.600261,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.619285,0.457244,0.463114,0.455991,0.672447,0.563747,0.561931,0.567775,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.617178,0.498676,0.492575,0.514447,0.676194,0.585441,0.577282,0.605066,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.618391,0.504427,0.497134,0.522544,0.66656,0.569044,0.562979,0.582788,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.614176,0.463065,0.466444,0.464862,0.670841,0.570035,0.564482,0.579735,110.0,4.0,...,384.0,0.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [13]:
best_swnu_network_umap_kfold_110

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.601724,0.481847,"[0.7352006056018168, 0.4131227217496963, 0.297...",0.475344,"[0.7715284397839212, 0.40572792362768495, 0.24...",0.497341,"[0.70213996529786, 0.4207920792079208, 0.36909...",,0.694766,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.615326,0.502466,"[0.7475360121304019, 0.4117887842816209, 0.348...",0.493593,"[0.7857825948358305, 0.40861088545897645, 0.28...",0.523831,"[0.7128397917871602, 0.415016501650165, 0.4436...",,0.686577,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.61954,0.497559,"[0.7500742942050521, 0.4087530966143683, 0.333...",0.490604,"[0.7713936430317848, 0.4090909090909091, 0.291...",0.509742,"[0.7299016772700984, 0.4084158415841584, 0.390...",,0.682081,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [14]:
best_swnu_network_umap_kfold_110["f1"].mean()

0.4939576423556507

In [15]:
best_swnu_network_umap_kfold_110["precision"].mean()

0.4865135328264425

In [16]:
best_swnu_network_umap_kfold_110["recall"].mean()

0.5103046911593029

In [17]:
np.stack(best_swnu_network_umap_kfold_110["f1_scores"]).mean(axis=0)

array([0.7442703 , 0.41122153, 0.32638109])

In [18]:
np.stack(best_swnu_network_umap_kfold_110["precision_scores"]).mean(axis=0)

array([0.77623489, 0.40780991, 0.2754958 ])

In [19]:
np.stack(best_swnu_network_umap_kfold_110["recall_scores"]).mean(axis=0)

array([0.71496048, 0.41474147, 0.40121212])

## GRP

In [None]:
(
    swnu_network_grp_kfold_110,
    best_swnu_network_grp_kfold_110,
    _,
    __,
) = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data_client,
    output_dim=output_dim_client,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swnu_hidden_dim_sizes_and_sig_depths=swnu_hidden_dim_sizes_and_sig_depths,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=True,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    device=device,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    include_features_in_input=include_features_in_input,
    path_indices=client_index,
    split_ids=client_transcript_id,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
swnu_network_grp_kfold_110.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

In [None]:
swnu_network_grp_kfold_110.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

In [None]:
best_swnu_network_grp_kfold_110

In [None]:
best_swnu_network_grp_kfold_110["f1"].mean()

In [None]:
best_swnu_network_grp_kfold_110["precision"].mean()

In [None]:
best_swnu_network_grp_kfold_110["recall"].mean()

In [None]:
np.stack(best_swnu_network_grp_kfold_110["f1_scores"]).mean(axis=0)

In [None]:
np.stack(best_swnu_network_grp_kfold_110["precision_scores"]).mean(axis=0)

In [None]:
np.stack(best_swnu_network_grp_kfold_110["recall_scores"]).mean(axis=0)