In [4]:
import numpy as np
import pickle
import os

seed = 2023

In [None]:
import torch

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [5]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    histories_baseline_hyperparameter_search,
)

In [6]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [7]:
%run ../load_anno_mi.py

In [8]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-18 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-18 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-18 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-18 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-18 00:00:34


In [9]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

(9699, 384)

# Baseline: Averaging history and use FFN

Here, we average the full history of a path and concatenate it the current embedding (the total number of features that are passed into the FFN is `2 * sbert_embeddings.shape[0]`).

Here, we will run the hyperparameter search to implement the FFN with the same parameters as the standard FFN baseline on the sentence embeddings. Going to try out some variations (1 hidden layer, 2 hidden layers and 3 hidden layers - all of size 100).

In [10]:
num_epochs = 100
hidden_dim_sizes = [[64, 64], [128, 128], [256, 256], [512, 512]]
dropout_rates = [0.1, 0.2]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [None]:
kwargs = {
    "num_epochs": num_epochs,
    "df": anno_mi,
    "id_column": "transcript_id",
    "label_column": "client_talk_type",
    "embeddings": sbert_embeddings,
    "y_data": y_data_client,
    "output_dim": output_dim_client,
    "hidden_dim_sizes": hidden_dim_sizes,
    "dropout_rates": dropout_rates,
    "learning_rates": learning_rates,
    "seeds": seeds,
    "loss": loss,
    "gamma": gamma,
    "device": device,
    "path_indices": client_index,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "patience": patience,
    "validation_metric": validation_metric,
    "verbose": False,
}

In [21]:
(
    ffn_mean_history_kfold,
    best_ffn_mean_history_kfold,
    _,
    __,
) = histories_baseline_hyperparameter_search(
    use_signatures=False,
    results_output=f"{output_dir}/ffn_mean_history_focal_{gamma}_kfold.csv",
    **kwargs,
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/9699 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' feature...
[INFO] Adding 'time_diff' feature...
[INFO] Adding 'timeline_index' feature...
Computing the mean history for each item in the dataframe


  0%|          | 0/9699 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_mean_history_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_mean_history_focal_2_kfold_best_model.csv


In [22]:
ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id,input_dim
0,,0.621073,0.511649,"[0.7466990438609805, 0.44222776392352453, 0.34...",0.503527,"[0.7856914723730437, 0.4455611390284757, 0.279...",0.534961,"[0.7113938692886062, 0.4389438943894389, 0.454...",,0.676622,...,0.5,0.0010,1,focal,2,True,5,64,0.00,768
0,,0.619732,0.509749,"[0.7479501973883997, 0.4378748928877464, 0.343...",0.503971,"[0.7874040920716112, 0.4554367201426025, 0.269...",0.536141,"[0.7122614227877386, 0.42161716171617164, 0.47...",,0.679672,...,0.5,0.0010,12,focal,2,True,5,64,0.00,768
0,,0.610536,0.509567,"[0.7353579175704988, 0.4476151651039544, 0.345...",0.501341,"[0.7920560747663551, 0.44238517324738114, 0.26...",0.540341,"[0.6862348178137652, 0.452970297029703, 0.4818...",,0.667470,...,0.5,0.0010,123,focal,2,True,5,64,0.00,768
0,,0.637739,0.518219,"[0.7610435813815595, 0.446546052631579, 0.3470...",0.511128,"[0.7807177615571776, 0.44508196721311477, 0.30...",0.529513,"[0.7423366107576634, 0.44801980198019803, 0.39...",,0.687861,...,0.5,0.0001,1,focal,2,True,5,64,0.10,768
0,,0.640613,0.519792,"[0.7647668393782383, 0.44935389745727383, 0.34...",0.513052,"[0.7834394904458599, 0.45408593091828137, 0.30...",0.531773,"[0.7469635627530364, 0.44471947194719474, 0.40...",,0.690591,...,0.5,0.0001,12,focal,2,True,5,64,0.10,768
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.635057,0.514847,"[0.7581409117821197, 0.4446280991735537, 0.341...",0.508139,"[0.7765312310491207, 0.445364238410596, 0.3025...",0.525741,"[0.7406015037593985, 0.4438943894389439, 0.392...",,0.692999,...,0.1,0.0001,12,focal,2,True,5,64,0.22,768
0,,0.633142,0.521705,"[0.7546658639373871, 0.4541800643086817, 0.356...",0.512353,"[0.7868801004394225, 0.4427899686520376, 0.307...",0.538265,"[0.7249855407750144, 0.4661716171617162, 0.423...",,0.686095,...,0.1,0.0001,123,focal,2,True,5,64,0.22,768
0,,0.620306,0.512478,"[0.7439229475615349, 0.4476449980537173, 0.345...",0.502589,"[0.7891663963671748, 0.423728813559322, 0.2948...",0.532063,"[0.7035858877964141, 0.4744224422442244, 0.418...",,0.682884,...,0.1,0.0005,1,focal,2,True,5,64,0.23,768
0,,0.627586,0.514185,"[0.7502246181491465, 0.44352844187963725, 0.34...",0.505850,"[0.7779503105590062, 0.443163097199341, 0.2964...",0.530646,"[0.7244071717755928, 0.4438943894389439, 0.423...",,0.687861,...,0.1,0.0005,12,focal,2,True,5,64,0.23,768


In [23]:
ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_mean_history_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id,input_dim
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(64, 64)",0.1,0.0001,0.636909,0.519518,0.511822,0.532172,0.688825,0.591687,0.584999,0.601366,45.333333,2.0,1.0,5.0,64.0,0.4,768.0
"(64, 64)",0.1,0.0005,0.628033,0.518918,0.509948,0.539582,0.681385,0.589699,0.580559,0.606701,45.333333,2.0,1.0,5.0,64.0,0.5,768.0
"(64, 64)",0.1,0.001,0.626181,0.517932,0.508816,0.540365,0.683687,0.592703,0.583401,0.610247,45.333333,2.0,1.0,5.0,64.0,0.3,768.0
"(64, 64)",0.5,0.0001,0.638825,0.520243,0.512853,0.532788,0.690698,0.591499,0.585948,0.59984,45.333333,2.0,1.0,5.0,64.0,0.1,768.0
"(64, 64)",0.5,0.0005,0.630843,0.515559,0.508288,0.53353,0.687058,0.591491,0.584674,0.604698,45.333333,2.0,1.0,5.0,64.0,0.2,768.0
"(64, 64)",0.5,0.001,0.617114,0.510322,0.502947,0.537148,0.674588,0.585919,0.576055,0.610059,45.333333,2.0,1.0,5.0,64.0,0.0,768.0
"(128, 128)",0.1,0.0001,0.63461,0.518088,0.510114,0.530949,0.688022,0.590833,0.584252,0.600106,45.333333,2.0,1.0,5.0,64.0,0.1,768.0
"(128, 128)",0.1,0.0005,0.626373,0.514391,0.505707,0.532139,0.68101,0.588465,0.579661,0.603157,45.333333,2.0,1.0,5.0,64.0,0.11,768.0
"(128, 128)",0.1,0.001,0.624202,0.515967,0.507439,0.540247,0.680368,0.58851,0.579656,0.607578,45.333333,2.0,1.0,5.0,64.0,0.9,768.0
"(128, 128)",0.5,0.0001,0.635057,0.520459,0.512132,0.535379,0.689413,0.593805,0.586554,0.60502,45.333333,2.0,1.0,5.0,64.0,0.7,768.0


In [24]:
best_ffn_mean_history_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,input_dim
0,,0.624904,0.512535,"[0.7473684210526316, 0.4439951475940154, 0.346...",0.503713,"[0.7785087719298246, 0.43536875495638383, 0.29...",0.528713,"[0.7186234817813765, 0.452970297029703, 0.4145...",,0.691073,...,"(512, 512)",0.5,0.0001,1,focal,2,True,5,64,768
0,,0.631801,0.520683,"[0.7528666264333133, 0.45522682445759366, 0.35...",0.511354,"[0.7870662460567823, 0.436130007558579, 0.3108...",0.536166,"[0.7215153267784846, 0.4760726072607261, 0.410...",,0.687219,...,"(512, 512)",0.5,0.0001,12,focal,2,True,5,64,768
0,,0.634291,0.517746,"[0.7558538404175988, 0.4524484014569, 0.344936...",0.509913,"[0.7804126886356637, 0.44400317712470216, 0.30...",0.530126,"[0.7327935222672065, 0.4612211221122112, 0.396...",,0.692036,...,"(512, 512)",0.5,0.0001,123,focal,2,True,5,64,768


In [25]:
best_ffn_mean_history_kfold["f1"].mean()

0.5169880010201018

In [26]:
best_ffn_mean_history_kfold["precision"].mean()

0.5083267122281124

In [27]:
best_ffn_mean_history_kfold["recall"].mean()

0.5316682821164322

In [28]:
np.stack(best_ffn_mean_history_kfold["f1_scores"]).mean(axis=0)

array([0.75202963, 0.45055679, 0.34837758])

In [29]:
np.stack(best_ffn_mean_history_kfold["precision_scores"]).mean(axis=0)

array([0.7819959 , 0.43850065, 0.30448359])

In [30]:
np.stack(best_ffn_mean_history_kfold["recall_scores"]).mean(axis=0)

array([0.72431078, 0.46342134, 0.40727273])