In [1]:
import numpy as np
import pickle
import os
from tqdm.notebook import tqdm

seed = 2023

In [2]:
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
from nlpsig_networks.scripts.swnu_network_functions import (
    swnu_network_hyperparameter_search,
)

In [4]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [5]:
%run ../load_anno_mi.py

In [6]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,speaker
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-11-02 00:00:13,-1
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-11-02 00:00:24,1
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-11-02 00:00:25,-1
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-11-02 00:00:34,1
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-11-02 00:00:34,-1


In [7]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# SWNU Network

## Obtaining path by looking at post history

We can obtain a path by looking at the history of each post. Here we look at the last 10 posts (and pad with vectors of zeros if there are less than 10 posts) including the current post.

We only want to consider paths that correspond to a client's utterance as we want to model a change in mood at that time. Their history will still contain the therapist's utterances too.

In [8]:
features = ["time_encoding", "timeline_index"]
standardise_method = ["z_score", None]
include_features_in_path = True
include_features_in_input = True

In [9]:
num_epochs = 100
dimensions = [15]
swnu_hidden_dim_sizes_and_sig_depths = [([12], 3), ([10], 4)]
ffn_hidden_dim_sizes = [[256, 256], [512, 512]]
dropout_rates = [0.1, 0.2]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [10]:
kwargs = {
    "num_epochs": num_epochs,
    "df": anno_mi,
    "id_column": "transcript_id",
    "label_column": "client_talk_type",
    "embeddings": sbert_embeddings,
    "y_data": y_data_client,
    "output_dim": output_dim_client,
    "dimensions": dimensions,
    "log_signature": True,
    "pooling": "signature",
    "swnu_hidden_dim_sizes_and_sig_depths": swnu_hidden_dim_sizes_and_sig_depths,
    "ffn_hidden_dim_sizes": ffn_hidden_dim_sizes,
    "dropout_rates": dropout_rates,
    "learning_rates": learning_rates,
    "BiLSTM": True,
    "seeds": seeds,
    "loss": loss,
    "gamma": gamma,
    "device": device,
    "features": features,
    "standardise_method": standardise_method,
    "include_features_in_path": include_features_in_path,
    "include_features_in_input": include_features_in_input,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "patience": patience,
    "validation_metric": validation_metric,
    "verbose": False,
}

# w=5

In [11]:
size = 5

## GRP

In [12]:
(
    swnu_network_grp_kfold_5,
    best_swnu_network_grp_kfold_5,
    _,
    __,
) = swnu_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_5_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_5_kfold_best_model.csv


In [13]:
swnu_network_grp_kfold_5.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_5.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.642593,0.517939,0.512941,0.52584,0.669878,0.560531,0.557392,0.564785,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.616475,0.505092,0.495921,0.524211,0.661314,0.563003,0.554992,0.57474,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.637739,0.508486,0.507559,0.516383,0.662385,0.549764,0.548589,0.553927,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.639336,0.519453,0.513397,0.530996,0.668486,0.562807,0.558131,0.569886,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.630715,0.517751,0.509796,0.533069,0.658906,0.556884,0.551283,0.566271,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.619796,0.498481,0.493499,0.508683,0.661796,0.553292,0.550249,0.560066,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.63825,0.517278,0.51159,0.525571,0.671591,0.564863,0.561295,0.569275,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.620434,0.50369,0.497558,0.515885,0.663081,0.560578,0.555221,0.569755,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.619221,0.495556,0.492274,0.501213,0.660833,0.548433,0.546763,0.551875,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.644381,0.515087,0.511767,0.520157,0.669664,0.555656,0.554508,0.557307,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [14]:
swnu_network_grp_kfold_5.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_5.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.642593,0.517939,0.512941,0.52584,0.669878,0.560531,0.557392,0.564785,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.616475,0.505092,0.495921,0.524211,0.661314,0.563003,0.554992,0.57474,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.637739,0.508486,0.507559,0.516383,0.662385,0.549764,0.548589,0.553927,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.639336,0.519453,0.513397,0.530996,0.668486,0.562807,0.558131,0.569886,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.630715,0.517751,0.509796,0.533069,0.658906,0.556884,0.551283,0.566271,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.619796,0.498481,0.493499,0.508683,0.661796,0.553292,0.550249,0.560066,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.63825,0.517278,0.51159,0.525571,0.671591,0.564863,0.561295,0.569275,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.620434,0.50369,0.497558,0.515885,0.663081,0.560578,0.555221,0.569755,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.619221,0.495556,0.492274,0.501213,0.660833,0.548433,0.546763,0.551875,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.644381,0.515087,0.511767,0.520157,0.669664,0.555656,0.554508,0.557307,5.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [15]:
best_swnu_network_grp_kfold_5

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.631609,0.517952,"[0.7513562386980108, 0.45882803523554194, 0.34...",0.510479,"[0.7844556324732537, 0.4281629735525375, 0.318...",0.529296,"[0.7209369577790631, 0.49422442244224424, 0.37...",,0.67325,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.637165,0.510735,"[0.7649904706054831, 0.4227234753550543, 0.344...",0.505492,"[0.7757954207552781, 0.428087986463621, 0.3125...",0.518537,"[0.7544823597455177, 0.4174917491749175, 0.383...",,0.669075,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.645977,0.523149,"[0.7655601659751037, 0.46760343481654953, 0.33...",0.518799,"[0.7851063829787234, 0.4437037037037037, 0.327...",0.528881,"[0.7469635627530364, 0.49422442244224424, 0.34...",,0.672447,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


in path (with speaker): 0.5143628505176907

in path (no speaker): 0.522911065406166

in input (no speaker): 0.5125295634373418

both (no speaker): 0.5151677550731781

(w/ time_encoding) in path (with speaker): 0.5186878521492474

(w/ time_encoding) in path (no speaker): 0.5191477428145815

(w/ time_encoding) in input (no speaker): 0.5126337587709865

(w/ time_encoding) both (no speaker):  0.5197361095991764

In [16]:
best_swnu_network_grp_kfold_5["f1"].mean()

0.5172784687824037

In [17]:
best_swnu_network_grp_kfold_5["precision"].mean()

0.5115898822057466

In [18]:
best_swnu_network_grp_kfold_5["recall"].mean()

0.5255712951283561

In [19]:
np.stack(best_swnu_network_grp_kfold_5["f1_scores"]).mean(axis=0)

array([0.76063563, 0.44971832, 0.34148147])

In [20]:
np.stack(best_swnu_network_grp_kfold_5["precision_scores"]).mean(axis=0)

array([0.78178581, 0.43331822, 0.31966561])

In [21]:
np.stack(best_swnu_network_grp_kfold_5["recall_scores"]).mean(axis=0)

array([0.74079429, 0.46864686, 0.36727273])

# w=11

In [22]:
size = 11

## GRP

In [23]:
(
    swnu_network_grp_kfold_11,
    best_swnu_network_grp_kfold_11,
    _,
    __,
) = swnu_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_11_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_11_kfold_best_model.csv


In [None]:
swnu_network_grp_kfold_11.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_11.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.63742,0.512353,0.50842,0.518756,0.665115,0.548997,0.54814,0.550756,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.619923,0.502744,0.495789,0.517242,0.657033,0.550762,0.545835,0.559927,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.628225,0.503931,0.50157,0.513363,0.655159,0.539334,0.540025,0.542939,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.641507,0.515027,0.511085,0.521455,0.663937,0.544758,0.543959,0.546407,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.610281,0.492446,0.485544,0.507381,0.655695,0.552948,0.545999,0.563295,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.618902,0.494763,0.491671,0.502298,0.656765,0.546311,0.545068,0.551287,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.636207,0.513508,0.506957,0.524587,0.662706,0.551843,0.548269,0.55661,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.627203,0.50848,0.502378,0.521326,0.656872,0.552282,0.547282,0.560852,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.622095,0.499258,0.494313,0.509312,0.652965,0.541744,0.538616,0.546918,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.638378,0.49983,0.499416,0.502763,0.669557,0.550157,0.551088,0.550363,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [None]:
swnu_network_grp_kfold_11.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_11.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.63742,0.512353,0.50842,0.518756,0.665115,0.548997,0.54814,0.550756,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.619923,0.502744,0.495789,0.517242,0.657033,0.550762,0.545835,0.559927,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.628225,0.503931,0.50157,0.513363,0.655159,0.539334,0.540025,0.542939,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.641507,0.515027,0.511085,0.521455,0.663937,0.544758,0.543959,0.546407,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.610281,0.492446,0.485544,0.507381,0.655695,0.552948,0.545999,0.563295,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.618902,0.494763,0.491671,0.502298,0.656765,0.546311,0.545068,0.551287,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.636207,0.513508,0.506957,0.524587,0.662706,0.551843,0.548269,0.55661,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.627203,0.50848,0.502378,0.521326,0.656872,0.552282,0.547282,0.560852,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.622095,0.499258,0.494313,0.509312,0.652965,0.541744,0.538616,0.546918,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.638378,0.49983,0.499416,0.502763,0.669557,0.550157,0.551088,0.550363,11.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [None]:
best_swnu_network_grp_kfold_11

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.604789,0.489501,"[0.7350375019133628, 0.42152466367713004, 0.31...",0.482666,"[0.7808130081300813, 0.38524590163934425, 0.28...",0.502923,"[0.694331983805668, 0.46534653465346537, 0.349...",,0.650771,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.618199,0.495678,"[0.7487073422957601, 0.39537763103590584, 0.34...",0.489153,"[0.7653276955602537, 0.3955408753096614, 0.306...",0.5057,"[0.7327935222672065, 0.3952145214521452, 0.389...",,0.662331,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.607854,0.49216,"[0.7438041660331459, 0.40560593569661996, 0.32...",0.484812,"[0.7842257133696697, 0.4052718286655684, 0.264...",0.51352,"[0.7073452862926547, 0.40594059405940597, 0.42...",,0.653982,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [None]:
best_swnu_network_grp_kfold_11["f1"].mean()

0.49244641720932986

In [None]:
best_swnu_network_grp_kfold_11["precision"].mean()

0.48554351108678334

In [None]:
best_swnu_network_grp_kfold_11["recall"].mean()

0.5073807764427879

In [None]:
np.stack(best_swnu_network_grp_kfold_11["f1_scores"]).mean(axis=0)

array([0.74251634, 0.40750274, 0.32732017])

In [None]:
np.stack(best_swnu_network_grp_kfold_11["precision_scores"]).mean(axis=0)

array([0.77678881, 0.39535287, 0.28448886])

In [None]:
np.stack(best_swnu_network_grp_kfold_11["recall_scores"]).mean(axis=0)

array([0.71149026, 0.42216722, 0.38848485])

# w=20

In [None]:
size = 20

## GRP

In [None]:
(
    swnu_network_grp_kfold_20,
    best_swnu_network_grp_kfold_20,
    _,
    __,
) = swnu_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_20_kfold_best_model.csv


In [None]:
swnu_network_grp_kfold_20.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_20.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.636782,0.504137,0.505088,0.507193,0.665329,0.54937,0.552315,0.54906,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.621775,0.510201,0.504318,0.524393,0.657086,0.55501,0.550891,0.563368,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.617816,0.494936,0.489517,0.506276,0.662813,0.56019,0.554493,0.569537,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.627714,0.494296,0.497169,0.496744,0.662867,0.545428,0.54945,0.545085,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.62848,0.504217,0.499629,0.511051,0.665007,0.562756,0.558817,0.568243,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.607982,0.495091,0.488045,0.512289,0.659174,0.5647,0.557056,0.579909,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.6341,0.501079,0.497555,0.506311,0.66581,0.55212,0.551477,0.553499,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.619349,0.507674,0.500819,0.521912,0.657996,0.557318,0.55194,0.567408,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.628161,0.497982,0.495578,0.500979,0.667416,0.55612,0.554743,0.558369,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.643423,0.505739,0.506331,0.508003,0.663348,0.542448,0.546129,0.541131,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [None]:
swnu_network_grp_kfold_20.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_20.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.636782,0.504137,0.505088,0.507193,0.665329,0.54937,0.552315,0.54906,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.621775,0.510201,0.504318,0.524393,0.657086,0.55501,0.550891,0.563368,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.617816,0.494936,0.489517,0.506276,0.662813,0.56019,0.554493,0.569537,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.627714,0.494296,0.497169,0.496744,0.662867,0.545428,0.54945,0.545085,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.62848,0.504217,0.499629,0.511051,0.665007,0.562756,0.558817,0.568243,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.607982,0.495091,0.488045,0.512289,0.659174,0.5647,0.557056,0.579909,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.6341,0.501079,0.497555,0.506311,0.66581,0.55212,0.551477,0.553499,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.619349,0.507674,0.500819,0.521912,0.657996,0.557318,0.55194,0.567408,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.628161,0.497982,0.495578,0.500979,0.667416,0.55612,0.554743,0.558369,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.643423,0.505739,0.506331,0.508003,0.663348,0.542448,0.546129,0.541131,20.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [None]:
best_swnu_network_grp_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.604789,0.485531,"[0.7364599092284418, 0.40747806180846996, 0.31...",0.479333,"[0.7722081218274112, 0.37899219304471254, 0.28...",0.496035,"[0.7038750722961249, 0.4405940594059406, 0.343...",,0.663134,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.614176,0.496981,"[0.7414823670053796, 0.41017488076311603, 0.33...",0.489773,"[0.7671614100185529, 0.39570552147239263, 0.30...",0.507736,"[0.7174667437825333, 0.42574257425742573, 0.38]",,0.662492,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.627778,0.513566,"[0.7476918419857727, 0.45454545454545453, 0.33...",0.510095,"[0.7843759923785328, 0.40897097625329815, 0.33...",0.521946,"[0.7142857142857143, 0.5115511551155115, 0.34]",,0.663937,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [None]:
best_swnu_network_grp_kfold_20["f1"].mean()

0.49869276165916915

In [None]:
best_swnu_network_grp_kfold_20["precision"].mean()

0.4930667715438048

In [None]:
best_swnu_network_grp_kfold_20["recall"].mean()

0.5085724091977348

In [None]:
np.stack(best_swnu_network_grp_kfold_20["f1_scores"]).mean(axis=0)

array([0.74187804, 0.42406613, 0.33013411])

In [None]:
np.stack(best_swnu_network_grp_kfold_20["precision_scores"]).mean(axis=0)

array([0.77458184, 0.39455623, 0.31006224])

In [None]:
np.stack(best_swnu_network_grp_kfold_20["recall_scores"]).mean(axis=0)

array([0.71187584, 0.45929593, 0.35454545])

# w=35

In [None]:
size = 35

## GRP

In [None]:
(
    swnu_network_grp_kfold_35,
    best_swnu_network_grp_kfold_35,
    _,
    __,
) = swnu_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_35_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/swnu_network_grp_focal_2_35_kfold_best_model.csv


In [None]:
swnu_network_grp_kfold_35.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_35.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.637867,0.494666,0.495343,0.494772,0.674641,0.563686,0.563922,0.563609,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.624904,0.501972,0.497012,0.513891,0.66656,0.568308,0.561949,0.580366,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.61705,0.499904,0.493868,0.512743,0.666667,0.571172,0.564378,0.581939,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.62401,0.482987,0.48146,0.485027,0.67127,0.558963,0.559136,0.559257,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.62037,0.49422,0.489289,0.505946,0.672072,0.573535,0.568259,0.582405,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.608876,0.48296,0.478771,0.491663,0.666881,0.565122,0.560047,0.572791,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.624777,0.492122,0.489063,0.496526,0.667951,0.561482,0.558399,0.565314,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.618072,0.493868,0.487605,0.504164,0.665971,0.565311,0.558906,0.574502,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.615517,0.482851,0.480007,0.488973,0.664954,0.558742,0.555127,0.564621,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.626692,0.488337,0.486495,0.490732,0.669129,0.560318,0.559479,0.561783,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [None]:
swnu_network_grp_kfold_35.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

  swnu_network_grp_kfold_35.groupby(


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,sig_depth,...,embedding_dim,num_features,log_signature,seed,BiLSTM,gamma,k_fold,n_splits,batch_size,model_id
dimensions,swnu_hidden_dim,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
15,"(10,)","(256, 256)",0.1,0.0001,0.637867,0.494666,0.495343,0.494772,0.674641,0.563686,0.563922,0.563609,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,13.0
15,"(10,)","(256, 256)",0.1,0.0005,0.624904,0.501972,0.497012,0.513891,0.66656,0.568308,0.561949,0.580366,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,14.0
15,"(10,)","(256, 256)",0.1,0.001,0.61705,0.499904,0.493868,0.512743,0.666667,0.571172,0.564378,0.581939,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,12.0
15,"(10,)","(256, 256)",0.2,0.0001,0.62401,0.482987,0.48146,0.485027,0.67127,0.558963,0.559136,0.559257,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,16.0
15,"(10,)","(256, 256)",0.2,0.0005,0.62037,0.49422,0.489289,0.505946,0.672072,0.573535,0.568259,0.582405,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,17.0
15,"(10,)","(256, 256)",0.2,0.001,0.608876,0.48296,0.478771,0.491663,0.666881,0.565122,0.560047,0.572791,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,15.0
15,"(10,)","(512, 512)",0.1,0.0001,0.624777,0.492122,0.489063,0.496526,0.667951,0.561482,0.558399,0.565314,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,19.0
15,"(10,)","(512, 512)",0.1,0.0005,0.618072,0.493868,0.487605,0.504164,0.665971,0.565311,0.558906,0.574502,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,20.0
15,"(10,)","(512, 512)",0.1,0.001,0.615517,0.482851,0.480007,0.488973,0.664954,0.558742,0.555127,0.564621,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,18.0
15,"(10,)","(512, 512)",0.2,0.0001,0.626692,0.488337,0.486495,0.490732,0.669129,0.560318,0.559479,0.561783,35.0,4.0,...,384.0,2.0,1.0,45.333333,1.0,2.0,1.0,5.0,64.0,22.0


In [None]:
best_swnu_network_grp_kfold_35

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,seed,BiLSTM,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.629502,0.480773,"[0.7627361190612477, 0.3998261625380269, 0.279...",0.48175,"[0.7549575070821529, 0.42240587695133147, 0.26...",0.480981,"[0.7706766917293233, 0.3795379537953795, 0.292...",,0.680315,...,1,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.601341,0.475843,"[0.7312117137307635, 0.4026123703419132, 0.293...",0.471983,"[0.7564142194744977, 0.3767074047447879, 0.282...",0.481811,"[0.7076344707923655, 0.43234323432343236, 0.30...",,0.672447,...,12,True,focal,2,True,5,Conv1d,,concatenation,64
0,,0.615709,0.487494,"[0.7492204899777282, 0.4083640836408364, 0.304...",0.481778,"[0.7699115044247787, 0.4058679706601467, 0.269...",0.497138,"[0.7296124927703875, 0.41089108910891087, 0.35...",,0.67341,...,123,True,focal,2,True,5,Conv1d,,concatenation,64


In [None]:
best_swnu_network_grp_kfold_35["f1"].mean()

0.4813701422942455

In [None]:
best_swnu_network_grp_kfold_35["precision"].mean()

0.47850363267032686

In [None]:
best_swnu_network_grp_kfold_35["recall"].mean()

0.48664298240118975

In [None]:
np.stack(best_swnu_network_grp_kfold_35["f1_scores"]).mean(axis=0)

array([0.74772277, 0.40360087, 0.29278678])

In [None]:
np.stack(best_swnu_network_grp_kfold_35["precision_scores"]).mean(axis=0)

array([0.76042774, 0.40166042, 0.27342274])

# w=80

In [None]:
size = 80

## GRP

In [None]:
(
    swnu_network_grp_kfold_80,
    best_swnu_network_grp_kfold_80,
    _,
    __,
) = swnu_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
swnu_network_grp_kfold_80.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

In [None]:
swnu_network_grp_kfold_80.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

In [None]:
best_swnu_network_grp_kfold_80

In [None]:
best_swnu_network_grp_kfold_80["f1"].mean()

In [None]:
best_swnu_network_grp_kfold_80["precision"].mean()

In [None]:
best_swnu_network_grp_kfold_80["recall"].mean()

In [None]:
np.stack(best_swnu_network_grp_kfold_80["f1_scores"]).mean(axis=0)

In [None]:
np.stack(best_swnu_network_grp_kfold_80["precision_scores"]).mean(axis=0)

In [None]:
np.stack(best_swnu_network_grp_kfold_80["recall_scores"]).mean(axis=0)

# w=110

In [None]:
size = 110

## GRP

In [None]:
(
    swnu_network_grp_kfold_110,
    best_swnu_network_grp_kfold_110,
    _,
    __,
) = swnu_network_hyperparameter_search(
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    results_output=f"{output_dir}/swnu_network_grp_focal_{gamma}_{size}_kfold.csv",
    **kwargs,
)

In [None]:
swnu_network_grp_kfold_110.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

In [None]:
swnu_network_grp_kfold_110.groupby(
    ["dimensions", "swnu_hidden_dim", "ffn_hidden_dim", "dropout_rate", "learning_rate"]
).mean()

In [None]:
best_swnu_network_grp_kfold_110

In [None]:
best_swnu_network_grp_kfold_110["f1"].mean()

In [None]:
best_swnu_network_grp_kfold_110["precision"].mean()

In [None]:
best_swnu_network_grp_kfold_110["recall"].mean()

In [None]:
np.stack(best_swnu_network_grp_kfold_110["f1_scores"]).mean(axis=0)

In [None]:
np.stack(best_swnu_network_grp_kfold_110["precision_scores"]).mean(axis=0)

In [None]:
np.stack(best_swnu_network_grp_kfold_110["recall_scores"]).mean(axis=0)