In [6]:
import numpy as np
import pickle
import os

seed = 2023

In [7]:
from nlpsig_networks.scripts.lstm_baseline_functions import (
    obtain_path
)

In [8]:
output_dir = "client_talk_type_output"
if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

In [9]:
%run ../load_anno_mi.py

In [10]:
anno_mi.head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-03 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-03 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-03 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-03 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-03 00:00:34


In [11]:
with open("../anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)
    
sbert_embeddings.shape

(13551, 384)

In [12]:
x_data = obtain_path(
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    k=20,
    path_indices=None
)

[INFO] Concatenating the embeddings to the dataframe...
[INFO] - columns beginning with 'e' denote the full embddings.
[INFO] Adding time feature columns into dataframe in `.df`.
[INFO] Adding 'time_encoding' and feature...
[INFO] Adding 'time_diff' and feature...
[INFO] Adding 'timeline_index' feature...
[INFO] Padding ids and storing in `.df_padded` and `.array_padded` attributes.


  0%|          | 0/13551 [00:00<?, ?it/s]

From the below, you can see that there is a transition at the 54th index (that's when the second timeline starts)

In [13]:
anno_mi[anno_mi["transcript_id"]==0].head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
0,high,0,reducing alcohol consumption,0,therapist,00:00:13,Thanks for filling it out. We give this form t...,3,False,,False,,True,open,question,,2023-07-03 00:00:13
1,high,0,reducing alcohol consumption,1,client,00:00:24,Sure.,3,,,,,,,,neutral,2023-07-03 00:00:24
2,high,0,reducing alcohol consumption,2,therapist,00:00:25,"So, let's see. It looks that you put-- You dri...",3,True,information,False,,False,,therapist_input,,2023-07-03 00:00:25
3,high,0,reducing alcohol consumption,3,client,00:00:34,Mm-hmm.,3,,,,,,,,neutral,2023-07-03 00:00:34
4,high,0,reducing alcohol consumption,4,therapist,00:00:34,-and you usually have three to four drinks whe...,3,True,information,False,,False,,therapist_input,,2023-07-03 00:00:34


In [14]:
anno_mi[anno_mi["transcript_id"]==1].head()

Unnamed: 0,mi_quality,transcript_id,topic,utterance_id,interlocutor,timestamp,utterance_text,annotator_id,therapist_input_exists,therapist_input_subtype,reflection_exists,reflection_subtype,question_exists,question_subtype,main_therapist_behaviour,client_talk_type,datetime
54,high,1,reducing alcohol consumption,0,therapist,00:00:11,"And before you leave, uh, is it okay if we go ...",4,True,negotiation,False,,True,closed,question,,2023-07-03 00:00:11
55,high,1,reducing alcohol consumption,1,client,00:00:17,Sure.,4,,,,,,,,neutral,2023-07-03 00:00:17
56,high,1,reducing alcohol consumption,2,therapist,00:00:18,"Okay. Uh, thanks for filling this out. Uh, loo...",4,False,,True,simple,False,,reflection,,2023-07-03 00:00:18
57,high,1,reducing alcohol consumption,3,client,00:00:25,"Yeah, but only on the weekend.",4,,,,,,,,sustain,2023-07-03 00:00:25
58,high,1,reducing alcohol consumption,4,therapist,00:00:27,"Okay. And then when you do drink alcohol, you ...",4,False,,True,simple,False,,reflection,,2023-07-03 00:00:27


We can see that the first few indices are correct:

In [15]:
# only have it's own embedding - no history
x_data[0]

array([[0.00154884, 0.01095446, 0.04541774, ..., 0.02605489, 0.05162514,
        0.0810306 ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [16]:
# has current and the history
x_data[2]

array([[ 0.00154884,  0.01095446,  0.04541774, ...,  0.02605489,
         0.05162514,  0.0810306 ],
       [-0.04369018, -0.01466528,  0.05083121, ...,  0.03477677,
         0.05443462,  0.02002253],
       [ 0.02747428, -0.01277931,  0.02649007, ...,  0.03227151,
        -0.010621  ,  0.04409946],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]])

And for the 54th index, we expect a fresh history:

In [17]:
# first history for transcript_id=1
x_data[54]

array([[0.00665784, 0.00442537, 0.05556735, ..., 0.03004349, 0.07726749,
        0.06769447],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [18]:
# should have two posts as history and current history for transcript_id=1
x_data[56]

array([[ 0.00665784,  0.00442537,  0.05556735, ...,  0.03004349,
         0.07726749,  0.06769447],
       [-0.04369018, -0.01466528,  0.05083121, ...,  0.03477677,
         0.05443462,  0.02002253],
       [-0.01000693, -0.00922856,  0.01035947, ..., -0.06271432,
         0.01522191,  0.08085583],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]])