In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.swmhau_network_functions import (
    swmhau_network_hyperparameter_search,
)

In [3]:
output_dir = "therapist_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-21 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-21 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-21 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-21 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-21 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# swmhau Network

In [7]:
features = ["timeline_index"]
standardise_method = [None]
num_features = len(features)
include_features_in_path = True

In [8]:
num_epochs = 100
embedding_dim = 384
dimensions = [15]
# define swmhau parameters: (output_channels, sig_depth, num_heads)
swmhau_parameters = [(12, 3, 10), (8, 4, 6), (8, 4, 12)]
num_layers = [1]
ffn_hidden_dim_sizes = [[256, 256], [512, 512]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

## UMAP

In [9]:
size = 20
(
    swmhau_network_umap_kfold_20,
    best_swmhau_network_umap_kfold_20,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="main_therapist_behaviour",
    embeddings=sbert_embeddings,
    y_data=y_data_therapist,
    embedding_dim=embedding_dim,
    output_dim=output_dim_therapist,
    history_lengths=[size],
    dim_reduce_methods=["umap"],
    dimensions=dimensions,
    log_signature=True,
    swmhau_parameters=swmhau_parameters,
    num_layers=num_layers,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    features=features,
    standardise_method=standardise_method,
    include_features_in_path=include_features_in_path,
    path_indices=therapist_index,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swmhau_network_umap_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: umap
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/swmhau_network_umap_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/swmhau_network_umap_focal_2_20_kfold_best_model.csv


In [10]:
swmhau_network_umap_kfold_20.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "learning_rate",
    ]
).mean()

  swmhau_network_umap_kfold_20.groupby(["dimensions",


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,...,num_features,embedding_dim,log_signature,dropout_rate,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,learning_rate,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.0001,0.56524,0.513817,0.522566,0.513656,0.561934,0.510605,0.51916,0.510386,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,14.5
15,8,4,6,1,"(256, 256)",0.0005,0.736037,0.702147,0.706176,0.700184,0.737445,0.703123,0.708239,0.70057,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,15.5
15,8,4,6,1,"(256, 256)",0.001,0.732214,0.698647,0.703448,0.697128,0.735764,0.702981,0.708628,0.700651,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,13.5
15,8,4,6,1,"(512, 512)",0.0001,0.573843,0.526182,0.533727,0.525149,0.570468,0.522132,0.531359,0.52085,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,20.5
15,8,4,6,1,"(512, 512)",0.0005,0.727605,0.694506,0.696633,0.694579,0.733101,0.701024,0.703844,0.700445,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,21.5
15,8,4,6,1,"(512, 512)",0.001,0.725352,0.691037,0.695285,0.689939,0.73204,0.698857,0.703941,0.69704,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,19.5
15,8,4,12,1,"(256, 256)",0.0001,0.56145,0.509693,0.518507,0.510839,0.562374,0.510994,0.518718,0.511972,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,26.5
15,8,4,12,1,"(256, 256)",0.0005,0.731326,0.69887,0.699672,0.699532,0.7346,0.701728,0.703352,0.701617,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,27.5
15,8,4,12,1,"(256, 256)",0.001,0.729756,0.697434,0.700644,0.697723,0.731937,0.699867,0.703621,0.69915,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,25.5
15,8,4,12,1,"(512, 512)",0.0001,0.570736,0.523074,0.529294,0.52197,0.570882,0.522441,0.529759,0.521153,20.0,16.0,...,1.0,384.0,1.0,0.3,45.333333,2.0,1.0,5.0,64.0,32.5


In [11]:
best_swmhau_network_umap_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.737403,0.70543,"[0.7460982108869432, 0.5281173594132028, 0.678...",0.711521,"[0.7821229050279329, 0.5409015025041736, 0.622...",0.703583,"[0.7132459970887919, 0.5159235668789809, 0.745...",,0.73623,...,0.0005,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.736174,0.70199,"[0.7535919540229885, 0.5218120805369129, 0.668...",0.706816,"[0.7439716312056738, 0.5514184397163121, 0.637...",0.699202,"[0.7634643377001455, 0.49522292993630573, 0.70...",,0.737161,...,0.0005,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.739246,0.704795,"[0.7590066456803077, 0.5279352226720648, 0.667...",0.706244,"[0.7306397306397306, 0.5370675453047776, 0.674...",0.704212,"[0.789665211062591, 0.5191082802547771, 0.6605...",,0.740884,...,0.0005,123,focal,2,True,5,Conv1d,,concatenation,64


In [12]:
best_swmhau_network_umap_kfold_20["f1"].mean()

0.7040717525292837

In [13]:
best_swmhau_network_umap_kfold_20["precision"].mean()

0.7081935309049031

In [14]:
best_swmhau_network_umap_kfold_20["recall"].mean()

0.7023322823197091

In [15]:
np.stack(best_swmhau_network_umap_kfold_20["f1_scores"]).mean(axis=0)

array([0.75289894, 0.52595489, 0.67122888, 0.86620431])

In [16]:
np.stack(best_swmhau_network_umap_kfold_20["precision_scores"]).mean(axis=0)

array([0.75224476, 0.54312916, 0.64486382, 0.89253639])

In [17]:
np.stack(best_swmhau_network_umap_kfold_20["recall_scores"]).mean(axis=0)

array([0.75545852, 0.51008493, 0.70233513, 0.84145056])

## Random Projections

In [18]:
size = 20
(
    swmhau_network_grp_kfold_20,
    best_swmhau_network_grp_kfold_20,
    _,
    __,
) = swmhau_network_hyperparameter_search(
    num_epochs=num_epochs,
    df=anno_mi,
    id_column="transcript_id",
    label_column="main_therapist_behaviour",
    embeddings=sbert_embeddings,
    y_data=y_data_therapist,
    embedding_dim=embedding_dim,
    output_dim=output_dim_therapist,
    history_lengths=[size],
    dim_reduce_methods=["gaussian_random_projection"],
    dimensions=dimensions,
    log_signature=True,
    swmhau_parameters=swmhau_parameters,
    num_layers=num_layers,
    ffn_hidden_dim_sizes=ffn_hidden_dim_sizes,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    features=features,
    standardise_method=standardise_method,
    path_indices=therapist_index,
    k_fold=True,
    patience=patience,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/swmhau_network_grp_focal_{gamma}_{size}_kfold.csv",
    verbose=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]


##################################################
dimension: 15 | method: gaussian_random_projection
[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] - columns beginning with 'd' denote the dimension reduced embeddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/9699 [00:00<?, ?it/s]

[INFO] The path was created for each item in the dataframe, by looking at its history, so to include embeddings in the FFN input, we concatenate the embeddings for each sentence / text.
saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/swmhau_network_grp_focal_2_20_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/swmhau_network_grp_focal_2_20_kfold_best_model.csv


In [19]:
swmhau_network_grp_kfold_20.groupby(
    [
        "dimensions",
        "output_channels",
        "sig_depth",
        "num_heads",
        "num_layers",
        "ffn_hidden_dim",
        "dropout_rate",
        "learning_rate",
    ]
).mean()

  swmhau_network_grp_kfold_20.groupby(["dimensions",


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,k,input_channels,include_features_in_path,num_features,embedding_dim,log_signature,seed,gamma,k_fold,n_splits,batch_size,model_id
dimensions,output_channels,sig_depth,num_heads,num_layers,ffn_hidden_dim,dropout_rate,learning_rate,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
15,8,4,6,1,"(256, 256)",0.1,0.0001,0.402499,0.335054,0.357382,0.34993,0.407189,0.337477,0.367419,0.352832,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,16.0
15,8,4,6,1,"(256, 256)",0.1,0.0005,0.618326,0.574856,0.57912,0.574079,0.614792,0.568544,0.573328,0.567311,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,17.0
15,8,4,6,1,"(256, 256)",0.1,0.001,0.718353,0.687023,0.688068,0.68908,0.718386,0.685818,0.687505,0.687138,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,15.0
15,8,4,6,1,"(256, 256)",0.5,0.0001,0.415404,0.35113,0.376387,0.362012,0.415309,0.351133,0.377927,0.361865,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,13.0
15,8,4,6,1,"(256, 256)",0.5,0.0005,0.648914,0.611596,0.613769,0.611568,0.646961,0.60886,0.611656,0.608346,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,14.0
15,8,4,6,1,"(256, 256)",0.5,0.001,0.715963,0.68218,0.687323,0.679573,0.721179,0.686079,0.692707,0.682684,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,12.0
15,8,4,6,1,"(512, 512)",0.1,0.0001,0.423255,0.362747,0.377996,0.370455,0.421153,0.36212,0.380737,0.368775,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,22.0
15,8,4,6,1,"(512, 512)",0.1,0.0005,0.642018,0.603691,0.607834,0.602804,0.640859,0.601173,0.605784,0.600105,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,23.0
15,8,4,6,1,"(512, 512)",0.1,0.001,0.717329,0.686805,0.688155,0.68924,0.724127,0.693721,0.694979,0.695539,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,21.0
15,8,4,6,1,"(512, 512)",0.5,0.0001,0.411785,0.345938,0.365454,0.357023,0.419395,0.352948,0.378387,0.363638,20.0,15.0,0.0,1.0,384.0,1.0,45.333333,2.0,1.0,5.0,64.0,19.0


In [20]:
best_swmhau_network_grp_kfold_20

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,learning_rate,seed,loss_function,gamma,k_fold,n_splits,augmentation_type,hidden_dim_aug,comb_method,batch_size
0,,0.718353,0.6897,"[0.7304672897196262, 0.5233506300963677, 0.646...",0.689632,"[0.7509607993850884, 0.4895977808599168, 0.621...",0.692572,"[0.7110625909752547, 0.5621019108280255, 0.673...",,0.721955,...,0.001,1,focal,2,True,5,Conv1d,,concatenation,64
0,,0.709545,0.678953,"[0.7385068039720484, 0.49791376912378305, 0.62...",0.677774,"[0.7464684014869889, 0.4419753086419753, 0.630...",0.684938,"[0.7307132459970888, 0.5700636942675159, 0.613...",,0.717766,...,0.001,12,focal,2,True,5,Conv1d,,concatenation,64
0,,0.724088,0.691761,"[0.7343283582089551, 0.5084745762711865, 0.661...",0.697059,"[0.7534456355283308, 0.5155482815057283, 0.610...",0.690211,"[0.7161572052401747, 0.5015923566878981, 0.722...",,0.732661,...,0.001,123,focal,2,True,5,Conv1d,,concatenation,64


In [21]:
best_swmhau_network_grp_kfold_20["f1"].mean()

0.6868047155741359

In [22]:
best_swmhau_network_grp_kfold_20["precision"].mean()

0.6881549643809765

In [23]:
best_swmhau_network_grp_kfold_20["recall"].mean()

0.689240383625818

In [24]:
np.stack(best_swmhau_network_grp_kfold_20["f1_scores"]).mean(axis=0)

array([0.73443415, 0.50991299, 0.64329698, 0.85957474])

In [25]:
np.stack(best_swmhau_network_grp_kfold_20["precision_scores"]).mean(axis=0)

array([0.75029161, 0.48237379, 0.62073661, 0.89921784])

In [26]:
np.stack(best_swmhau_network_grp_kfold_20["recall_scores"]).mean(axis=0)

array([0.71931101, 0.54458599, 0.66974596, 0.82331857])