In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    ffn_hyperparameter_search,
)

In [3]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-17 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-17 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-17 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-17 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-17 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(9699, 384)

# Baseline: FFN baseline

Using the embeddings for the sentences directly in a FFN to predict the client talk type.

In [7]:
num_epochs = 100
hidden_dim_sizes = [[64,64], [128,128],[256,256]]
dropout_rates = [0.5, 0.2, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

In [8]:
hidden_dim_sizes

[[64, 64], [128, 128], [256, 256]]

In [9]:
learning_rates

[0.001, 0.0001, 0.0005]

We use the `ffn_hyperparameter_search` function which loops through the different hidden dimensions, dropout rates and learning rates to find the best model for the validation set. We evaluate the model on several seeds and average the performance over the seeds.

In [10]:
ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data_client,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_client,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=False,
    patience=patience,
    split_ids=client_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_best_model.csv


In [11]:
ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,0.582908,0.628796,0.510905,"[0.7535211267605634, 0.46384039900249374, 0.31...",0.500335,"[0.834307992202729, 0.4246575342465753, 0.2420...",0.550123,"[0.6869983948635634, 0.510989010989011, 0.4523...",0.694246,0.591981,...,"(64, 64)",0.5,0.0010,1,focal,2,False,,64,0
0,0.578808,0.651294,0.510791,"[0.7790008467400508, 0.45430809399477806, 0.29...",0.501121,"[0.8243727598566308, 0.43283582089552236, 0.24...",0.532446,"[0.7383627608346709, 0.47802197802197804, 0.38...",0.712284,0.616745,...,"(64, 64)",0.5,0.0010,12,focal,2,False,,64,0
0,0.586651,0.623172,0.503207,"[0.7495621716287215, 0.4536082474226804, 0.306...",0.494518,"[0.8246628131021194, 0.42718446601941745, 0.23...",0.540965,"[0.6869983948635634, 0.4835164835164835, 0.452...",0.727132,0.588443,...,"(64, 64)",0.5,0.0010,123,focal,2,False,,64,0
0,0.741634,0.493813,0.331468,"[0.614, 0.3804034582132565, 0.0]",0.357379,"[0.8143236074270557, 0.2578125, 0.0]",0.406017,"[0.492776886035313, 0.7252747252747253, 0.0]",0.784088,0.483491,...,"(64, 64)",0.5,0.0001,1,focal,2,False,,64,1
0,0.585608,0.650169,0.519475,"[0.7743589743589744, 0.46965699208443273, 0.31...",0.509402,"[0.8281535648994516, 0.4517766497461929, 0.248...",0.548236,"[0.7271268057784912, 0.489010989010989, 0.4285...",0.698012,0.610849,...,"(64, 64)",0.5,0.0001,12,focal,2,False,,64,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,0.576948,0.633296,0.511622,"[0.7586805555555556, 0.44845360824742264, 0.32...",0.500555,"[0.8260869565217391, 0.4223300970873786, 0.253...",0.547917,"[0.7014446227929374, 0.47802197802197804, 0.46...",0.760577,0.595519,...,"(256, 256)",0.1,0.0001,12,focal,2,False,,64,25
0,0.582413,0.614173,0.492449,"[0.7442680776014109, 0.4352078239608802, 0.297...",0.483230,"[0.8258317025440313, 0.3920704845814978, 0.231...",0.527682,"[0.6773675762439807, 0.489010989010989, 0.4166...",0.719641,0.599057,...,"(256, 256)",0.1,0.0001,123,focal,2,False,,64,25
0,0.572328,0.634421,0.505037,"[0.7651122625215889, 0.450402144772118, 0.2995...",0.498274,"[0.8280373831775701, 0.4397905759162304, 0.226...",0.537697,"[0.7110754414125201, 0.46153846153846156, 0.44...",0.706368,0.600236,...,"(256, 256)",0.1,0.0005,1,focal,2,False,,64,26
0,0.579547,0.638920,0.513555,"[0.7586805555555556, 0.4640371229698376, 0.317...",0.502324,"[0.8260869565217391, 0.40160642570281124, 0.27...",0.539981,"[0.7014446227929374, 0.5494505494505495, 0.369...",0.730387,0.609670,...,"(256, 256)",0.1,0.0005,12,focal,2,False,,64,26


In [12]:
ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(64, 64)",0.1,0.0001,0.587881,0.630671,0.503897,0.493472,0.536649,0.700988,0.604167,0.486195,0.48593,0.487065,45.333333,2.0,0.0,64.0,7.0
"(64, 64)",0.1,0.0005,0.582073,0.626172,0.504855,0.495513,0.541272,0.715585,0.588443,0.481945,0.478707,0.488473,45.333333,2.0,0.0,64.0,8.0
"(64, 64)",0.1,0.001,0.584537,0.623547,0.50866,0.498198,0.54681,0.717217,0.590409,0.484959,0.483284,0.491069,45.333333,2.0,0.0,64.0,6.0
"(64, 64)",0.2,0.0001,0.587542,0.626922,0.506521,0.49595,0.544325,0.709119,0.599057,0.485869,0.48426,0.488622,45.333333,2.0,0.0,64.0,4.0
"(64, 64)",0.2,0.0005,0.584718,0.611924,0.498983,0.490331,0.542225,0.702964,0.584906,0.485868,0.480491,0.495445,45.333333,2.0,0.0,64.0,5.0
"(64, 64)",0.2,0.001,0.583584,0.621672,0.503762,0.493816,0.541316,0.708372,0.586085,0.482445,0.479052,0.488164,45.333333,2.0,0.0,64.0,3.0
"(64, 64)",0.5,0.0001,0.638109,0.593176,0.45427,0.456074,0.500368,0.735365,0.564072,0.431788,0.431822,0.447262,45.333333,2.0,0.0,64.0,1.0
"(64, 64)",0.5,0.0005,0.585649,0.630296,0.508564,0.498726,0.545242,0.710975,0.600629,0.490215,0.488182,0.494029,45.333333,2.0,0.0,64.0,2.0
"(64, 64)",0.5,0.001,0.582789,0.634421,0.508301,0.498658,0.541178,0.71122,0.599057,0.484404,0.48415,0.487704,45.333333,2.0,0.0,64.0,0.0
"(128, 128)",0.1,0.0001,0.580205,0.628421,0.506438,0.495629,0.541349,0.721606,0.59945,0.487746,0.486118,0.49067,45.333333,2.0,0.0,64.0,16.0


In [13]:
best_ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,0.57735,0.64117,0.512673,"[0.7663230240549829, 0.4461152882205514, 0.325...",0.500571,"[0.8243992606284658, 0.41013824884792627, 0.26...",0.540523,"[0.7158908507223114, 0.489010989010989, 0.4166...",0.719902,0.616745,...,"[0.7598499061913696, 0.4027777777777778, 0.313...","(256, 256)",0.2,0.0001,1,focal,2,False,,64
0,0.577855,0.628796,0.506871,"[0.7558644656820157, 0.43701799485861187, 0.32...",0.495913,"[0.8238636363636364, 0.4106280193236715, 0.253...",0.543184,"[0.6982343499197432, 0.46703296703296704, 0.46...",0.758114,0.601415,...,"[0.7410881801125704, 0.37037037037037035, 0.35...","(256, 256)",0.2,0.0001,12,focal,2,False,,64
0,0.583869,0.616423,0.491256,"[0.7480245829675154, 0.4313725490196078, 0.294...",0.482085,"[0.8255813953488372, 0.3893805309734513, 0.231...",0.524022,"[0.6837881219903692, 0.4835164835164835, 0.404...",0.719763,0.602594,...,"[0.7317073170731707, 0.4074074074074074, 0.333...","(256, 256)",0.2,0.0001,123,focal,2,False,,64


In [14]:
best_ffn_current["f1"].mean()

0.5036002985512656

In [15]:
best_ffn_current["precision"].mean()

0.49285621491762543

In [16]:
best_ffn_current["recall"].mean()

0.5359097831007943

In [17]:
np.stack(best_ffn_current["f1_scores"]).mean(axis=0)

array([0.75673736, 0.43816861, 0.31589493])

In [18]:
np.stack(best_ffn_current["precision_scores"]).mean(axis=0)

array([0.82461476, 0.40338227, 0.25057161])

In [19]:
np.stack(best_ffn_current["recall_scores"]).mean(axis=0)

array([0.69930444, 0.47985348, 0.42857143])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [20]:
ffn_current_kfold, best_ffn_current_kfold, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data_client,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_client,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=True,
    patience=patience,
    split_ids=client_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv


In [21]:
ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,,0.639272,0.518299,"[0.7679810958499483, 0.4306732055283103, 0.356...",0.514406,"[0.7847872019317839, 0.4684772065955383, 0.289...",0.537404,"[0.7518796992481203, 0.39851485148514854, 0.46...",,0.661047,...,"(64, 64)",0.5,0.0010,1,focal,2,True,5,64,0
0,,0.645211,0.521731,"[0.7727807172251616, 0.43989314336598395, 0.35...",0.518379,"[0.7857142857142857, 0.47775628626692457, 0.29...",0.537770,"[0.760266049739734, 0.4075907590759076, 0.4454...",,0.658799,...,"(64, 64)",0.5,0.0010,12,focal,2,True,5,64,0
0,,0.629310,0.508992,"[0.761365330153525, 0.4254754533392304, 0.3401...",0.505292,"[0.7856044294063365, 0.45853193517635843, 0.27...",0.529996,"[0.7385772122614228, 0.39686468646864687, 0.45...",,0.649807,...,"(64, 64)",0.5,0.0010,123,focal,2,True,5,64,0
0,,0.643103,0.493964,"[0.7768313458262351, 0.39620081411126185, 0.30...",0.496532,"[0.7629670942554378, 0.43843843843843844, 0.28...",0.495107,"[0.7912087912087912, 0.3613861386138614, 0.332...",,0.656230,...,"(64, 64)",0.5,0.0001,1,focal,2,True,5,64,1
0,,0.646743,0.521010,"[0.7752165614447218, 0.43796255986068783, 0.34...",0.516301,"[0.787354607813898, 0.46359447004608295, 0.297...",0.534033,"[0.763447079236553, 0.415016501650165, 0.42363...",,0.657354,...,"(64, 64)",0.5,0.0001,12,focal,2,True,5,64,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,,0.636782,0.510545,"[0.7648449039881833, 0.4371900826446281, 0.329...",0.504634,"[0.7817028985507246, 0.4379139072847682, 0.294...",0.519904,"[0.7486986697513013, 0.43646864686468645, 0.37...",,0.664419,...,"(256, 256)",0.1,0.0001,12,focal,2,True,5,64,25
0,,0.639464,0.513886,"[0.7655882352941176, 0.43705857914416285, 0.33...",0.508348,"[0.7788749251944943, 0.4401673640167364, 0.306...",0.522247,"[0.7527472527472527, 0.43399339933993397, 0.38]",,0.665543,...,"(256, 256)",0.1,0.0001,123,focal,2,True,5,64,25
0,,0.634674,0.514114,"[0.7663690476190477, 0.42844522968197885, 0.34...",0.509890,"[0.7893930104230533, 0.4610266159695818, 0.279...",0.534938,"[0.7446500867553499, 0.40016501650165015, 0.46]",,0.653179,...,"(256, 256)",0.1,0.0005,1,focal,2,True,5,64,26
0,,0.640421,0.517552,"[0.7701826753093693, 0.42434936038817817, 0.35...",0.512638,"[0.784984984984985, 0.45592417061611373, 0.297...",0.534567,"[0.7559282822440717, 0.39686468646864687, 0.45...",,0.658478,...,"(256, 256)",0.1,0.0005,12,focal,2,True,5,64,26


In [22]:
ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
"(64, 64)",0.1,0.0001,0.635824,0.51166,0.505265,0.524039,0.658531,0.546615,0.543247,0.553267,45.333333,2.0,1.0,5.0,64.0,7.0
"(64, 64)",0.1,0.0005,0.63493,0.51291,0.507614,0.529921,0.655855,0.546246,0.543332,0.557253,45.333333,2.0,1.0,5.0,64.0,8.0
"(64, 64)",0.1,0.001,0.630077,0.51039,0.504599,0.528383,0.654089,0.546322,0.542932,0.557513,45.333333,2.0,1.0,5.0,64.0,6.0
"(64, 64)",0.2,0.0001,0.637165,0.511612,0.505694,0.523203,0.658157,0.544536,0.541659,0.550871,45.333333,2.0,1.0,5.0,64.0,4.0
"(64, 64)",0.2,0.0005,0.629374,0.512933,0.505438,0.531583,0.654624,0.548361,0.543237,0.56014,45.333333,2.0,1.0,5.0,64.0,5.0
"(64, 64)",0.2,0.001,0.626117,0.508063,0.501684,0.52769,0.652537,0.545799,0.54165,0.558235,45.333333,2.0,1.0,5.0,64.0,3.0
"(64, 64)",0.5,0.0001,0.645211,0.509257,0.507438,0.516472,0.657675,0.532637,0.535551,0.533961,45.333333,2.0,1.0,5.0,64.0,1.0
"(64, 64)",0.5,0.0005,0.638825,0.515069,0.50971,0.528625,0.659067,0.547164,0.544468,0.554825,45.333333,2.0,1.0,5.0,64.0,2.0
"(64, 64)",0.5,0.001,0.637931,0.51634,0.512692,0.535057,0.656551,0.544456,0.543756,0.556143,45.333333,2.0,1.0,5.0,64.0,0.0
"(128, 128)",0.1,0.0001,0.637803,0.514196,0.507761,0.5261,0.660083,0.548817,0.545523,0.55492,45.333333,2.0,1.0,5.0,64.0,16.0


In [23]:
best_ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,,0.636973,0.511104,"[0.7645756457564574, 0.43752594437525943, 0.33...",0.505237,"[0.7808260476334037, 0.4402673350041771, 0.294...",0.520663,"[0.7489878542510121, 0.43481848184818483, 0.37...",,0.665382,...,"[0.7901234567901234, 0.469281045751634, 0.3978...","(256, 256)",0.1,0.0001,1,focal,2,True,5,64
0,,0.636782,0.510545,"[0.7648449039881833, 0.4371900826446281, 0.329...",0.504634,"[0.7817028985507246, 0.4379139072847682, 0.294...",0.519904,"[0.7486986697513013, 0.43646864686468645, 0.37...",,0.664419,...,"[0.7876039304610734, 0.46797385620915033, 0.40...","(256, 256)",0.1,0.0001,12,focal,2,True,5,64
0,,0.639464,0.513886,"[0.7655882352941176, 0.43705857914416285, 0.33...",0.508348,"[0.7788749251944943, 0.4401673640167364, 0.306...",0.522247,"[0.7527472527472527, 0.43399339933993397, 0.38]",,0.665543,...,"[0.7903754094230284, 0.4627450980392157, 0.411...","(256, 256)",0.1,0.0001,123,focal,2,True,5,64


In [24]:
best_ffn_current_kfold["f1"].mean()

0.5118449028528574

In [25]:
best_ffn_current_kfold["precision"].mean()

0.5060731871074536

In [26]:
best_ffn_current_kfold["recall"].mean()

0.5209379530588493

In [27]:
np.stack(best_ffn_current_kfold["f1_scores"]).mean(axis=0)

array([0.76500293, 0.4372582 , 0.33327358])

In [28]:
np.stack(best_ffn_current_kfold["precision_scores"]).mean(axis=0)

array([0.78046796, 0.43944954, 0.29830207])

In [29]:
np.stack(best_ffn_current_kfold["recall_scores"]).mean(axis=0)

array([0.75014459, 0.43509351, 0.37757576])