In [1]:
import numpy as np
import pickle
import os

seed = 2023

In [2]:
from nlpsig_networks.scripts.ffn_baseline_functions import (
    ffn_hyperparameter_search,
)

In [3]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

## AnnoMI

In [4]:
%run ../load_anno_mi.py

In [5]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-05 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-05 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-05 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-05 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-05 00:00:34


In [6]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

# Baseline: FFN baseline

Using the embeddings for the sentences directly in a FFN to predict the client talk type.

In [7]:
num_epochs = 100
hidden_dim_sizes = [[32,32],[64,64],[128,128],[256,256]]
dropout_rates = [0.5, 0.2, 0.1]
learning_rates = [1e-3, 1e-4, 5e-4]
seeds = [0, 1, 12, 123, 1234]
loss = "focal"
gamma = 2
validation_metric = "f1"

In [8]:
hidden_dim_sizes

[[32, 32], [64, 64], [128, 128], [256, 256]]

In [9]:
learning_rates

[0.001, 0.0001, 0.0005]

We use the `ffn_hyperparameter_search` function which loops through the different hidden dimensions, dropout rates and learning rates to find the best model for the validation set. We evaluate the model on several seeds and average the performance over the seeds.

In [10]:
ffn_current, best_ffn_current, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data_client,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_client,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=False,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}.csv",
    verbose=False
)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_best_model.csv


In [11]:
ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,batch_size,model_id
0,focal,0.701859,0.629756,"[0.7978339350180507, 0.5500794912559618, 0.541...",0.529493,0.711896,0.618425,"[0.8179148311306902, 0.5536062378167641, 0.483...","(32, 32)",0.5,0.0010,0,2,False,64,0
0,focal,0.689219,0.612203,"[0.7883472057074911, 0.5224625623960066, 0.525...",0.591756,0.705390,0.615503,"[0.8057971014492752, 0.5424430641821946, 0.498...","(32, 32)",0.5,0.0010,1,2,False,64,0
0,focal,0.689219,0.612550,"[0.7896323086196504, 0.5419354838709678, 0.506...",0.531607,0.697955,0.605220,"[0.8017556693489393, 0.5411764705882353, 0.472...","(32, 32)",0.5,0.0010,12,2,False,64,0
0,focal,0.701859,0.625198,"[0.7988200589970502, 0.5295109612141653, 0.547...",0.530302,0.712825,0.620700,"[0.8146551724137931, 0.532520325203252, 0.5149...","(32, 32)",0.5,0.0010,123,2,False,64,0
0,focal,0.686245,0.617221,"[0.7821297429620564, 0.5457227138643067, 0.523...",0.526035,0.705390,0.618059,"[0.8029850746268656, 0.5750000000000001, 0.476...","(32, 32)",0.5,0.0010,1234,2,False,64,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,focal,0.686989,0.616372,"[0.7854984894259818, 0.5229357798165138, 0.540...",0.551644,0.712825,0.628322,"[0.8055353241077932, 0.5676190476190476, 0.511...","(256, 256)",0.1,0.0005,0,2,False,64,35
0,focal,0.707807,0.621480,"[0.8038658328595794, 0.5130890052356021, 0.547...",0.561550,0.729554,0.628370,"[0.8250000000000001, 0.5521739130434783, 0.507...","(256, 256)",0.1,0.0005,1,2,False,64,35
0,focal,0.695167,0.624620,"[0.7874481941977501, 0.5302325581395348, 0.556...",0.551401,0.716543,0.627053,"[0.8115107913669066, 0.5596868884540117, 0.509...","(256, 256)",0.1,0.0005,12,2,False,64,35
0,focal,0.689219,0.619736,"[0.7851941747572816, 0.5390749601275916, 0.534...",0.526658,0.703532,0.616673,"[0.8029629629629629, 0.5708661417322834, 0.476...","(256, 256)",0.1,0.0005,123,2,False,64,35


In [12]:
ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,valid_loss,valid_accuracy,valid_f1,seed,gamma,k_fold,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
"(32, 32)",0.1,0.0001,0.683123,0.605064,0.532173,0.693123,0.598255,274.0,2.0,0.0,64.0,7.0
"(32, 32)",0.1,0.0005,0.691896,0.620601,0.556023,0.704089,0.615632,274.0,2.0,0.0,64.0,8.0
"(32, 32)",0.1,0.001,0.686394,0.614432,0.561662,0.702974,0.617199,274.0,2.0,0.0,64.0,6.0
"(32, 32)",0.2,0.0001,0.683123,0.603386,0.537029,0.69461,0.598255,274.0,2.0,0.0,64.0,4.0
"(32, 32)",0.2,0.0005,0.693532,0.619365,0.530991,0.70316,0.611552,274.0,2.0,0.0,64.0,5.0
"(32, 32)",0.2,0.001,0.691004,0.619794,0.567272,0.704275,0.616431,274.0,2.0,0.0,64.0,3.0
"(32, 32)",0.5,0.0001,0.685502,0.601271,0.551921,0.696283,0.594701,274.0,2.0,0.0,64.0,1.0
"(32, 32)",0.5,0.0005,0.696208,0.621173,0.53965,0.709108,0.615575,274.0,2.0,0.0,64.0,2.0
"(32, 32)",0.5,0.001,0.69368,0.619385,0.541839,0.706691,0.615581,274.0,2.0,0.0,64.0,0.0
"(64, 64)",0.1,0.0001,0.687286,0.608652,0.525517,0.699628,0.604779,274.0,2.0,0.0,64.0,16.0


In [13]:
best_ffn_current

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,batch_size
0,focal,0.697398,0.618733,"[0.7929203539823009, 0.5412844036697249, 0.521...",0.566841,0.716543,0.630701,"[0.80835734870317, 0.5627376425855513, 0.52100...","(256, 256)",0.5,0.0005,0,2,False,64
0,focal,0.702602,0.622152,"[0.7979033197437391, 0.5348460291734198, 0.533...",0.593544,0.727695,0.633698,"[0.8212765957446809, 0.5696969696969697, 0.510...","(256, 256)",0.5,0.0005,1,2,False,64
0,focal,0.689219,0.62375,"[0.7839805825242719, 0.5279034690799397, 0.559...",0.532678,0.710037,0.628391,"[0.8026509572901325, 0.5708955223880597, 0.511...","(256, 256)",0.5,0.0005,12,2,False,64
0,focal,0.697398,0.613688,"[0.7932636469221833, 0.5371621621621621, 0.510...",0.624325,0.72119,0.624889,"[0.8176638176638177, 0.5732484076433122, 0.483...","(256, 256)",0.5,0.0005,123,2,False,64
0,focal,0.69145,0.618324,"[0.790332326283988, 0.5364341085271318, 0.5282...",0.52628,0.708178,0.629914,"[0.7999999999999999, 0.5736137667304015, 0.516...","(256, 256)",0.5,0.0005,1234,2,False,64


In [14]:
best_ffn_current["f1"].mean()

0.6193295055219763

In [15]:
np.stack(best_ffn_current["f1_scores"]).mean(axis=0)

array([0.79168005, 0.53552603, 0.53078244])

## KFold

We can repeat this but use K-Fold evaluation instead - by default, we have $K=5$ folds.

In [16]:
ffn_current_kfold, best_ffn_current_kfold, _, __ = ffn_hyperparameter_search(
    num_epochs=num_epochs,
    x_data=sbert_embeddings[client_index],
    y_data=y_data_client,
    hidden_dim_sizes=hidden_dim_sizes,
    output_dim=output_dim_client,
    dropout_rates=dropout_rates,
    learning_rates=learning_rates,
    seeds=seeds,
    loss=loss,
    gamma=gamma,
    k_fold=True,
    validation_metric=validation_metric,
    results_output=f"{output_dir}/ffn_current_focal_{gamma}_kfold.csv",
    verbose=False
)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

saving results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_kfold.csv
saving the best model results dataframe to CSV for this hyperparameter search in client_talk_type_output/ffn_current_focal_2_kfold_best_model.csv


In [17]:
ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,batch_size,model_id
0,focal,0.677770,0.585480,"[0.7866822429906541, 0.49722991689750695, 0.47...",,0.677770,0.585480,"[0.7866822429906541, 0.49722991689750695, 0.47...","(32, 32)",0.5,0.0010,0,2,True,64,0
0,focal,0.673011,0.591745,"[0.7761802060867482, 0.5219638242894057, 0.477...",,0.673011,0.591745,"[0.7761802060867482, 0.5219638242894057, 0.477...","(32, 32)",0.5,0.0010,1,2,True,64,0
0,focal,0.672119,0.583928,"[0.7812388961269691, 0.5020491803278688, 0.468...",,0.672119,0.583928,"[0.7812388961269691, 0.5020491803278688, 0.468...","(32, 32)",0.5,0.0010,12,2,True,64,0
0,focal,0.678216,0.593687,"[0.7825776508389861, 0.5238718116415959, 0.474...",,0.678216,0.593687,"[0.7825776508389861, 0.5238718116415959, 0.474...","(32, 32)",0.5,0.0010,123,2,True,64,0
0,focal,0.674201,0.587137,"[0.7798099762470309, 0.5202987983111401, 0.461...",,0.674201,0.587137,"[0.7798099762470309, 0.5202987983111401, 0.461...","(32, 32)",0.5,0.0010,1234,2,True,64,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,focal,0.683717,0.593502,"[0.7881266799111838, 0.518, 0.4743792921288959]",,0.683717,0.593502,"[0.7881266799111838, 0.518, 0.4743792921288959]","(256, 256)",0.1,0.0005,0,2,True,64,35
0,focal,0.681636,0.597157,"[0.7838836571701038, 0.5352904434728295, 0.472...",,0.681636,0.597157,"[0.7838836571701038, 0.5352904434728295, 0.472...","(256, 256)",0.1,0.0005,1,2,True,64,35
0,focal,0.680595,0.591522,"[0.7858486130700516, 0.5178273908697101, 0.470...",,0.680595,0.591522,"[0.7858486130700516, 0.5178273908697101, 0.470...","(256, 256)",0.1,0.0005,12,2,True,64,35
0,focal,0.676729,0.591386,"[0.7811792733770101, 0.5245379222434672, 0.468...",,0.676729,0.591386,"[0.7811792733770101, 0.5245379222434672, 0.468...","(256, 256)",0.1,0.0005,123,2,True,64,35


In [18]:
ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()

  ffn_current_kfold.groupby(["hidden_dim", "dropout_rate", "learning_rate"]).mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,accuracy,f1,valid_accuracy,valid_f1,seed,gamma,k_fold,batch_size,model_id
hidden_dim,dropout_rate,learning_rate,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
"(32, 32)",0.1,0.0001,0.675033,0.587862,0.675033,0.587862,274.0,2.0,1.0,64.0,7.0
"(32, 32)",0.1,0.0005,0.671435,0.587135,0.671435,0.587135,274.0,2.0,1.0,64.0,8.0
"(32, 32)",0.1,0.001,0.671048,0.588461,0.671048,0.588461,274.0,2.0,1.0,64.0,6.0
"(32, 32)",0.2,0.0001,0.675123,0.588488,0.675123,0.588488,274.0,2.0,1.0,64.0,4.0
"(32, 32)",0.2,0.0005,0.671554,0.587692,0.671554,0.587692,274.0,2.0,1.0,64.0,5.0
"(32, 32)",0.2,0.001,0.67545,0.590716,0.67545,0.590716,274.0,2.0,1.0,64.0,3.0
"(32, 32)",0.5,0.0001,0.677591,0.586015,0.677591,0.586015,274.0,2.0,1.0,64.0,1.0
"(32, 32)",0.5,0.0005,0.678275,0.590044,0.678275,0.590044,274.0,2.0,1.0,64.0,2.0
"(32, 32)",0.5,0.001,0.675063,0.588395,0.675063,0.588395,274.0,2.0,1.0,64.0,0.0
"(64, 64)",0.1,0.0001,0.676996,0.589598,0.676996,0.589598,274.0,2.0,1.0,64.0,16.0


In [19]:
best_ffn_current_kfold

Unnamed: 0,loss,accuracy,f1,f1_scores,valid_loss,valid_accuracy,valid_f1,valid_f1_scores,hidden_dim,dropout_rate,learning_rate,seed,gamma,k_fold,batch_size
0,focal,0.678364,0.594226,"[0.7819333253674274, 0.5322683706070288, 0.468...",,0.678364,0.594226,"[0.7819333253674274, 0.5322683706070288, 0.468...","(256, 256)",0.5,0.0005,0,2,True,64
0,focal,0.684461,0.600941,"[0.7856719250385483, 0.5348534201954398, 0.482...",,0.684461,0.600941,"[0.7856719250385483, 0.5348534201954398, 0.482...","(256, 256)",0.5,0.0005,1,2,True,64
0,focal,0.683866,0.598593,"[0.7867298578199052, 0.5310367240818981, 0.478...",,0.683866,0.598593,"[0.7867298578199052, 0.5310367240818981, 0.478...","(256, 256)",0.5,0.0005,12,2,True,64
0,focal,0.678364,0.595845,"[0.7795153396203891, 0.5346912794398472, 0.473...",,0.678364,0.595845,"[0.7795153396203891, 0.5346912794398472, 0.473...","(256, 256)",0.5,0.0005,123,2,True,64
0,focal,0.68119,0.593779,"[0.7843971631205673, 0.5283259070655634, 0.468...",,0.68119,0.593779,"[0.7843971631205673, 0.5283259070655634, 0.468...","(256, 256)",0.5,0.0005,1234,2,True,64


In [20]:
best_ffn_current_kfold["f1"].mean()

0.596677178724857

In [21]:
np.stack(best_ffn_current_kfold["f1_scores"]).mean(axis=0)

array([0.78364952, 0.53223514, 0.47414687])