In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    ffn_hyperparameter_search,
)

In [3]:
output_dir = "therapist_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-08-19 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-08-19 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-08-19 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-08-19 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-08-19 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(9699, 384)

In [7]:
output_dim_therapist

4

In [8]:
label_to_id_therapist

{'question': 0, 'therapist_input': 1, 'reflection': 2, 'other': 3}

In [9]:
id_to_label_therapist

{0: 'question', 1: 'therapist_input', 2: 'reflection', 3: 'other'}

# Baseline: FFN baseline

Using the embeddings for the sentences directly in a FFN to predict the therapist talk type.

In [10]:
num_epochs = 100
hidden_dim_sizes = [[64,64],[128,128],[256,256]]
dropout_rates = [0.5, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [1, 12, 123]
loss = "focal"
gamma = 2
validation_metric = "f1"
patience = 5

We use the `ffn_hyperparameter_search` function which loops through the different hidden dimensions, dropout rates and learning rates to find the best model for the validation set. We evaluate the model on several seeds and average the performance over the seeds.

In [11]:
ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[therapist_index],
    y_data=y_data_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_therapist,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=False,
    patience=patience,
    split_ids=therapist_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2_best_model.csv


In [12]:
ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,0.6884,0.775556,0.738975,"[0.815742397137746, 0.5784313725490198, 0.6997...",0.749626,"[0.7916666666666666, 0.6555555555555556, 0.662...",0.734275,"[0.8413284132841329, 0.5175438596491229, 0.741...",0.751872,0.724855,...,"(64, 64)",0.5,0.001,1,focal,2,False,,64,0
0,0.698119,0.764444,0.725872,"[0.8066298342541437, 0.5520833333333333, 0.687...",0.747548,"[0.8051470588235294, 0.6794871794871795, 0.611...",0.720312,"[0.8081180811808119, 0.4649122807017544, 0.784...",0.763452,0.728324,...,"(64, 64)",0.5,0.001,12,focal,2,False,,64,0
0,0.695803,0.764444,0.729739,"[0.7977736549165121, 0.5688073394495413, 0.690...",0.733279,"[0.8022388059701493, 0.5961538461538461, 0.645...",0.728861,"[0.7933579335793358, 0.543859649122807, 0.7416...",0.764869,0.717919,...,"(64, 64)",0.5,0.001,123,focal,2,False,,64,0
0,0.737284,0.772222,0.73721,"[0.8051948051948052, 0.5771144278606964, 0.694...",0.752662,"[0.8097014925373134, 0.6666666666666666, 0.620...",0.733079,"[0.8007380073800738, 0.5087719298245614, 0.789...",0.774946,0.72948,...,"(64, 64)",0.5,0.0001,1,focal,2,False,,64,1
0,0.705236,0.764444,0.724082,"[0.8115942028985507, 0.5454545454545454, 0.678...",0.738973,"[0.797153024911032, 0.6428571428571429, 0.6215...",0.719182,"[0.8265682656826568, 0.47368421052631576, 0.74...",0.752087,0.731792,...,"(64, 64)",0.5,0.0001,12,focal,2,False,,64,1
0,0.725045,0.766667,0.726373,"[0.8094373865698729, 0.5492227979274612, 0.685...",0.745613,"[0.7964285714285714, 0.6708860759493671, 0.623...",0.720472,"[0.8228782287822878, 0.4649122807017544, 0.760...",0.772833,0.726012,...,"(64, 64)",0.5,0.0001,123,focal,2,False,,64,1
0,0.732368,0.772222,0.735728,"[0.8075471698113207, 0.5685279187817258, 0.695...",0.754755,"[0.8262548262548263, 0.6746987951807228, 0.613...",0.731148,"[0.7896678966789668, 0.49122807017543857, 0.80...",0.7757,0.726012,...,"(64, 64)",0.5,0.0005,1,focal,2,False,,64,2
0,0.68539,0.767778,0.731942,"[0.8051001821493625, 0.5757575757575758, 0.682...",0.748368,"[0.7949640287769785, 0.6785714285714286, 0.622...",0.726203,"[0.8154981549815498, 0.5, 0.7559808612440191, ...",0.765139,0.726012,...,"(64, 64)",0.5,0.0005,12,focal,2,False,,64,2
0,0.656865,0.771111,0.733758,"[0.8138686131386861, 0.5670103092783505, 0.692...",0.753762,"[0.8050541516245487, 0.6875, 0.622137404580152...",0.728009,"[0.8228782287822878, 0.4824561403508772, 0.779...",0.753781,0.727168,...,"(64, 64)",0.5,0.0005,123,focal,2,False,,64,2
0,0.734442,0.771111,0.734854,"[0.8044280442804429, 0.5714285714285715, 0.694...",0.753782,"[0.8044280442804428, 0.6829268292682927, 0.620...",0.729616,"[0.8044280442804428, 0.49122807017543857, 0.78...",0.777416,0.727168,...,"(64, 64)",0.1,0.001,1,focal,2,False,,64,3


In [13]:
ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,accuracy,f1,precision,recall,valid_loss,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
"(64, 64)",0.1,0.0001,0.686953,0.767778,0.731774,0.746623,0.727195,0.755143,0.725241,0.694131,0.695423,0.694578,45.333333,2.0,0.0,64.0,4.0
"(64, 64)",0.1,0.0005,0.706603,0.765556,0.727976,0.742016,0.724126,0.762475,0.724856,0.69208,0.693398,0.69284,45.333333,2.0,0.0,64.0,5.0
"(64, 64)",0.1,0.001,0.701142,0.764815,0.730015,0.74349,0.727845,0.765919,0.717919,0.688437,0.691545,0.691091,45.333333,2.0,0.0,64.0,3.0
"(64, 64)",0.5,0.0001,0.722522,0.767778,0.729222,0.745749,0.724245,0.766622,0.729094,0.697413,0.700755,0.69572,45.333333,2.0,0.0,64.0,1.0
"(64, 64)",0.5,0.0005,0.691541,0.77037,0.73381,0.752295,0.728453,0.764874,0.726397,0.694321,0.696544,0.69389,45.333333,2.0,0.0,64.0,2.0
"(64, 64)",0.5,0.001,0.694107,0.768148,0.731529,0.743484,0.727816,0.760064,0.723699,0.689815,0.691251,0.689757,45.333333,2.0,0.0,64.0,0.0
"(128, 128)",0.1,0.0001,0.687858,0.768519,0.736023,0.750983,0.731518,0.749309,0.727553,0.69788,0.699886,0.697434,45.333333,2.0,0.0,64.0,10.0
"(128, 128)",0.1,0.0005,0.719532,0.765926,0.732388,0.744541,0.729708,0.74613,0.723699,0.692557,0.696448,0.691734,45.333333,2.0,0.0,64.0,11.0
"(128, 128)",0.1,0.001,0.69167,0.767037,0.730921,0.743851,0.727956,0.748847,0.724856,0.693135,0.696535,0.692102,45.333333,2.0,0.0,64.0,9.0
"(128, 128)",0.5,0.0001,0.734703,0.766296,0.729548,0.749738,0.723915,0.766852,0.728709,0.699479,0.701919,0.698736,45.333333,2.0,0.0,64.0,7.0


In [14]:
best_ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,0.757341,0.767778,0.731562,"[0.8051470588235293, 0.5700000000000001, 0.686...",0.746361,"[0.8021978021978022, 0.6627906976744186, 0.622...",0.72675,"[0.8081180811808119, 0.5, 0.7655502392344498, ...",0.758058,0.731792,...,"[0.7559322033898305, 0.5294117647058824, 0.695...","(128, 128)",0.5,0.0001,1,focal,2,False,,64
0,0.709618,0.764444,0.728221,"[0.7992700729927008, 0.5671641791044776, 0.682...",0.742136,"[0.7906137184115524, 0.6551724137931034, 0.622...",0.723541,"[0.8081180811808119, 0.5, 0.7559808612440191, ...",0.751756,0.728324,...,"[0.752542372881356, 0.5392156862745098, 0.6826...","(128, 128)",0.5,0.0001,12,focal,2,False,,64
0,0.737151,0.766667,0.72886,"[0.7992495309568479, 0.5621621621621622, 0.684...",0.760718,"[0.8129770992366412, 0.7323943661971831, 0.595...",0.721454,"[0.7859778597785978, 0.45614035087719296, 0.80...",0.790741,0.726012,...,"[0.7186440677966102, 0.5098039215686274, 0.730...","(128, 128)",0.5,0.0001,123,focal,2,False,,64


In [15]:
best_ffn_current["f1"].mean()

0.7295480092797342

In [16]:
best_ffn_current["precision"].mean()

0.7497380871340166

In [17]:
best_ffn_current["recall"].mean()

0.7239150998790221

In [18]:
np.stack(best_ffn_current["f1_scores"]).mean(axis=0)

array([0.80122222, 0.56644211, 0.68450613, 0.86602157])

In [19]:
np.stack(best_ffn_current["precision_scores"]).mean(axis=0)

array([0.80192954, 0.68345249, 0.61345334, 0.90011698])

In [20]:
np.stack(best_ffn_current["recall_scores"]).mean(axis=0)

array([0.80073801, 0.48538012, 0.77511962, 0.83442266])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [21]:
ffn_current_kfold, best_ffn_current_kfold, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[therapist_index],
    y_data=y_data_therapist,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_therapist,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=True,
    patience=patience,
    split_ids=therapist_transcript_id,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in therapist_talk_type_output/ffn_current_focal_2_kfold_best_model.csv


In [22]:
ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size,model_id
0,,0.734725,0.70163,"[0.7579222476314931, 0.5231884057971014, 0.662...",0.701194,"[0.7416879795396419, 0.5084507042253521, 0.660...",0.702985,"[0.7748830995323981, 0.5388059701492537, 0.664...",,0.755063,...,"(64, 64)",0.5,0.001,1,focal,2,True,5,64,0
0,,0.736812,0.702193,"[0.767816091954023, 0.5121056493030082, 0.6645...",0.702707,"[0.7551679586563308, 0.5036075036075036, 0.650...",0.702746,"[0.7808951235804943, 0.5208955223880597, 0.679...",,0.755854,...,"(64, 64)",0.5,0.001,12,focal,2,True,5,64,0
0,,0.733207,0.70008,"[0.7615635179153094, 0.5177304964539007, 0.658...",0.699226,"[0.7431659249841068, 0.49324324324324326, 0.66...",0.702514,"[0.7808951235804943, 0.5447761194029851, 0.655...",,0.754589,...,"(64, 64)",0.5,0.001,123,focal,2,True,5,64,0
0,,0.736433,0.69978,"[0.7629678266579121, 0.5085271317829457, 0.660...",0.7031,"[0.750161394448031, 0.5290322580645161, 0.6379...",0.697778,"[0.7762191048764195, 0.48955223880597015, 0.68...",,0.756487,...,"(64, 64)",0.5,0.0001,1,focal,2,True,5,64,1
0,,0.736433,0.697915,"[0.7658743080429827, 0.49842271293375395, 0.66...",0.702077,"[0.747141041931385, 0.5284280936454849, 0.6451...",0.695508,"[0.7855711422845691, 0.4716417910447761, 0.687...",,0.759652,...,"(64, 64)",0.5,0.0001,12,focal,2,True,5,64,1
0,,0.735863,0.69342,"[0.7710371819960861, 0.4789915966386555, 0.660...",0.703673,"[0.7533460803059273, 0.5480769230769231, 0.622...",0.68885,"[0.7895791583166333, 0.4253731343283582, 0.702...",,0.75538,...,"(64, 64)",0.5,0.0001,123,focal,2,True,5,64,1
0,,0.735294,0.70178,"[0.7599738391105297, 0.5217391304347826, 0.661...",0.701258,"[0.7443946188340808, 0.5070422535211268, 0.660...",0.703141,"[0.7762191048764195, 0.5373134328358209, 0.661...",,0.755538,...,"(64, 64)",0.5,0.0005,1,focal,2,True,5,64,2
0,,0.734915,0.700256,"[0.7641478573765129, 0.5119760479041916, 0.663...",0.701895,"[0.7487179487179487, 0.5135135135135135, 0.645...",0.69987,"[0.7802271209084837, 0.5104477611940299, 0.683...",,0.756171,...,"(64, 64)",0.5,0.0005,12,focal,2,True,5,64,2
0,,0.738899,0.704191,"[0.7674190382728164, 0.5173176123802505, 0.669...",0.704327,"[0.7519230769230769, 0.5109170305676856, 0.663...",0.70473,"[0.7835671342685371, 0.5238805970149254, 0.675...",,0.758544,...,"(64, 64)",0.5,0.0005,123,focal,2,True,5,64,2
0,,0.732638,0.69967,"[0.7585067723818963, 0.5185724690458849, 0.662...",0.699748,"[0.7503267973856209, 0.5064011379800853, 0.648...",0.700582,"[0.7668670674682698, 0.5313432835820896, 0.676...",,0.753165,...,"(64, 64)",0.1,0.001,1,focal,2,True,5,64,3


In [23]:
ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,precision,recall,valid_accuracy,valid_f1,valid_precision,valid_recall,seed,gamma,k_fold,n_splits,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
"(64, 64)",0.1,0.0001,0.733903,0.69914,0.699985,0.699055,0.755274,0.725999,0.724266,0.729179,45.333333,2.0,1.0,5.0,64.0,4.0
"(64, 64)",0.1,0.0005,0.733333,0.699665,0.6996,0.700784,0.754905,0.725671,0.72374,0.729527,45.333333,2.0,1.0,5.0,64.0,5.0
"(64, 64)",0.1,0.001,0.734092,0.70128,0.7013,0.702279,0.752004,0.723014,0.721072,0.72692,45.333333,2.0,1.0,5.0,64.0,3.0
"(64, 64)",0.5,0.0001,0.736243,0.697038,0.70295,0.694045,0.757173,0.724455,0.726269,0.723763,45.333333,2.0,1.0,5.0,64.0,1.0
"(64, 64)",0.5,0.0005,0.736369,0.702076,0.702494,0.70258,0.756751,0.726481,0.725006,0.729241,45.333333,2.0,1.0,5.0,64.0,2.0
"(64, 64)",0.5,0.001,0.734915,0.701301,0.701042,0.702748,0.755169,0.725324,0.723544,0.728929,45.333333,2.0,1.0,5.0,64.0,0.0
"(128, 128)",0.1,0.0001,0.734029,0.700861,0.70039,0.702181,0.75443,0.725589,0.723368,0.729714,45.333333,2.0,1.0,5.0,64.0,10.0
"(128, 128)",0.1,0.0005,0.732891,0.700889,0.700504,0.702972,0.753587,0.725764,0.723841,0.730474,45.333333,2.0,1.0,5.0,64.0,11.0
"(128, 128)",0.1,0.001,0.729665,0.696912,0.696258,0.699192,0.752373,0.722855,0.720934,0.727293,45.333333,2.0,1.0,5.0,64.0,9.0
"(128, 128)",0.5,0.0001,0.737887,0.704211,0.704931,0.704156,0.757437,0.727819,0.726095,0.730599,45.333333,2.0,1.0,5.0,64.0,7.0


In [24]:
best_ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,precision,precision_scores,recall,recall_scores,valid_loss,valid_accuracy,...,valid_recall_scores,hidden_dim,dropout_rate,learning_rate,seed,loss_function,gamma,k_fold,n_splits,batch_size
0,,0.732068,0.699367,"[0.7584589614740368, 0.5145700071073205, 0.661...",0.69897,"[0.760752688172043, 0.49118046132971505, 0.644...",0.701272,"[0.7561790247160989, 0.5402985074626866, 0.679...",,0.755538,...,"[0.7607709750566893, 0.6163366336633663, 0.706...","(256, 256)",0.5,0.0005,1,focal,2,True,5,64
0,,0.735484,0.701811,"[0.7597840755735493, 0.5161290322580646, 0.670...",0.704227,"[0.7675528289025222, 0.5188536953242836, 0.635...",0.701075,"[0.7521710086840347, 0.5134328358208955, 0.709...",,0.759177,...,"[0.7590702947845805, 0.5841584158415841, 0.732...","(256, 256)",0.5,0.0005,12,focal,2,True,5,64
0,,0.731689,0.699606,"[0.7595959595959595, 0.5131034482758621, 0.662...",0.698649,"[0.7657841140529531, 0.47692307692307695, 0.65...",0.702967,"[0.7535070140280561, 0.5552238805970149, 0.674...",,0.754905,...,"[0.7647392290249433, 0.6126237623762376, 0.704...","(256, 256)",0.5,0.0005,123,focal,2,True,5,64


In [25]:
best_ffn_current_kfold["f1"].mean()

0.7002612172691872

In [26]:
best_ffn_current_kfold["precision"].mean()

0.7006151039699103

In [27]:
best_ffn_current_kfold["recall"].mean()

0.7017711936951435

In [28]:
np.stack(best_ffn_current_kfold["f1_scores"]).mean(axis=0)

array([0.75927967, 0.51460083, 0.6646804 , 0.86248397])

In [29]:
np.stack(best_ffn_current_kfold["precision_scores"]).mean(axis=0)

array([0.76469654, 0.49565241, 0.64347717, 0.89863429])

In [30]:
np.stack(best_ffn_current_kfold["recall_scores"]).mean(axis=0)

array([0.75395235, 0.53631841, 0.68767908, 0.82913493])