In [1]:
import pandas as pd
import numpy as np
import torch
import os
import evaluate

import nlpsig

from nlpsig.classification_utils import DataSplits

from sklearn import metrics

from tqdm.notebook import tqdm

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    ffn_hyperparameter_search,
    histories_baseline_hyperparameter_search
)
from nlpsig_networks.scripts.swnu_network_functions import (
    swnu_network_hyperparameter_search
)

In [3]:
import signatory

In [4]:
output_dir = "client_talk_type"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [5]:
anno_mi = pd.read_csv("AnnoMI-full.csv")
anno_mi["datetime"] = pd.to_datetime(anno_mi["timestamp"])
anno_mi = anno_mi.drop(columns=["video_title", "video_url"])
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-06-19 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-06-19 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-06-19 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-06-19 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-06-19 00:00:34


In [6]:
anno_mi["interlocutor"].value_counts()

therapist    6826
client       6725
Name: interlocutor, dtype: int64

In [7]:
anno_mi["main_therapist_behaviour"].value_counts() / anno_mi["interlocutor"].value_counts()["therapist"]

other              0.313947
question           0.286258
reflection         0.251538
therapist_input    0.148257
Name: main_therapist_behaviour, dtype: float64

In [8]:
anno_mi["client_talk_type"].value_counts() / anno_mi["interlocutor"].value_counts()["client"]

neutral    0.627063
change     0.248030
sustain    0.124907
Name: client_talk_type, dtype: float64

In [9]:
anno_mi["interlocutor"].value_counts()

therapist    6826
client       6725
Name: interlocutor, dtype: int64

In [10]:
anno_mi["topic"].value_counts()

reducing alcohol consumption                                                          2326
more exercise / increasing activity                                                   2034
reducing recidivism                                                                   1303
reducing drug use                                                                     1104
diabetes management                                                                    948
smoking cessation                                                                      923
smoking cessation                                                                      541
taking medicine / following medical procedure                                          448
asthma management                                                                      431
avoiding DOI                                                                           394
changing approach to disease                                                           315

In [11]:
len(anno_mi["transcript_id"].unique())

133

## Only considering client for now...

In [12]:
client_index = [isinstance(x, str) for x in anno_mi["client_talk_type"]]
sum(client_index)

6725

In [13]:
y_data = anno_mi["client_talk_type"][client_index]
y_data.shape

(6725,)

In [14]:
y_data[0:20]

1     neutral
3     neutral
5     neutral
7     neutral
9     neutral
11    neutral
13    neutral
15    neutral
17    neutral
19    neutral
21    neutral
23    neutral
25    neutral
27    neutral
29    neutral
31    neutral
33    neutral
35     change
37     change
39     change
Name: client_talk_type, dtype: object

In [15]:
label_to_id = {y_data.unique()[i]: i for i in range(len(y_data.unique()))}
id_to_label = {v: k for k, v in label_to_id.items()}

In [16]:
label_to_id

{'neutral': 0, 'change': 1, 'sustain': 2}

In [17]:
id_to_label

{0: 'neutral', 1: 'change', 2: 'sustain'}

In [18]:
y_data = [label_to_id[x] for x in y_data]
y_data[0:20]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]

In [19]:
output_dim = len(label_to_id.values())
output_dim

3

## Obtaining SBERT Embeddings

We can use the `SentenceEncoder` class within `nlpsig` to obtain sentence embeddings from a model. This class uses the [`sentence-transformer`](https://www.sbert.net/docs/package_reference/SentenceTransformer.html) package and here, we have use the pre-trained `all-MiniLM-L12-v2` model by passing this name as a string to the class - alternative models can be found [here](https://www.sbert.net/docs/pretrained_models.html).

We can pass our dataframe and the column name which stores the sentences that we wish to encode along with the model name into the constructor of the class to initialise our sentence encoder as follows:

In [22]:
# initialise the Text Encoder
sentence_encoder = nlpsig.SentenceEncoder(df=anno_mi,
                                          feature_name="utterance_text",
                                          model_name="all-MiniLM-L12-v2")

We used the `.load_pretrained_model()` method to load in the pre-trained model - this may require you to download the model if this is the first time running the notebook.

In [23]:
sentence_encoder.load_pretrained_model()

We can then obtain embeddings via the `.obtain_embeddings()` method.

In [24]:
sbert_embeddings = sentence_encoder.obtain_embeddings()

[INFO] number of sentences to encode: 13551


Batches:   0%|          | 0/212 [00:00<?, ?it/s]

In [25]:
sbert_embeddings.shape

(13551, 384)

We can save our embeddings for use later:

In [26]:
np.save(f"{output_dir}/anno_mi_client_sentence_embeddings_384",
        sbert_embeddings)

In [20]:
sbert_embeddings = np.load(f"{output_dir}/anno_mi_client_sentence_embeddings_384.npy")

# Baseline: FFN baseline

Using the embeddings for the sentences directly in a FFN to predict the client talk type.

Going to try out some variations (1 hidden layer, 2 hidden layers and 3 hidden layers - all of size 100)

In [21]:
num_epochs = 100
hidden_dim_sizes = [[100]*i for i in range(1, 4)]
dropout_rates = [0.5, 0.2, 0.1]
learning_rates = [5e-3, 1e-3, 5e-4, 1e-4, 1e-5]
seeds = [0, 1, 12, 123, 1234]
loss = "focal"
gamma = 2
validation_metric = "f1"

In [22]:
hidden_dim_sizes

[[100], [100, 100], [100, 100, 100]]

In [23]:
learning_rates

[0.005, 0.001, 0.0005, 0.0001, 1e-05]

In [39]:
ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in ffn_current_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in ffn_current_focal_2_best_model.csv


In [40]:
ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,model_id
0,focal,0.684015,0.618947,"[0.7808471454880296, 0.5281899109792284, 0.547...",0.564228,0.694238,0.607537,"[0.7920646583394563, 0.5475285171102662, 0.483...",[100],0.5,0.00500,0,2,False,0
0,focal,0.692937,0.605870,"[0.7964912280701756, 0.5169628432956381, 0.504...",0.574431,0.726766,0.623467,"[0.821629213483146, 0.5737704918032787, 0.475]",[100],0.5,0.00500,1,2,False,0
0,focal,0.695911,0.619387,"[0.7940652818991099, 0.529032258064516, 0.5350...",0.668241,0.706320,0.613244,"[0.804332129963899, 0.5559999999999999, 0.4794...",[100],0.5,0.00500,12,2,False,0
0,focal,0.698141,0.611627,"[0.79976717112922, 0.5218855218855218, 0.51322...",0.542071,0.713755,0.611268,"[0.8147622427253371, 0.5546218487394958, 0.464...",[100],0.5,0.00500,123,2,False,0
0,focal,0.696654,0.618950,"[0.793575252825699, 0.5481239804241436, 0.5151...",0.681275,0.716543,0.625144,"[0.8127259580621836, 0.5720000000000001, 0.490...",[100],0.5,0.00500,1234,2,False,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,focal,0.645353,0.531895,"[0.7853403141361256, 0.3394833948339484, 0.470...",0.592595,0.650558,0.522043,"[0.7971530249110321, 0.3333333333333333, 0.435...","[100, 100, 100]",0.1,0.00001,0,2,False,44
0,focal,0.596283,0.387553,"[0.7388016288539849, 0.42385786802030456, 0.0]",0.758998,0.602230,0.386229,"[0.7382840663302092, 0.4204018547140649, 0.0]","[100, 100, 100]",0.1,0.00001,1,2,False,44
0,focal,0.647584,0.418361,"[0.7968036529680366, 0.45827814569536424, 0.0]",0.643662,0.671933,0.431289,"[0.8022284122562674, 0.4916387959866221, 0.0]","[100, 100, 100]",0.1,0.00001,12,2,False,44
0,focal,0.645353,0.413831,"[0.796149490373726, 0.44534412955465585, 0.0]",0.639678,0.675651,0.429970,"[0.8079834824501033, 0.48192771084337344, 0.0]","[100, 100, 100]",0.1,0.00001,123,2,False,44


In [41]:
best_ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold
0,focal,0.699628,0.616654,"[0.7990654205607477, 0.525963149078727, 0.5249...",0.632668,0.722119,0.627415,"[0.8181174805378627, 0.5546218487394958, 0.509...",[100],0.5,0.001,0,2,False
0,focal,0.704833,0.627111,"[0.7983490566037734, 0.5552050473186121, 0.527...",0.681163,0.714684,0.622294,"[0.8117394416607016, 0.5490196078431373, 0.506...",[100],0.5,0.001,1,2,False
0,focal,0.704089,0.625025,"[0.8009395184967704, 0.5366666666666666, 0.537...",0.568209,0.726766,0.635061,"[0.8222698072805139, 0.5684210526315789, 0.514...",[100],0.5,0.001,12,2,False
0,focal,0.694424,0.618451,"[0.7903130537507383, 0.5288461538461539, 0.536...",0.608452,0.724907,0.633952,"[0.8160919540229885, 0.5896414342629483, 0.496...",[100],0.5,0.001,123,2,False
0,focal,0.698885,0.624533,"[0.7961859356376638, 0.5333333333333333, 0.544...",0.541399,0.716543,0.627569,"[0.8112798264642083, 0.5714285714285714, 0.5]",[100],0.5,0.001,1234,2,False


In [42]:
best_ffn_current["f1"].mean()

0.6223547220370722

In [68]:
np.stack(best_ffn_current["f1_scores"]).mean(axis=0)

array([0.7969706 , 0.53600287, 0.5340907 ])

In [43]:
ffn_current_kfold, best_ffn_current_kfold, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in ffn_current_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in ffn_current_focal_2_kfold_best_model.csv


In [44]:
ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,model_id
0,focal,0.681784,0.594099,"[0.7857559836544076, 0.505226480836237, 0.4913...",0.545414,0.686374,0.583469,"[0.7960155911650064, 0.5026455026455026, 0.451...",[100],0.5,0.00500,0,2,True,0
1,focal,0.686989,0.582050,"[0.7977272727272727, 0.4909747292418772, 0.457...",0.534665,0.692005,0.589859,"[0.7959183673469388, 0.49382716049382713, 0.47...",[100],0.5,0.00500,0,2,True,0
2,focal,0.682528,0.575796,"[0.788538681948424, 0.5342237061769616, 0.4046...",0.552925,0.700450,0.604532,"[0.7991323210412147, 0.5532467532467532, 0.461...",[100],0.5,0.00500,0,2,True,0
3,focal,0.706320,0.613436,"[0.805491990846682, 0.5451263537906137, 0.4896...",0.577055,0.701014,0.605842,"[0.800520381613183, 0.5459317585301838, 0.4710...",[100],0.5,0.00500,0,2,True,0
4,focal,0.672862,0.596099,"[0.7738814993954052, 0.5225225225225226, 0.491...",0.548656,0.679617,0.605924,"[0.775735294117647, 0.5420353982300885, 0.5]",[100],0.5,0.00500,0,2,True,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,focal,0.653532,0.552813,"[0.7872340425531915, 0.3870967741935484, 0.484...",0.625054,0.626126,0.505284,"[0.7734855136084283, 0.3289817232375979, 0.413...","[100, 100, 100]",0.1,0.00001,1234,2,True,44
1,focal,0.640149,0.521224,"[0.7870370370370371, 0.3321917808219178, 0.444...",0.619201,0.636824,0.524574,"[0.7733450241122314, 0.37055214723926383, 0.42...","[100, 100, 100]",0.1,0.00001,1234,2,True,44
2,focal,0.626766,0.499733,"[0.7731188971855256, 0.34596375617792424, 0.38...",0.617926,0.627815,0.508627,"[0.771168041684759, 0.31865284974093266, 0.436...","[100, 100, 100]",0.1,0.00001,1234,2,True,44
3,focal,0.610409,0.487177,"[0.7679719462302746, 0.2956810631229236, 0.397...",0.638740,0.630068,0.509474,"[0.7740511915269197, 0.3742405832320777, 0.380...","[100, 100, 100]",0.1,0.00001,1234,2,True,44


In [45]:
best_ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold
0,focal,0.684758,0.60716,"[0.7886815171583383, 0.5196078431372549, 0.513...",0.55306,0.68018,0.593967,"[0.7824529991047449, 0.5286236297198538, 0.470...",[100],0.5,0.0001,0,2,True
1,focal,0.677323,0.586921,"[0.7866108786610878, 0.5255023183925811, 0.448...",0.531541,0.688626,0.606592,"[0.7841726618705036, 0.5529953917050692, 0.482...",[100],0.5,0.0001,0,2,True
2,focal,0.685502,0.579362,"[0.79133409350057, 0.5272108843537414, 0.41954...",0.550432,0.690315,0.59825,"[0.791083916083916, 0.530718954248366, 0.47294...",[100],0.5,0.0001,0,2,True
3,focal,0.692193,0.602759,"[0.7938931297709922, 0.5475409836065575, 0.466...",0.582319,0.683559,0.5917,"[0.7836153161175423, 0.5513126491646777, 0.440...",[100],0.5,0.0001,0,2,True
4,focal,0.665428,0.588839,"[0.7696019300361883, 0.4969135802469136, 0.5]",0.547304,0.675113,0.592574,"[0.7785356980445657, 0.5266821345707655, 0.472...",[100],0.5,0.0001,0,2,True
0,focal,0.681784,0.600001,"[0.7883124627310675, 0.5033783783783784, 0.508...",0.560043,0.67455,0.58056,"[0.7842786958463601, 0.5160493827160494, 0.441...",[100],0.5,0.0001,1,2,True
1,focal,0.686989,0.596568,"[0.7936132465996452, 0.5281250000000001, 0.467...",0.535979,0.69482,0.612573,"[0.7891699955614737, 0.5463071512309496, 0.502...",[100],0.5,0.0001,1,2,True
2,focal,0.680297,0.573057,"[0.7853042479908152, 0.5385878489326765, 0.395...",0.552261,0.693131,0.605103,"[0.7892888498683055, 0.548469387755102, 0.4775...",[100],0.5,0.0001,1,2,True
3,focal,0.689963,0.602343,"[0.7903699354081034, 0.5454545454545455, 0.471...",0.581351,0.682432,0.5879,"[0.7850799289520426, 0.5445783132530121, 0.434...",[100],0.5,0.0001,1,2,True
4,focal,0.664684,0.58702,"[0.7681246255242661, 0.4929356357927786, 0.5]",0.547436,0.685248,0.604926,"[0.7860696517412936, 0.5317647058823529, 0.496...",[100],0.5,0.0001,1,2,True


In [46]:
best_ffn_current_kfold["f1"].mean()

0.5924049773807214

In [67]:
np.stack(best_ffn_current_kfold["f1_scores"]).mean(axis=0)

array([0.78498169, 0.52167665, 0.47055659])

# Baseline: Averaging history and use FFN

Here, we will use `nlpsig` to construct some paths of embeddings for the last $k$ utterances which we will average and use those in a FFN. We will also concatenate the current utterance embedding to the FFN input - all of this is done in the `obtain_mean_history` function in `nlpsig-networks.scripts.ffn_baseline_functions` which we imported earlier on.

Here, we will run the hyperparameter search to implement the FFN with the same parameters as above. For this baseline, we can also see how well the model performance is affected by the size of the history window $k$.

In [24]:
window_sizes = [5, 10, 20, 50]

In [28]:
ffn_mean_history, best_ffn_mean_history, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/4 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type/ffn_mean_history_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type/ffn_mean_history_focal_2_best_model.csv


In [29]:
ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,model_id,k,input_dim
0,focal,0.713011,0.614313,"[0.8114157806379407, 0.5272727272727272, 0.504...",0.697381,0.725836,0.625205,"[0.818368745716244, 0.5365853658536585, 0.5206...",[100],0.5,0.00500,0,2,False,0.00,5,766
0,focal,0.710037,0.620470,"[0.8068181818181818, 0.5251798561151079, 0.529...",0.502735,0.717472,0.611772,"[0.8177408177408177, 0.5253863134657837, 0.492...",[100],0.5,0.00500,1,2,False,0.00,5,766
0,focal,0.667658,0.595767,"[0.7693259121830551, 0.5345622119815667, 0.483...",0.602109,0.700743,0.621087,"[0.7975921745673438, 0.5708955223880597, 0.494...",[100],0.5,0.00500,12,2,False,0.00,5,766
0,focal,0.697398,0.627495,"[0.7888161808447353, 0.538961038961039, 0.5547...",0.534903,0.716543,0.626636,"[0.812545587162655, 0.5786407766990291, 0.4887...",[100],0.5,0.00500,123,2,False,0.00,5,766
0,focal,0.709294,0.610480,"[0.8129251700680273, 0.5212765957446809, 0.497...",0.517703,0.731413,0.629395,"[0.8272033310201249, 0.552915766738661, 0.5080...",[100],0.5,0.00500,1234,2,False,0.00,5,766
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,focal,0.617100,0.265933,"[0.7633163501621121, 0.03448275862068966, 0.0]",0.756498,0.640335,0.267512,"[0.7808764940239044, 0.021660649819494584, 0.0]","[100, 100, 100]",0.1,0.00001,0,2,False,3.44,50,766
0,focal,0.667658,0.566020,"[0.7959905660377359, 0.43717728055077454, 0.46...",0.560264,0.693309,0.585954,"[0.8141720896601591, 0.48945147679324896, 0.45...","[100, 100, 100]",0.1,0.00001,1,2,False,3.44,50,766
0,focal,0.617100,0.254406,"[0.7632183908045977, 0.0, 0.0]",0.788310,0.639405,0.260163,"[0.7804878048780487, 0.0, 0.0]","[100, 100, 100]",0.1,0.00001,12,2,False,3.44,50,766
0,focal,0.683271,0.577781,"[0.801648028251913, 0.5021645021645021, 0.4295...",0.594105,0.697955,0.576923,"[0.8155619596541788, 0.51520572450805, 0.4]","[100, 100, 100]",0.1,0.00001,123,2,False,3.44,50,766


In [30]:
best_ffn_mean_history

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,k,input_dim
0,focal,0.707807,0.631739,"[0.798329355608592, 0.5847589424572317, 0.5121...",0.71838,0.736989,0.648432,"[0.8244719592134013, 0.6247619047619047, 0.496...",[100],0.5,0.001,0,2,False,50,766
0,focal,0.70855,0.635621,"[0.8004807692307692, 0.5795275590551181, 0.526...",0.549926,0.732342,0.639982,"[0.8253968253968255, 0.5984251968503937, 0.496...",[100],0.5,0.001,1,2,False,50,766
0,focal,0.709294,0.633258,"[0.7988269794721409, 0.5654281098546042, 0.535...",0.615128,0.732342,0.644632,"[0.8202166064981948, 0.6078431372549019, 0.505...",[100],0.5,0.001,12,2,False,50,766
0,focal,0.710037,0.632813,"[0.8046783625730994, 0.5382059800664452, 0.555...",0.541037,0.736989,0.641692,"[0.8297567954220315, 0.6012024048096193, 0.494...",[100],0.5,0.001,123,2,False,50,766
0,focal,0.699628,0.631056,"[0.7921760391198045, 0.5714285714285714, 0.529...",0.55633,0.719331,0.6402,"[0.8095952023988006, 0.6074074074074074, 0.503...",[100],0.5,0.001,1234,2,False,50,766


In [31]:
best_ffn_mean_history["f1"].mean()

0.6328974621408594

In [69]:
np.stack(best_ffn_mean_history["f1_scores"]).mean(axis=0)

array([0.7988983 , 0.56786983, 0.53192425])

In [32]:
ffn_mean_history_kfold, best_ffn_mean_history_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/4 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type/ffn_mean_history_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type/ffn_mean_history_focal_2_kfold_best_model.csv


In [33]:
ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,model_id,k,input_dim
0,focal,0.672862,0.601383,"[0.7743119266055046, 0.5278276481149013, 0.502...",0.536458,0.664977,0.581384,"[0.772706732941708, 0.4993141289437586, 0.4721...",[100],0.5,0.00500,0,2,True,0.00,5,766
1,focal,0.689219,0.590228,"[0.7997678467788741, 0.5151515151515151, 0.455...",0.528168,0.699887,0.612343,"[0.7982456140350876, 0.5367088607594938, 0.502...",[100],0.5,0.00500,0,2,True,0.00,5,766
2,focal,0.687732,0.565361,"[0.791759465478842, 0.5346869712351945, 0.3696...",0.533283,0.695946,0.592782,"[0.7955706984667802, 0.5286458333333333, 0.454...",[100],0.5,0.00500,0,2,True,0.00,5,766
3,focal,0.687732,0.585681,"[0.7934537246049661, 0.4970873786407767, 0.466...",0.567295,0.701577,0.594639,"[0.8050739957716702, 0.5205091937765206, 0.458...",[100],0.5,0.00500,0,2,True,0.00,5,766
4,focal,0.661710,0.585731,"[0.7608173076923077, 0.5103011093502379, 0.486...",0.545385,0.689752,0.612503,"[0.7840858292355833, 0.5374233128834355, 0.516]",[100],0.5,0.00500,0,2,True,0.00,5,766
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,focal,0.678067,0.581589,"[0.7988165680473372, 0.47694753577106525, 0.46...",0.603217,0.668919,0.557512,"[0.7909854175872735, 0.48687350835322196, 0.39...","[100, 100, 100]",0.1,0.00001,1234,2,True,3.44,50,766
1,focal,0.673606,0.563899,"[0.8027923211169284, 0.4462809917355372, 0.442...",0.598215,0.673423,0.574711,"[0.7879592740150508, 0.4847058823529411, 0.451...","[100, 100, 100]",0.1,0.00001,1234,2,True,3.44,50,766
2,focal,0.642379,0.518177,"[0.7772020725388602, 0.4147909967845659, 0.362...",0.603113,0.662162,0.542741,"[0.7956427015250545, 0.42493638676844786, 0.40...","[100, 100, 100]",0.1,0.00001,1234,2,True,3.44,50,766
3,focal,0.640149,0.538257,"[0.7726190476190476, 0.41627543035993736, 0.42...",0.621480,0.650338,0.533928,"[0.7822222222222224, 0.4493597206053551, 0.370...","[100, 100, 100]",0.1,0.00001,1234,2,True,3.44,50,766


In [34]:
best_ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,k,input_dim
0,focal,0.699628,0.629848,"[0.7965998785670917, 0.5466034755134281, 0.546...",0.535128,0.685811,0.601814,"[0.7854578096947935, 0.543030303030303, 0.4769...",[100],0.5,0.0001,0,2,True,20,766
1,focal,0.681041,0.59661,"[0.7869249394673123, 0.5405405405405405, 0.462...",0.507187,0.696509,0.624477,"[0.7877401646843549, 0.5691609977324262, 0.516...",[100],0.5,0.0001,0,2,True,20,766
2,focal,0.689219,0.591243,"[0.789838337182448, 0.5564924114671164, 0.4273...",0.534393,0.692005,0.600451,"[0.7948831054256728, 0.537966537966538, 0.4685...",[100],0.5,0.0001,0,2,True,20,766
3,focal,0.680297,0.601317,"[0.7790487658037327, 0.5457364341085271, 0.479...",0.557449,0.685248,0.597495,"[0.7831541218637994, 0.5563380281690141, 0.452...",[100],0.5,0.0001,0,2,True,20,766
4,focal,0.667658,0.60025,"[0.7642770352369381, 0.5081723625557206, 0.528...",0.528262,0.681869,0.606892,"[0.77910174152154, 0.5415730337078652, 0.5]",[100],0.5,0.0001,0,2,True,20,766
0,focal,0.689219,0.615304,"[0.7893462469733656, 0.5358851674641149, 0.520...",0.531559,0.680743,0.59429,"[0.7827260458839406, 0.5410628019323671, 0.459...",[100],0.5,0.0001,1,2,True,20,766
1,focal,0.680297,0.597075,"[0.7845503922751961, 0.5407066052227342, 0.465...",0.509566,0.685811,0.613716,"[0.7775239835541344, 0.5565819861431871, 0.507...",[100],0.5,0.0001,1,2,True,20,766
2,focal,0.675836,0.578892,"[0.7780373831775701, 0.5499181669394436, 0.408...",0.530567,0.685811,0.598425,"[0.7876344086021505, 0.5445665445665446, 0.463...",[100],0.5,0.0001,1,2,True,20,766
3,focal,0.684758,0.604041,"[0.7838651414810355, 0.5588697017268446, 0.469...",0.553035,0.681306,0.594199,"[0.7785778577857785, 0.5623529411764706, 0.441...",[100],0.5,0.0001,1,2,True,20,766
4,focal,0.667658,0.601386,"[0.7639060568603213, 0.5289747399702823, 0.511...",0.529611,0.67455,0.603511,"[0.7723880597014926, 0.5430167597765363, 0.495...",[100],0.5,0.0001,1,2,True,20,766


In [35]:
best_ffn_mean_history_kfold["f1"].mean()

0.6024489022123928

In [70]:
np.stack(best_ffn_mean_history_kfold["f1_scores"]).mean(axis=0)

array([0.78204441, 0.54354018, 0.48176212])

# Baseline: FFN using signatures

First, we dimension reduce these and then take signatures. We use the path signature as input to the FFN for classification.

We want to choose a dimension and signature depth such that the number of terms in the signature is _roughly_ 384 so that it is comparable to the number of features that we used for the previous baseline where we computed the mean of the history. Again, we are concatenating the features we obtain with the current utterance embedding.

In [71]:
dim_reduce_methods = ["gaussian_random_projection", "umap"]

## Using signature

In [131]:
signature_dimensions_and_sig_depths = [(19, 2), (7, 3), (4, 4), (3, 5)]
[signatory.signature_channels(channels, depth)
 for (channels, depth) in signature_dimensions_and_sig_depths]

[380, 399, 340, 363]

### Using UMAP

In [None]:
ffn_signature_umap, best_ffn_signature_umap, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=False,
    dim_reduce_methods=["umap"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_signature_umap_focal_{gamma}.csv",
    verbose=False
)

In [None]:
best_ffn_signature_umap["f1"].mean()

In [None]:
np.stack(best_ffn_signature_umap["f1_scores"]).mean(axis=0)

In [None]:
ffn_signature_umap_kfold, best_ffn_signature_umap_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=False,
    dim_reduce_methods=["umap"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_signature_umap_focal_{gamma}_kfold.csv",
    verbose=False
)

In [None]:
best_ffn_signature_umap_kfold["f1"].mean()

In [None]:
np.stack(best_ffn_signature_umap_kfold["f1_scores"]).mean(axis=0)

### Using random projections

In [None]:
ffn_signature_grp, best_ffn_signature_grp, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=False,
    dim_reduce_methods=["gaussian_random_projection"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_signature_grp_focal_{gamma}.csv",
    verbose=False
)

In [None]:
best_ffn_signature_grp["f1"].mean()

In [None]:
np.stack(best_ffn_signature_grp["f1_scores"]).mean(axis=0)

In [None]:
ffn_signature_grp_kfold, best_ffn_signature_grp_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=False,
    dim_reduce_methods=["gaussian_random_projection"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_signature_grp_focal_{gamma}_kfold.csv",
    verbose=False
)

In [None]:
best_ffn_signature_grp_kfold["f1"].mean()

In [None]:
np.stack(best_ffn_signature_grp_kfold["f1_scores"]).mean(axis=0)

## Using log signature

In [125]:
log_signature_dimensions_and_sig_depths = [(28, 2), (10, 3), (6, 4), (4, 5)]
[signatory.logsignature_channels(channels, depth)
 for (channels, depth) in log_signature_dimensions_and_sig_depths]

[406, 385, 406, 294]

### Using UMAP

In [None]:
ffn_logsignature_umap, best_ffn_logsignature_umap, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=True,
    dim_reduce_methods=["umap"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_umap_focal_{gamma}.csv",
    verbose=False
)

In [None]:
best_ffn_logsignature_umap["f1"].mean()

In [None]:
np.stack(best_ffn_logsignature_umap["f1_scores"]).mean(axis=0)

In [None]:
ffn_logsignature_umap_kfold, best_ffn_logsignature_umap_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=True,
    dim_reduce_methods=["umap"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_umap_focal_{gamma}_kfold.csv",
    verbose=False
)

In [None]:
best_ffn_logsignature_umap_kfold["f1"].mean()

In [None]:
np.stack(best_ffn_logsignature_umap_kfold["f1_scores"]).mean(axis=0)

### Using random projections

In [None]:
ffn_logsignature_grp, best_ffn_logsignature_grp, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=True,
    dim_reduce_methods=["gaussian_random_projection"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_grp_focal_{gamma}.csv",
    verbose=False
)

In [None]:
best_ffn_logsignature_grp["f1"].mean()

In [None]:
np.stack(best_ffn_logsignature_grp["f1_scores"]).mean(axis=0)

In [None]:
ffn_logsignature_grp_kfold, best_ffn_logsignature_grp_kfold, _, __ = histories_baseline_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    output_dim=output_dim,
    window_sizes=window_sizes,
    hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    use_signatures=True,
    log_signature=True,
    dim_reduce_methods=["gaussian_random_projection"],
    dimension_and_sig_depths=signature_dimensions_and_sig_depths,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    path_indices=client_index,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_grp_focal_{gamma}_kfold.csv",
    verbose=False
)

In [None]:
best_ffn_logsignature_grp_kfold["f1"].mean()

In [None]:
np.stack(best_ffn_logsignature_grp_kfold["f1_scores"]).mean(axis=0)

# Baseline: LSTM classification

# Baseline: Fine-tune BERT for classification

In [21]:
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding
)

In [22]:
pretrained_model = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(
    pretrained_model,
    num_labels=output_dim,
    id2label=id_to_label,
    label2id=label_to_id
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

We need to make a column in `anno_mi` called `label`:

In [23]:
df = anno_mi[client_index].reset_index(drop=True)
df["label"] = df["client_talk_type"].apply(lambda x: label_to_id[x])
df.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime,label
0,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-06-09 00:00:24,0
1,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-06-09 00:00:34,0
2,high,0,reducing alcohol consumption,5,client,00:00:39,Usually three drinks and glasses of wine.,3,,,,,,,,neutral,2023-06-09 00:00:39,0
3,high,0,reducing alcohol consumption,7,client,00:00:46,Something like that.,3,,,,,,,,neutral,2023-06-09 00:00:46,0
4,high,0,reducing alcohol consumption,9,client,00:01:03,Okay.,3,,,,,,,,neutral,2023-06-09 00:01:03,0


In [24]:
fine_tune_bert = nlpsig.TextEncoder(df=df,
                                    feature_name="utterance_text",
                                    model=model,
                                    tokenizer=tokenizer,
                                    data_collator=data_collator)

In [25]:
fine_tune_bert.tokenize_text()

[INFO] Setting return_special_tokens_mask=True
[INFO] Tokenizing the dataset...


Map:   0%|          | 0/6725 [00:00<?, ? examples/s]

[INFO] Saving the tokenized text for each sentence into `.df['tokens']`...


Map:   0%|          | 0/6725 [00:00<?, ? examples/s]

[INFO] Creating tokenized dataframe and setting in `.tokenized_df` attribute...
[INFO] Note: 'text_id' is the column name for denoting the corresponding text id


Dataset({
    features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
    num_rows: 6725
})

In [26]:
splits = DataSplits(x_data=sbert_embeddings[client_index],
                    y_data=torch.tensor(y_data))

In [27]:
type(splits.indices)

tuple

In [28]:
len(splits.indices)

3

In [29]:
label_to_id.keys()

dict_keys(['neutral', 'change', 'sustain'])

In [30]:
fine_tune_bert.split_dataset(indices=splits.indices)

DatasetDict({
    train: Dataset({
        features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
        num_rows: 3604
    })
    test: Dataset({
        features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
        num_rows: 1345
    })
    validation: Dataset({
        f

In [31]:
fine_tune_bert.dataset_split["train"]

Dataset({
    features: ['mi_quality', 'transcript_id', 'topic', 'utterance_id', 'interlocutor', 'timestamp', 'utterance_text', 'annotator_id', 'therapist_input_exists', 'therapist_input_subtype', 'reflection_exists', 'reflection_subtype', 'question_exists', 'question_subtype', 'main_therapist_behaviour', 'client_talk_type', 'datetime', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'tokens'],
    num_rows: 3604
})

In [32]:
model_name = "fine-tuned-bert-anno-mi-client"
fine_tune_bert.set_up_training_args(output_dir=model_name,
                                    num_train_epochs=600,
                                    per_device_train_batch_size=128,
                                    disable_tqdm=False,
                                    save_strategy="steps",
                                    save_steps=10000,
                                    seed=seed)

[INFO] Setting up TrainingArguments object and saving to `.training_args`.


TrainingArguments(
_n_gpu=0,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=epoch,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ign

In [33]:

def compute_metrics(eval_pred):
    accuracy = evaluate.load("accuracy")
    f1 = evaluate.load("f1")
    predictions = np.argmax(eval_pred.predictions, axis=1)
    accuracy = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)['accuracy']
    f1 = f1.compute(predictions=predictions, references=eval_pred.label_ids)['f1']
    return {"accuracy": accuracy, "f1": f1}

In [34]:
fine_tune_bert.set_up_trainer(data_collator=data_collator,
                              compute_metrics=compute_metrics)

[INFO] Setting up Trainer object, and saving to `.trainer`.


<transformers.trainer.Trainer at 0x2d6c3dbb0>

In [None]:
fine_tune_bert.fit_transformer_with_trainer_api()

In [45]:
def compute_classification_accuracy(model,
                                    test_dataset,
                                    feature_name):
    # loop through test set and make prediction from model
    predicted = [None for i in range(len(test_dataset))]
    for i in tqdm(range(len(test_dataset))):
        inputs = tokenizer(test_dataset[feature_name][i],
                           return_tensors="pt")
        with torch.no_grad():
            logits = model(**inputs).logits
        predicted[i] = logits.argmax().item()

    # convert to torch tensor
    predicted = torch.tensor(predicted)
    labels = torch.tensor(test_dataset["label"])
    
    # compute accuracy
    accuracy = ((predicted == labels).sum() / len(labels)).item()
    # compute F1
    f1_scores = metrics.f1_score(labels, predicted, average=None)
    f1 = sum(f1_scores)/len(f1_scores)
    
    # print evaluation metrics
    print(
        f"Accuracy on dataset of size {len(labels)}: "
        f"{100 * accuracy} %."
    )
    print(f"- f1: {f1_scores}")
    print(f"- f1 (macro): {f1}")
        
    return {"accuracy": accuracy,
            "f1": f1,
            "f1_scores": f1_scores}

In [46]:
compute_classification_accuracy(model=fine_tune_bert.model,
                                test_dataset=fine_tune_bert.dataset_split["test"],
                                feature_name="utterance_text")

  0%|          | 0/1345 [00:00<?, ?it/s]

Accuracy on dataset of size 1345: 33.382898569107056 %.
- f1: [0.37770898 0.35405872 0.        ]
- f1 (macro): 0.24392256675418103


{'accuracy': 0.33382898569107056,
 'f1': 0.24392256675418103,
 'f1_scores': array([0.37770898, 0.35405872, 0.        ])}

'neutral'

In [145]:
fine_tune_bert.dataset_split["test"]["client_talk_type"]

['neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'change',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'change',
 'change',
 'neutral',
 'change',
 'change',
 'change',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'change',
 'neutral',
 'neutral',
 'neutral',
 'neutral',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'change',
 'neutral',
 'neutral',
 'change',
 'neutral',
 'change',
 'neutral',
 'neutral',
 'sustain',
 'neutral',
 'sustain',
 'change',
 'neutral',
 'change',
 'neutral',
 'change',
 'neutral',
 'neutral',
 'change',
 'change',
 'change',
 'neutral',
 'neutral',
 'change',
 'neutral',
 'sustain',
 'neutral',
 'change',
 'change',
 'change',
 'change',
 'neutral',
 'neutral',
 'change',
 'neutral',
 'neutral',
 'n

# SWNU Network

## Obtaining path by looking at post history

We can obtain a path by looking at the history of each post. Here we look at the last 10 posts (and pad with vectors of zeros if there are less than 10 posts) including the current post.

We only want to consider paths that correspond to a client's utterance as we want to model a change in mood at that time. Their history will still contain the therapist's utterances too.

In [51]:
time_features = ["time_encoding", "timeline_index"]
standardise_method = ["minmax", None]

In [26]:
sig_depths = [2,3]
dim_reduce_methods = ["gaussian_random_projection", "umap"]

In [27]:
output_dim = len(label_to_id)
lstm_hidden_dims = [[8,8], [12,12,8]]
num_time_features = len(time_features)
conv_output_channels = [20, 10, 5]
learning_rate = 1e-4

In [28]:
embedding_dim = 384
dimensions = [embedding_dim, 100, 50, 30]

In [29]:
signatory.signature_channels(10, 2)

110

In [30]:
signatory.logsignature_channels(12, 3)

650

In [31]:
import math
signature_dimensions_and_sig_depths = [(math.ceil(19/2), 2), (math.ceil(7/2), 3)]

In [32]:
signature_dimensions_and_sig_depths

[(10, 2), (4, 3)]

In [33]:
log_signature_dimensions_and_sig_depths = [(28, 2), (10, 3), (6, 4)]
bilstm_log_signature_dimensions_and_sig_depths = [(int(28/2), 2), (int(10/2), 3), (int(6/2), 4)]

In [34]:
log_signature_dimensions_and_sig_depths

[(28, 2), (10, 3), (6, 4)]

In [35]:
bilstm_log_signature_dimensions_and_sig_depths

[(14, 2), (5, 3), (3, 4)]

In [52]:
from __future__ import annotations

import nlpsig
from nlpsig.classification_utils import DataSplits, Folds
from nlpsig_networks.pytorch_utils import SaveBestModel, training_pytorch, testing_pytorch, set_seed, KFold_pytorch
from nlpsig_networks.snwu_network import SWNUNetwork
from nlpsig_networks.focal_loss import FocalLoss
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os


def obtain_SWNUNetwork_input(
    method: str,
    dimension: int,
    df: pd.DataFrame,
    id_column: str,
    label_column: str,
    embeddings: np.array,
    k: int,
    time_feature: list[str] | str | None = None,
    standardise_method: list[str] | str | None = None,
    seed: int = 42,
    path_indices: np.array | None = None
) -> tuple[torch.tensor, int]:
    # use nlpsig to construct the path as a numpy array
    # first define how we construct the path
    path_specifics = {"pad_by": "history",
                      "zero_padding": True,
                      "method": "k_last",
                      "k": k,
                      "time_feature": time_feature,
                      "standardise_method": standardise_method,
                      "embeddings": "dim_reduced",
                      "include_current_embedding": True}
    
    # first perform dimension reduction on embeddings
    if dimension == embeddings.shape[1]:
        # no need to perform dimensionality reduction
        embeddings_reduced = embeddings
    else:
        reduction = nlpsig.DimReduce(method=method,
                                     n_components=dimension)
        embeddings_reduced = reduction.fit_transform(embeddings,
                                                     random_state=seed)
    
    # obtain path by using PrepareData class and .pad method
    paths = nlpsig.PrepareData(df,
                               id_column=id_column,
                               label_column=label_column,
                               embeddings=embeddings,
                               embeddings_reduced=embeddings_reduced)
    paths.pad(**path_specifics)
    
    # slice the path in specified way
    if path_indices is not None:
        paths.array_padded = paths.array_padded[path_indices]
        paths.embeddings = paths.embeddings[path_indices]
        paths.embeddings_reduced = paths.embeddings_reduced[path_indices]
    
    return paths.get_torch_path_for_SWNUNetwork(
        include_time_features_in_path=True,
        include_time_features_in_input=True,
        include_embedding_in_input=True,
        reduced_embeddings=False
    )
    
def implement_swnu_network(
    num_epochs: int,
    x_data: torch.tensor | np.array,
    y_data: torch.tensor | np.array,
    input_channels: int,
    output_channels: int,
    num_time_features: int,
    embedding_dim: int,
    log_signature: bool,
    sig_depth: int,
    lstm_hidden_dim: list[int] | int,
    ffn_hidden_dim: list[int] | int,
    output_dim: int,
    BiLSTM: bool,
    dropout_rate: float,
    learning_rate: float,
    seed: int,
    loss: str,
    gamma: float = 0.0,
    augmentation_type: str = "Conv1d",
    comb_method: str = "concatenation",
    data_split_seed: int = 0,
    k_fold: bool = False,
    n_splits: int = 5,
    verbose_training: bool = True,
    verbose_results: bool = True,
    verbose_model: bool = False,
) -> tuple[SWNUNetwork, pd.DataFrame]:
    # set seed
    set_seed(seed)
    
    # initialise SWNUNetwork
    SWNUNetwork_args = {
        "input_channels": input_channels,
        "output_channels": output_channels,
        "num_time_features": num_time_features,
        "embedding_dim": embedding_dim,
        "log_signature": log_signature,
        "sig_depth": sig_depth,
        "hidden_dim_swnu": lstm_hidden_dim,
        "hidden_dim_ffn": ffn_hidden_dim,
        "output_dim": output_dim,
        "dropout_rate": dropout_rate,
        "augmentation_type": augmentation_type,
        "BiLSTM": BiLSTM,
        "comb_method": comb_method
    }
    swnu_network_model = SWNUNetwork(**SWNUNetwork_args)
    
    if verbose_model:
        print(swnu_network_model)
    
    # convert data to torch tensors
    if not isinstance(x_data, torch.Tensor):
        x_data = torch.tensor(x_data)
    if not isinstance(y_data, torch.Tensor):
        y_data = torch.tensor(y_data)
    x_data = x_data.float()
    
    # set some variables for training
    save_best = True
    early_stopping = True
    model_output = "best_model.pkl"
    validation_metric = "f1"
    patience = 10
    
    if k_fold:
        # perform KFold evaluation and return the performance on validation and test sets
        # split dataset
        folds = Folds(x_data=x_data,
                      y_data=y_data,
                      n_splits=n_splits,
                      shuffle=True,
                      random_state=data_split_seed)
        
         # define loss
        if loss == "focal":
            criterion = FocalLoss(gamma = gamma)
        elif loss == "cross_entropy":
            criterion = torch.nn.CrossEntropyLoss()
        else:
            raise ValueError("criterion must be either 'focal' or 'cross_entropy'")

        # define optimizer
        optimizer = torch.optim.Adam(swnu_network_model.parameters(), lr=learning_rate)
        
        # perform k-fold evaluation which returns a dataframe with columns for the
        # loss, accuracy, f1 (macro) and individual f1-scores for each fold
        # (for both validation and test set)
        results = KFold_pytorch(folds=folds,
                                model=swnu_network_model,
                                criterion=criterion,
                                optimizer=optimizer,
                                num_epochs=num_epochs,
                                seed=seed,
                                save_best=save_best,
                                early_stopping=early_stopping,
                                validation_metric=validation_metric,
                                patience=patience,
                                verbose=verbose_training)
    else:
        # split dataset
        split_data = DataSplits(x_data=x_data,
                                y_data=y_data,
                                train_size=0.8,
                                valid_size=0.2,
                                shuffle=True,
                                random_state=data_split_seed)
        train, valid, test = split_data.get_splits(as_DataLoader=True)

        # define loss
        if loss == "focal":
            criterion = FocalLoss(gamma = gamma)
            y_train = split_data.get_splits(as_DataLoader=False)[1]
            criterion.set_alpha_from_y(y=y_train)
        elif loss == "cross_entropy":
            criterion = torch.nn.CrossEntropyLoss()
        else:
            raise ValueError("criterion must be either 'focal' or 'cross_entropy'")

        # define optimizer
        optimizer = torch.optim.Adam(swnu_network_model.parameters(), lr=learning_rate)
        
        # train FFN
        swnu_network_model = training_pytorch(model=swnu_network_model,
                                      train_loader=train,
                                      criterion=criterion,
                                      optimizer=optimizer,
                                      num_epochs=num_epochs,
                                      valid_loader=valid,
                                      seed=seed,
                                      save_best=save_best,
                                      output=model_output,
                                      early_stopping=early_stopping,
                                      validation_metric=validation_metric,
                                      patience=patience,
                                      verbose=verbose_training)
        
        # evaluate on validation
        test_results = testing_pytorch(model=swnu_network_model,
                                       test_loader=test,
                                       criterion=criterion,
                                       verbose=False)
        
        # evaluate on test
        valid_results = testing_pytorch(model=swnu_network_model,
                                        test_loader=valid,
                                        criterion=criterion)
        
        results = pd.DataFrame({"loss": test_results["loss"],
                                "accuracy": test_results["accuracy"], 
                                "f1": test_results["f1"],
                                "f1_scores": test_results["f1_scores"],
                                "valid_loss": valid_results["loss"],
                                "valid_accuracy": valid_results["accuracy"], 
                                "valid_f1": valid_results["f1"],
                                "valid_f1_scores": valid_results["f1_scores"]})

    if verbose_results:
        with pd.option_context('display.precision', 3):
            print(results)
            
    # remove any models that have been saved
    if os.path.exists(model_output):
        os.remove(model_output)
        
    return swnu_network_model, results


def swnu_network_hyperparameter_search(
    num_epochs: int,
    df: pd.DataFrame,
    id_column: str,
    label_column: str,
    embeddings: np.array,
    y_data: np.array,
    embedding_dim: int,
    output_dim: int,
    window_sizes: list[int],
    dim_reduce_methods: list[str],
    dimensions: list[int],
    sig_depths: list[int],
    log_signature: bool,
    conv_output_channels: list[int],
    swnu_hidden_dim_sizes: list[int] | list[list[int]],
    ffn_hidden_dim_sizes: list[int] | list[list[int]],
    dropout_rates: list[float],
    learning_rates: list[float],
    BiLSTM,
    seeds : list[int],
    loss: str,
    gamma: float = 0.0,
    time_feature: list[str] | str | None = None,
    standardise_method: list[str] | str | None = None,
    augmentation_type: str = "Conv1d",
    comb_method: str = "concatenation",
    path_indices: np.array | None = None,
    data_split_seed: int = 0,
    k_fold: bool = False,
    n_splits: int = 5,
    validation_metric: str = "f1",
    results_output: str | None = None,
    verbose: bool = True
):
    if validation_metric not in ["loss", "accuracy", "f1"]:
        raise ValueError("validation_metric must be either 'loss', 'accuracy' or 'f1'")
    
    # initialise SaveBestModel class
    model_output = "best_swnu_network_model.pkl",
    save_best_model = SaveBestModel(metric=validation_metric,
                                    output=model_output,
                                    verbose=verbose)
    
    results_df = pd.DataFrame()
    model_id = 0
    
    for k in tqdm(window_sizes):
        if verbose:
            print("\n" + "-" * 50)
            print(f"k: {k}")
        for dimension in tqdm(dimensions):
            for method in tqdm(dim_reduce_methods):
                print("\n" + "#" * 50)
                print(f"dimension: {dimension} | "
                      f"method: {method}")
                x_data, input_channels = obtain_SWNUNetwork_input(
                    method=method,
                    dimension=dimension,
                    df=df,
                    id_column=id_column,
                    label_column=label_column,
                    embeddings=embeddings,
                    k=k,
                    time_feature=time_feature,
                    standardise_method=standardise_method,
                    path_indices=path_indices
                )
        
                for lstm_hidden_dim in tqdm(swnu_hidden_dim_sizes):
                    for ffn_hidden_dim in tqdm(ffn_hidden_dim_sizes):
                        for sig_depth in sig_depths:
                            for output_channels in tqdm(conv_output_channels):
                                for dropout in tqdm(dropout_rates):
                                    for lr in tqdm(learning_rates):
                                        if verbose:
                                            print("\n" + "!" * 50)
                                            print(f"lstm_hidden_dim: {lstm_hidden_dim} | "
                                                  f"ffn_hidden_dim: {ffn_hidden_dim} | "
                                                  f"sig_depth: {sig_depth} | "
                                                  f"output_channels: {output_channels} | "
                                                  f"dropout: {dropout} | "
                                                  f"learning_rate: {lr}")
                                        scores = []
                                        verbose_model = verbose
                                        for seed in seeds:
                                            _, results = implement_swnu_network(
                                                num_epochs=num_epochs,
                                                x_data=x_data,
                                                y_data=y_data,
                                                input_channels=input_channels,
                                                output_channels=output_channels,
                                                num_time_features=len(time_feature),
                                                embedding_dim=embedding_dim,
                                                log_signature=log_signature,
                                                sig_depth=sig_depth,
                                                lstm_hidden_dim=lstm_hidden_dim,
                                                ffn_hidden_dim=ffn_hidden_dim,
                                                output_dim=output_dim,
                                                BiLSTM=BiLSTM,
                                                dropout_rate=dropout,
                                                learning_rate=lr,
                                                seed=seed,
                                                loss=loss,
                                                gamma=gamma,
                                                augmentation_type=augmentation_type,
                                                comb_method=comb_method,
                                                data_split_seed=data_split_seed,
                                                k_fold=k_fold,
                                                n_splits=n_splits,
                                                verbose_training=False,
                                                verbose_results=verbose,
                                                verbose_model=verbose_model
                                            )
                                            # save metric that we want to validate on
                                            # taking the mean over the performance on the folds for the seed
                                            # if k_fold=False, .mean() just returns the performance for the seed
                                            scores.append(results[f"valid_{validation_metric}"].mean())
                                            
                                            # concatenate to results dataframe
                                            results["k"] = k
                                            results["dimensions"] = dimension
                                            results["sig_depth"] = sig_depth
                                            results["method"] = method
                                            results["input_channels"] = input_channels
                                            results["output_channels"] = output_channels
                                            results["num_time_features"] = num_time_features
                                            results["embedding_dim"] = embedding_dim
                                            results["log_signature"] = log_signature
                                            results["lstm_hidden_dim"] = [lstm_hidden_dim for _ in range(len(results.index))]
                                            results["ffn_hidden_dim"] = [ffn_hidden_dim for _ in range(len(results.index))]
                                            results["dropout_rate"] = dropout
                                            results["learning_rate"] = lr
                                            results["seed"] = seed
                                            results["BiLSTM"] = BiLSTM
                                            results["loss"] = loss
                                            results["gamma"] = gamma
                                            results["k_fold"] = k_fold
                                            results["augmentation_type"] = augmentation_type
                                            results["comb_method"] = comb_method
                                            results["model_id"] = model_id
                                            results_df = pd.concat([results_df, results])
                                            
                                            # don't continue printing out the model
                                            verbose_model = False

                                        model_id += 1
                                        scores_mean = sum(scores)/len(scores)
                                        
                                        print(f"- average{' (kfold)' if k_fold else ''} "
                                              f"(validation) metric score: {scores_mean}")
                                        print(f"scores for the different seeds: {scores}")
                                        # save best model according to averaged metric over the different seeds
                                        save_best_model(current_valid_metric=scores_mean,
                                                        extra_info={
                                                            "k": k,
                                                            "dimensions": dimension,
                                                            "sig_depth": sig_depth,
                                                            "method": method,
                                                            "input_channels": input_channels,
                                                            "output_channels": output_channels,
                                                            "num_time_features": num_time_features,
                                                            "embedding_dim": embedding_dim,
                                                            "log_signature": log_signature,
                                                            "lstm_hidden_dim": lstm_hidden_dim,
                                                            "ffn_hidden_dim": ffn_hidden_dim,
                                                            "dropout_rate": dropout,
                                                            "learning_rate": lr,
                                                            "BiLSTM": BiLSTM,
                                                            "loss": loss,
                                                            "gamma": gamma,
                                                            "augmentation_type": augmentation_type,
                                                            "comb_method": comb_method
                                                        })

    checkpoint = torch.load(f=model_output)
    if verbose:
        print("*" * 50)
        print("The best model had the following parameters:")
        print(checkpoint["extra_info"])

    x_data, input_channels = obtain_SWNUNetwork_input(method=checkpoint["extra_info"]["method"],
                                               dimension=checkpoint["extra_info"]["k"],
                                               df=df,
                                               id_column=id_column,
                                               label_column=label_column,
                                               embeddings=embeddings,
                                               k=checkpoint["extra_info"]["k"],
                                               path_indices=path_indices)

    test_scores = []
    test_results_df = pd.DataFrame()
    for seed in seeds:
        _, test_results = implement_swnu_network(
            num_epochs=num_epochs,
            x_data=x_data,
            y_data=y_data,
            sig_depth=checkpoint["extra_info"]["sig_depth"],
            input_channels=checkpoint["extra_info"]["input_channels"],
            output_channels=checkpoint["extra_info"]["output_channels"],
            num_time_features=len(time_feature),
            embedding_dim=embedding_dim,
            log_signature=checkpoint["extra_info"]["log_signature"],
            output_dim=output_dim,
            lstm_hidden_dim=checkpoint["extra_info"]["lstm_hidden_dim"],
            ffn_hidden_dim=checkpoint["extra_info"]["ffn_hidden_dim"],
            BiLSTM=checkpoint["extra_info"]["BiLSTM"],
            dropout_rate=checkpoint["extra_info"]["dropout_rate"],
            learning_rate=checkpoint["extra_info"]["learning_rate"],
            seed=seed,
            loss=checkpoint["extra_info"]["loss"],
            gamma=checkpoint["extra_info"]["gamma"],
            augmentation_type=checkpoint["extra_info"]["augmentation_type"],
            comb_method=checkpoint["extra_info"]["comb_method"],
            data_split_seed=data_split_seed,
            k_fold=k_fold,
            n_splits=n_splits,
            verbose_training=False,
            verbose_results=False,
            verbose_model=False
        )

        # save metric that we want to validate on
        # taking the mean over the performance on the folds for the seed
        # if k_fold=False, .mean() just returns the performance for the seed
        test_scores.append(test_results[validation_metric].mean())
        
        # concatenate to results dataframe
        test_results["k"] = checkpoint["extra_info"]["k"]
        test_results["dimensions"] = checkpoint["extra_info"]["dimensions"]
        test_results["sig_depth"] = checkpoint["extra_info"]["sig_depth"]
        test_results["method"] = checkpoint["extra_info"]["method"]
        test_results["input_channels"] = checkpoint["extra_info"]["input_channels"]
        test_results["output_channels"] = checkpoint["extra_info"]["output_channels"]
        test_results["num_time_features"] = len(time_feature)
        test_results["embedding_dim"] = embedding_dim
        test_results["log_signature"] = checkpoint["extra_info"]["log_signature"]
        test_results["lstm_hidden_dim"] = [checkpoint["extra_info"]["lstm_hidden_dim"]
                                           for _ in range(len(test_results.index))]
        test_results["ffn_hidden_dim"] = [checkpoint["extra_info"]["ffn_hidden_dim"]
                                          for _ in range(len(test_results.index))]
        test_results["dropout_rate"] = checkpoint["extra_info"]["dropout_rate"]
        test_results["learning_rate"] = checkpoint["extra_info"]["learning_rate"]
        test_results["seed"] = seed
        test_results["BiLSTM"] = checkpoint["extra_info"]["BiLSTM"]
        test_results["loss"] = checkpoint["extra_info"]["loss"]
        test_results["gamma"] = checkpoint["extra_info"]["gamma"]
        test_results["k_fold"] = k_fold
        test_results["augmentation_type"] = checkpoint["extra_info"]["augmentation_type"]
        test_results["comb_method"] = checkpoint["extra_info"]["comb_method"]
        test_results_df = pd.concat([test_results_df, test_results])
        
    test_scores_mean = sum(test_scores)/len(test_scores)
    if verbose:
        print(f"best validation score: {save_best_model.best_valid_metric}")
        print(f"- Best model: average (test) metric score: {test_scores_mean}")
        print(f"scores for the different seeds: {test_scores}")
        
    if results_output is not None:
        print("saving results dataframe to CSV for this "
            f"hyperparameter search in {results_output}")
        results_df.to_csv(results_output)
    
    # remove any models that have been saved
    if os.path.exists(model_output):
        os.remove(model_output)
    
    return results_df, test_results_df, save_best_model.best_valid_metric, checkpoint["extra_info"]

In [53]:
from __future__ import annotations
import signatory
import torch
import torch.nn as nn
from nlpsig_networks.swnu import SWNU


class SWNUNetwork(nn.Module):
    """
    Stacked Deep Signature Neural Network for classification.
    """

    def __init__(
        self,
        input_channels: int,
        output_channels: int,
        num_time_features: int,
        embedding_dim: int,
        log_signature: bool,
        sig_depth: int,
        hidden_dim_swnu: list[int] | int,
        hidden_dim_ffn: list[int] | int,
        output_dim: int,
        dropout_rate: float,
        augmentation_type: str = "Conv1d",
        augmentation_args: dict | None = None,
        hidden_dim_aug: list[int] | int | None = None,
        BiLSTM: bool = False,
        comb_method: str = "gated_addition",
    ):
        """
        SWNU network for classification.

        Parameters
        ----------
        input_channels : int
            Dimension of the embeddings that will be passed in.
        output_channels : int
            Requested dimension of the embeddings after convolution layer.
        num_time_features : int
            Number of time features to add to FFN input. If none, set to zero.
        embedding_dim : int
            Dimension of embedding to add to FFN input. If none, set to zero.
        log_signature : bool
            Whether or not to use the log signature or standard signature.
        sig_depth : int
            The depth to truncate the path signature at.
        hidden_dim_swnu : list[int] | int
            Dimensions of the hidden layers in the SNWU blocks.
        hidden_dim_ffn : list[int] | int
            Dimension of the hidden layers in the FFN.
        output_dim : int
            Dimension of the output layer in the FFN.
        dropout_rate : float
            Dropout rate in the FFN.
        augmentation_type : str, optional
            Method of augmenting the path, by default "Conv1d".
            Options are:
            - "Conv1d": passes path through 1D convolution layer.
            - "signatory": passes path through `Augment` layer from `signatory` package.
        augmentation_args : dict | None, optional
            Arguments to pass into `torch.Conv1d` or `signatory.Augment`, by default None.
            If None, by default will set `kernel_size=3`, `stride=1`, `padding=0`.
        hidden_dim_aug : list[int] | int | None
            Dimensions of the hidden layers in the augmentation layer.
            Passed into `Augment` class from `signatory` package if
            `augmentation_type='signatory'`, by default None.
        BiLSTM : bool, optional
            Whether or not a birectional LSTM is used,
            by default False (unidirectional LSTM is used in this case).
        comb_method : str, optional
            Determines how to combine the path signature and embeddings,
            by default "gated_addition".
            Options are:
            - concatenation: concatenation of path signature and embedding vector
            - gated_addition: element-wise addition of path signature and embedding vector
        """
        super(SWNUNetwork, self).__init__()
        
        # dimensionality reduction on the input prior to SWNU
        self.input_channels = input_channels
        self.augmentation_type = augmentation_type
        if isinstance(hidden_dim_aug, int):
            hidden_dim_aug = [hidden_dim_aug]
        elif hidden_dim_aug is None:
            hidden_dim_aug = []
        self.hidden_dim_aug = hidden_dim_aug
        if augmentation_args is None:
            augmentation_args = {"kernel_size": 3,
                                 "stride": 1,
                                 "padding": 1}
        # convolution
        self.conv = nn.Conv1d(
            in_channels=input_channels,
            out_channels=output_channels,
            **augmentation_args,
        )
        # alternative to convolution: using Augment from signatory 
        self.augment = signatory.Augment(
            in_channels=input_channels,
            layer_sizes=self.hidden_dim_aug + [output_channels],
            include_original=False,
            include_time=False,
            **augmentation_args,
        )
        # non-linearity
        self.tanh1 = nn.Tanh()
        
        # signature window network unit to obtain feature set for FFN
        if isinstance(hidden_dim_swnu, int):
            hidden_dim_swnu = [hidden_dim_swnu]

        self.swnu = SWNU(input_size=output_channels,
                         hidden_dim=hidden_dim_swnu,
                         log_signature=log_signature,
                         sig_depth=sig_depth,
                         BiLSTM=BiLSTM)
        
        # signature without lift (for passing into FFN)
        mult = 2 if BiLSTM else 1
        if log_signature:
            signature_output_channels = signatory.logsignature_channels(
                in_channels=mult * hidden_dim_swnu[-1], depth=sig_depth
            )
        else:
            signature_output_channels = signatory.signature_channels(
                channels=mult * hidden_dim_swnu[-1], depth=sig_depth
            )
        
        # determining how to concatenate features to the SWNU features
        self.embedding_dim = embedding_dim
        self.num_time_features = num_time_features
        if comb_method not in ["concatenation", "gated_addition"]:
            raise ValueError(
                "`comb_method` must be either 'concatenation' or 'gated_addition'."
            )
        self.comb_method = comb_method
        if augmentation_type not in ["Conv1d", "signatory"]:
            raise ValueError("`augmentation_type` must be 'Conv1d' or 'signatory'.")
        
        # find dimension of features to pass through FFN
        if self.comb_method == "concatenation":
            input_dim = (
                signature_output_channels
                + self.num_time_features
                + self.embedding_dim
            )
        elif self.comb_method == "gated_addition":
            input_dim = self.embedding_dim
            input_gated_linear = (
                signature_output_channels
                + self.num_time_features
            )
            if self.embedding_dim > 0:
                self.fc_scale = nn.Linear(input_gated_linear, self.embedding_dim)
                self.scaler = torch.nn.Parameter(torch.zeros(1, self.embedding_dim))
            else:
                self.fc_scale = nn.Linear(input_gated_linear, input_gated_linear)
                self.scaler = torch.nn.Parameter(torch.zeros(1, input_gated_linear))
            # non-linearity
            self.tanh2 = nn.Tanh()

        # FFN for classification
        if isinstance(hidden_dim_ffn, int):
            hidden_dim_ffn = [hidden_dim_ffn]
        self.hidden_dim_ffn = hidden_dim_ffn
        
        # FFN: input layer
        self.ffn_input_layer = nn.Linear(input_dim, self.hidden_dim_ffn[0])
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        input_dim = self.hidden_dim_ffn[0]
        
        # FFN: hidden layers
        self.ffn_linear_layers = []
        self.ffn_non_linear_layers = []
        self.dropout_layers = []
        for l in range(len(self.hidden_dim_ffn)):
            self.ffn_linear_layers.append(nn.Linear(input_dim, self.hidden_dim_ffn[l]))
            self.ffn_non_linear_layers.append(nn.ReLU())
            self.dropout_layers.append(nn.Dropout(dropout_rate))
            input_dim = self.hidden_dim_ffn[l]
        
        self.ffn_linear_layers = nn.ModuleList(self.ffn_linear_layers)
        self.ffn_non_linear_layers = nn.ModuleList(self.ffn_non_linear_layers)
        self.dropout_layers = nn.ModuleList(self.dropout_layers)
        
        # FFN: readout
        self.ffn_final_layer = nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor):
        # x has dimensions [batch, length of signal, channels]

        # convolution
        if self.augmentation_type == "Conv1d":
            # input has dimensions [batch, length of signal, channels]
            # swap dimensions to get [batch, channels, length of signal]
            # (nn.Conv1d expects this)
            out = torch.transpose(x, 1, 2)
            # get only the path information
            out = self.conv(out[:, : self.input_channels, :])
            out = self.tanh1(out)
            # make output have dimensions [batch, length of signal, channels]
            out = torch.transpose(out, 1, 2)
        elif self.augmentation_type == "signatory":
            # input has dimensions [batch, length of signal, channels]
            # (signatory.Augment expects this)
            # and get only the path information
            # output has dimensions [batch, length of signal, channels]
            out = self.augment(x[:, :, : self.input_channels])

        # use SWNU to obtain feature set
        out = self.swnu(out)

        # combine last post embedding
        if x.shape[2] > self.input_channels:
            # we have things to concatenate to the path
            if self.comb_method == "concatenation":
                if self.num_time_features > 0:
                    # concatenate any time features
                    # take the maximum for the latest time
                    out = torch.cat(
                        (
                            out,
                            x[
                                :,
                                :,
                                self.input_channels : (
                                    self.input_channels + self.num_time_features
                                ),
                            ].max(1)[0],
                        ),
                        dim=1,
                    )
                if x.shape[2] > self.input_channels + self.num_time_features:
                    # concatenate current post embedding if provided
                    out = torch.cat(
                        (
                            out,
                            x[:, 0, (self.input_channels + self.num_time_features) :],
                        ),
                        dim=1,
                    )
            elif self.comb_method == "gated_addition":
                if self.num_time_features > 0:
                    # concatenate any time features
                    out_gated = torch.cat(
                        (
                            out,
                            x[
                                :,
                                :,
                                self.input_channels : (
                                    self.input_channels + self.num_time_features
                                ),
                            ].max(1)[0],
                        ),
                        dim=1,
                    )
                else:
                    out_gated = out
                out_gated = self.fc_scale(out_gated.float())
                out_gated = self.tanh2(out_gated)
                out_gated = torch.mul(self.scaler, out_gated)
                if x.shape[2] > self.input_channels + self.num_time_features:
                    # concatenate current post embedding if provided
                    out = (
                        out_gated
                        + x[:, 0, (self.input_channels + self.num_time_features) :],
                    )
                else:
                    out = out_gated

        # FFN: input layer
        out = self.ffn_input_layer(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        # FFN: hidden layers    
        for l in range(len(self.hidden_dim_ffn)):
            out = self.ffn_linear_layers[l](out)
            out = self.ffn_non_linear_layers[l](out)
            out = self.dropout_layers[l](out)

        # FFN: readout
        out = self.ffn_final_layer(out)

        return out


In [54]:
from __future__ import annotations
from signatory import Signature, LogSignature, signature_channels, logsignature_channels
import torch
import torch.nn as nn


class SWNU(nn.Module):
    """
    Signature Window Network Unit.
    """
    
    def __init__(
        self,
        input_size: int,
        log_signature: bool,
        sig_depth: int,
        hidden_dim: list[int] | int,
        BiLSTM: bool = False,
    ):
        """
        Applies a multi-layer Signature Window Network Unit (SWNU) to
        an input sequence.

        Parameters
        ----------
        input_size : int
            The number of expected features in the input x.
        log_signature : bool
            Whether or not to use the log signature or standard signature.
        sig_depth : int
            The depth to truncate the path signature at.
        hidden_dim : list[int] | int
            Dimensions of the hidden layers in the LSTM blocks in the SWNU.
        BiLSTM : bool, optional
            Whether or not a birectional LSTM is used for the final SWNU block,
            by default False (unidirectional LSTM is used in this case).
        """
        super(SWNU, self).__init__()
        
        # logging inputs to the class
        self.input_size = input_size
        self.log_signature = log_signature
        if isinstance(hidden_dim, int):
            hidden_dim = [hidden_dim]
        self.hidden_dim = hidden_dim
        self.BiLSTM = BiLSTM
        
        # creating expanding window signature layers and corresponding LSTM layers
        self.signature_layers = []
        self.lstm_layers = []
        for l in range(len(self.hidden_dim)):
            # create expanding window signature layer and compute the input dimension to LSTM
            if self.log_signature:    
                self.signature_layers.append(LogSignature(depth=sig_depth, stream=True))
                if l == 0:
                    input_dim_lstm = logsignature_channels(in_channels=input_size,
                                                           depth=sig_depth)
                else:
                    input_dim_lstm = logsignature_channels(in_channels=self.hidden_dim[l-1],
                                                           depth=sig_depth)
            else:
                self.signature_layers.append(Signature(depth=sig_depth, stream=True))
                if l == 0:
                    input_dim_lstm = signature_channels(channels=input_size,
                                                        depth=sig_depth)
                else:
                    input_dim_lstm = signature_channels(channels=self.hidden_dim[l-1],
                                                        depth=sig_depth)
            
            # create LSTM layer (if last layer, this can be a BiLSTM)
            self.lstm_layers.append(nn.LSTM(
                input_size=input_dim_lstm,
                hidden_size=self.hidden_dim[l],
                num_layers=1,
                batch_first=True,
                bidirectional=False if l!=(len(self.hidden_dim)-1) else self.BiLSTM,
            ))
        
        # make a ModuleList from the signatures and LSTM layers
        self.signature_layers = nn.ModuleList(self.signature_layers)
        self.lstm_layers = nn.ModuleList(self.lstm_layers)

        # final signature without lift (i.e. no expanding windows)
        if self.log_signature:
            self.signature2 = LogSignature(depth=sig_depth, stream=False)
        else:
            self.signature2 = Signature(depth=sig_depth, stream=False)
            
    def forward(self, x: torch.Tensor):
        # x has dimensions [batch, length of signal, channels]
        
        # take signature lifts and lstm
        for l in range(len(self.hidden_dim)):
            x = self.signature_layers[l](x)
            x, _ = self.lstm_layers[l](x)
        
        # take final signature
        out = self.signature2(x)
        
        return out

In [55]:
from __future__ import annotations

import os
import random

import torch
from sklearn.model_selection import (
    GroupKFold,
    GroupShuffleSplit,
    KFold,
    train_test_split,
)
from torch.utils.data import TensorDataset
from torch.utils.data.dataloader import DataLoader


class DataSplits:
    """
    Class to split the data into train, validation and test sets.
    """

    def __init__(
        self,
        x_data: torch.Tensor,
        y_data: torch.Tensor,
        train_size: float = 0.8,
        valid_size: float | None = 0.33,
        shuffle: bool = False,
        random_state: int = 42,
    ):
        """
        Class to split the data into train, validation and test sets.

        Parameters
        ----------
        x_data : torch.Tensor
            Features for prediction.
        y_data : torch.Tensor
            Variable to predict.
        train_size : float, optional
            Proportion of data to use as training data, by default 0.8.
        valid_size : float | None, optional
            Proportion of training data to use as validation data, by default 0.33.
            If None, will not create a validation set.
        shuffle : bool, optional
            Whether or not to shuffle the dataset, by default False.
        random_state : int, optional
            Seed number, by default 42.
        """
        if x_data.shape[0] != y_data.shape[0]:
            msg = (
                "x_data and y_data do not have compatible shapes "
                "(need to have same number of samples)"
            )
            raise ValueError(msg)
        if (train_size < 0) or (train_size > 1):
            msg = "train_size must be between 0 and 1"
            raise ValueError(msg)
        if valid_size is not None and ((valid_size < 0) or (valid_size > 1)):
            msg = "valid_size must be between 0 and 1"
            raise ValueError(msg)

        self.x_data = x_data
        self.y_data = y_data
        self.shuffle = shuffle
        if self.shuffle:
            self.random_state = random_state
        else:
            self.random_state = None

        # first split data into train set, test/valid set
        train_index, test_index = train_test_split(
            range(len(self.y_data)),
            test_size=(1 - train_size),
            shuffle=self.shuffle,
            random_state=self.random_state,
        )

        if valid_size is not None:
            # further split the train set into a train, valid set
            train_index, valid_index = train_test_split(
                train_index,
                test_size=valid_size,
                shuffle=self.shuffle,
                random_state=self.random_state,
            )
        else:
            valid_index = None

        # store indices
        self.indices = (train_index, valid_index, test_index)

    def get_splits(
        self, as_DataLoader: bool = False, data_loader_args: dict | None = None
    ):
        """
        Returns train, validation and test set.

        Parameters
        ----------
        as_DataLoader : bool, optional
            Whether or not to return as `torch.utils.data.dataloader.DataLoader` objects
            ready to be passed into PyTorch model, by default False.
        data_loader_args : dict | None, optional
            Any keywords to be passed in obtaining the
            `torch.utils.data.dataloader.DataLoader` object,
            by default {"batch_size": 64, "shuffle": True}.

        Returns
        -------
        - If `as_DataLoader` is True, return tuple of
        `torch.utils.data.dataloader.DataLoader` objects:
          - First element is training dataset
          - Second element is validation dataset
          - Third element is testing dataset
        - If `as_DataLoader` is False, returns tuple of `torch.Tensors`:
          - First element is features for training dataset
          - Second element is labels for training dataset
          - First element is features for validation dataset
          - Second element is labels for validation dataset
          - First element is features for testing dataset
          - Second element is labels for testing dataset
        """
        if data_loader_args is None:
            data_loader_args = {"batch_size": 64, "shuffle": True}

        # obtain validation set
        if self.indices[1] is not None:
            x_valid = self.x_data[self.indices[1]]
            y_valid = self.y_data[self.indices[1]]
        else:
            x_valid = None
            y_valid = None

        # obtain training set
        x_train = self.x_data[self.indices[0]]
        y_train = self.y_data[self.indices[0]]

        # obtain test set
        x_test = self.x_data[self.indices[2]]
        y_test = self.y_data[self.indices[2]]

        if as_DataLoader:
            # return datasets as DataLoader objects if requested
            if x_valid is not None:
                valid = TensorDataset(x_valid, y_valid)
                valid_loader = DataLoader(dataset=valid, **data_loader_args)
            else:
                valid_loader = None

            train = TensorDataset(x_train, y_train)
            test = TensorDataset(x_test, y_test)
            train_loader = DataLoader(dataset=train, **data_loader_args)
            test_loader = DataLoader(dataset=test, **data_loader_args)

            return train_loader, valid_loader, test_loader

        return (
            x_train,
            y_train,
            x_valid,
            y_valid,
            x_test,
            y_test,
        )


class Folds:
    """
    Class to split the data into different folds based on groups.
    """

    def __init__(
        self,
        x_data: torch.Tensor,
        y_data: torch.Tensor,
        groups: torch.Tensor | None = None,
        n_splits: int = 5,
        valid_size: float | None = 0.33,
        shuffle: bool = False,
        random_state: int = 42,
    ):
        """
        Class to split the data into different folds based on groups

        Parameters
        ----------
        x_data : torch.Tensor
            Features for prediction.
        y_data : torch.Tensor
            Variable to predict.
        groups : torch.Tensor | None, optional
            Groups to split by, default None. If None is passed, then does standard KFold,
            otherwise implements GroupShuffleSplit (if shuffle is True),
            or GroupKFold (if shuffle is False).
        n_splits : int, optional
            Number of splits / folds, by default 5.
        valid_size : float | None, optional
            Proportion of training data to use as validation data, by default 0.33.
            If None, will not create a validation set.
        shuffle : bool, optional
            Whether or not to shuffle the dataset, by default False.
        random_state : int, optional
            Seed number, by default 42.

        Raises
        ------
        ValueError
            if `n_splits` < 2.
        ValueError
            if `x_data` and `y_data` do not have the same number of records
            (number of rows in `x_data` should equal the length of `y_data`).
        ValueError
            if `x_data` and `groups` do not have the same number of records
            (number of rows in `x_data` should equal the length of `groups`).
        """
        if n_splits < 2:
            msg = "n_splits should be at least 2"
            raise ValueError(msg)
        if x_data.shape[0] != y_data.shape[0]:
            msg = (
                "x_data and y_data do not have compatible shapes "
                "(need to have same number of samples)"
            )
            raise ValueError(msg)
        if groups is not None and x_data.shape[0] != groups.shape[0]:
            msg = (
                "x_data and groups do not have compatible shapes "
                "(need to have same number of samples)"
            )
            raise ValueError(msg)
        if valid_size is not None and ((valid_size < 0) or (valid_size > 1)):
            msg = "valid_size must be between 0 and 1"
            raise ValueError(msg)

        self.x_data = x_data
        self.y_data = y_data
        self.groups = groups
        self.n_splits = n_splits
        self.shuffle = shuffle
        if self.shuffle:
            self.random_state = random_state
        else:
            self.random_state = None

        if self.groups is not None:
            if self.shuffle:
                # GroupShuffleSplit does not guarantee that every group is in a test group
                self.fold = GroupShuffleSplit(
                    n_splits=self.n_splits, random_state=self.random_state
                )
            else:
                # GroupKFold guarantees that every group is in a test group once
                self.fold = GroupKFold(n_splits=self.n_splits)
        else:
            self.fold = KFold(
                n_splits=self.n_splits,
                shuffle=self.shuffle,
                random_state=self.random_state,
            )

        # obtain fold indices
        self.fold_indices = list(self.fold.split(X=self.x_data, groups=self.groups))

        # make the validation sets within the indices
        for k in range(self.n_splits):
            train_index = self.fold_indices[k][0].tolist()
            test_index = self.fold_indices[k][1].tolist()

            if valid_size is not None:
                # further split the train set into a train, valid set
                train_index, valid_index = train_test_split(
                    train_index,
                    test_size=valid_size,
                    shuffle=self.shuffle,
                    random_state=self.random_state,
                )
            else:
                valid_index = None

            # store indices
            self.fold_indices[k] = (train_index, valid_index, test_index)

    def get_splits(
        self,
        fold_index: int,
        as_DataLoader: bool = False,
        data_loader_args: dict | None = None,
    ) -> (
        tuple[DataLoader, DataLoader, DataLoader]
        | tuple[
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
        ]
    ):
        """
        Obtains the data from a particular fold

        Parameters
        ----------
        fold_index : int
            Which fold to obtain data for
        as_DataLoader : bool, optional
            Whether or not to return as `torch.utils.data.dataloader.DataLoader` objects
            ready to be passed into PyTorch model, by default False.
        data_loader_args : dict | None, optional
            Any keywords to be passed in obtaining the
            `torch.utils.data.dataloader.DataLoader` object,
            by default {"batch_size": 64, "shuffle": True}.

        Returns
        -------
        - If `as_DataLoader` is True, return tuple of
        `torch.utils.data.dataloader.DataLoader` objects:
          - First element is training dataset
          - Second element is validation dataset
          - Third element is testing dataset
        - If `as_DataLoader` is False, returns tuple of `torch.Tensors`:
          - First element is features for training dataset
          - Second element is labels for training dataset
          - First element is features for validation dataset
          - Second element is labels for validation dataset
          - First element is features for testing dataset
          - Second element is labels for testing dataset

        Raises
        ------
        ValueError
            if the requested `fold_index` is not valid (out of range).
        """
        if data_loader_args is None:
            data_loader_args = {"batch_size": 64, "shuffle": True}
        if fold_index not in list(range(self.n_splits)):
            msg = (
                f"There are {self.n_splits} folds, so "
                f"fold_index must be in {list(range(self.n_splits))}"
            )
            raise ValueError(msg)

        # obtain train and test indices for provided fold_index
        train_index = self.fold_indices[fold_index][0]
        valid_index = self.fold_indices[fold_index][1]
        test_index = self.fold_indices[fold_index][2]

        # obtain validation set
        if valid_index is not None:
            x_valid = self.x_data[valid_index]
            y_valid = self.y_data[valid_index]
        else:
            x_valid = None
            y_valid = None

        # obtain training set
        x_train = self.x_data[train_index]
        y_train = self.y_data[train_index]

        # obtain test set
        x_test = self.x_data[test_index]
        y_test = self.y_data[test_index]

        if as_DataLoader:
            # return datasets as DataLoader objects if requested
            if valid_index is not None:
                valid = TensorDataset(x_valid, y_valid)
                valid_loader = DataLoader(dataset=valid, **data_loader_args)
            else:
                valid_loader = None

            train = TensorDataset(x_train, y_train)
            test = TensorDataset(x_test, y_test)
            train_loader = DataLoader(dataset=train, **data_loader_args)
            test_loader = DataLoader(dataset=test, **data_loader_args)

            return train_loader, valid_loader, test_loader

        return (
            x_train,
            y_train,
            x_valid,
            y_valid,
            x_test,
            y_test,
        )


def set_seed(seed: int) -> None:
    """
    Helper function for reproducible behavior to set the seed in
    `random`, `torch`.

    Parameters
    ----------
    seed : int
        Seed number.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    # np.random.seed(seed)  # not needed with numpy generators
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


In [56]:
swnu_network_log_signature, best_swnu_network_log_signature, _, __ = swnu_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    y_data=y_data,
    embedding_dim=embedding_dim,
    output_dim=output_dim,
    window_sizes=window_sizes,
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    sig_depths=[x[1] for x in log_signature_dimensions_and_sig_depths],
    log_signature=True,
    conv_output_channels=conv_output_channels,
    swnu_hidden_dim_sizes=[x[0] for x in log_signature_dimensions_and_sig_depths],
    ffn_hidden_dim_sizes=hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    BiLSTM=False,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    time_feature=time_features,
    standardise_method=standardise_method,
    path_indices=client_index,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_logsignature_grp_focal_{gamma}.csv",
    verbose=False
)

TypeError: swnu_network_hyperparameter_search() got an unexpected keyword argument 'num_time_features'