# Training a SeqSigNet model

In this notebook, we will go through some helper functions in the library to train and evalaute a SeqSigNet model. 

Note that using the other SeqSigNet models such as the SeqSigNet-Attention-Encoder and SeqSigNet-Attention-BiLSTM models unit are very similar and have the same input to the model (besides the initialisations of the models of course). In particular, you would use the same `obtain_SeqSigNet_input` as we will see in this notebook.

In [None]:
import pickle

from load_anno_mi import (
    anno_mi,
    client_index,
    client_transcript_id,
    output_dim_client,
    y_data_client,
    label_to_id_client,
    id_to_label_client,
)

from nlpsig_networks.scripts.seqsignet_functions import (
    obtain_SeqSigNet_input,
    implement_seqsignet,
)

## AnnoMI

In [None]:
anno_mi.head()

Below are some statistics regarding how many therapist and client utterances there are (after removing duplicates - see the [load_anno_mi.py](load_anno_mi.py) script) and the proportion of each class in the "main_therapist_behaviour" and "client_talk_type".

In [None]:
anno_mi["interlocutor"].value_counts()

In [None]:
anno_mi["main_therapist_behaviour"].value_counts() / anno_mi[
    "interlocutor"
].value_counts()["therapist"]

In [None]:
anno_mi["client_talk_type"].value_counts() / anno_mi["interlocutor"].value_counts()[
    "client"
]

## Client talk type

In [None]:
sum(client_index)

In [None]:
label_to_id_client

In [None]:
id_to_label_client

In [None]:
output_dim_client

## Obtaining SBERT Embeddings

We can use the `SentenceEncoder` class within `nlpsig` to obtain sentence embeddings from a model. This class uses the [`sentence-transformer`](https://www.sbert.net/docs/package_reference/SentenceTransformer.html) package and here, we have use the pre-trained `all-MiniLM-L12-v2` model - alternative models can be found [here](https://www.sbert.net/docs/pretrained_models.html).

All of this is done in the `sbert_embeddings.py` script:
- We use the `nlpsig.SentenceEncoder` class where we pass in our dataframe, `df=anno_mi`, along with the column name which contains our text, `feature_name="utterance_text"`. Further, we can pass in the pre-trained sentence-transformer model we want to use `model_name="all-MiniLM-L12-v2"` (see what other pre-trained models are available in the `sentence-transformer` library [here](https://www.sbert.net/docs/pretrained_models.html).
- We can then simply obtain embeddings for each item in our dataframe by:
    1. Loading the pre-trained model using `nlpsig.SentenceEncoder.load_pretrained_model()`
    2. Obtain SBERT embeddings using `nlpsig.SentenceEncoder.obtain_embeddings()`
- We lastly pickle the embeddings to `"anno_mi_sbert.pkl"`

We will leave the next cell commented out as we have already ran this, but you will need to run this first if this does not exist.

In [None]:
# %run ../sbert_embeddings.py

Now, we can load in the embeddings, and we see that for each item in our dataframe, we now have a corresponding SBERT representation for.

In [None]:
with open("anno_mi_sbert.pkl", "rb") as f:
    sbert_embeddings = pickle.load(f)

sbert_embeddings.shape

In [None]:
anno_mi.shape

## Constructing paths and input for SeqSigNet

For the SeqSigNet models, we use a Signature Window Unit (either the Signature Window Network Unit (SWNU) for SeqSigNet or the Signature Attention Units (SWMHAU) for SeqSigNet-Attention-Encoder and SeqSigNet-Attention-BiLSTM models) to process _units_ of the history. In particular, we take a sliding window across the history to construct smaller snapshots/units of history for which the units are applied. The outputs of the units are fixed length feature representations of the window in history. We will see this in more detail later on to make this a bit clearer!

To do this, for each data point, we need to construct a path of it's history, and then we take a sliding window where we must define three things:
1. `shift`: how many data points to shift the window by as we slide through the history
2. `window_size`: the size of the sliding window
3. `n`: the number of units

After choosing these, this will define the history size that you must have in order for this sliding window to move nicely through the history. The equation for computing this is:
$$
k = \text{shift} * n + (\text{window\_size} - \text{shift})
$$

In this particular example, this is the history is the path consisting of the SBERT representations of the conversation up to the current utterance/data point. We can then take the sliding window to get the units for each snapshot of shitory. Note that having a shift that is smaller than the window size means there is overlap between the units which we typically do want.

For all path construction functionality, we can utilise the `nlpsig` library and in particular the `nlpsig.PrepareData` class. To make this even easier, there are helper functions within the `sig_networks` library: specifically the `obtain_SeqSigNet_input` function from `sig_networks.scripts.seqsignet_functions`.

For this function, we need to pass in:
- the `method` for dimension reduction we wish to use (for full options, see `nlpsig.DimReduce`)
- the `dimension` to reduce down to
- the dataframe, `df`, that contains our data
- the column name which defines the grouping of our texts, `id_column`
    - For this example, this is the transcript/conversation ID, `id_column="transcript_id"`
- the label column `label_column`
    - For this example, this is the client talk type, `label_column="client_talk_type"`
- the embeddings for each item in the dataframe which we obtained earlier
- the `shift`, `window_size`, and `n` to use (see above) 

Optionally, there may be some additional features present that we wish to add in such as any time variables or any other non-textual information. We can pass these into the `features` argument and we also have functionality to either include these features in either the path, the inputs that we concatentate at the end of SWNUNetwork, or both. Here, we will just add some time features to the path.

Note that after constructing a `nlpsig.PrepareData` object, it will create some time variables such as:
- the `timeline_index` which is simply an ordered index for each grouping given by the `id_column`
- (if `datetime` is a column in the dataframe) the `time_encoding` which is the date of the data point as a fraction of the year, e.g. 31/01/2014 at 00:00:00 is 31/365 into the year so will be converted to 2014.0849
- (if `datetime` is a column in the dataframe) the `time_encoding_minute` which is the time of the data point (ignoring the date) as a fraction of a minute, e.g. 00:01:30 would be converted to 1.50

See `nlpsig.PrepareData` for all options implemented. Note that we could also add any other feature that is present in the dataframe (as long as is numerical - categorical ones could be converted to numerical). 

Here we will just use `timeline_index` and `time_encoding_minute` which refers to the index of the current data point within its conversation and the time of the utterance as fraction of a minute. We apply no standardisation to these but standardisation methods are available to be applied via the `standardise_method` argument.

Lastly, since in this example, we only want to predict on the client talk type but include the therapist utterances in the history, we can slice the input by passing the indices of interest into the `path_indices` argument:

In [None]:
seqsignet_input = obtain_SeqSigNet_input(
    method="umap",
    dimension=15,
    df=anno_mi,
    id_column="transcript_id",
    label_column="client_talk_type",
    embeddings=sbert_embeddings,
    shift=3,
    window_size=5,
    n=6,
    features=["time_encoding_minute", "timeline_index"],
    include_features_in_path=True,
    include_features_in_input=False,
    seed=42,
    path_indices=client_index,
)

To see how the `nlpsig` is used to construct the input, please see the source code for the `obtain_SeqSigNet_input` function. Within this function, we have made some certain choices on how we should construct a path and input to SeqSigNet, but it is possible to construct the input differently with some customisation to this function and to adapt this approach for what your task requires.

In [None]:
type(seqsignet_input)

In [None]:
seqsignet_input.keys()

In [None]:
[type(seqsignet_input[key]) for key in seqsignet_input.keys()]

The result of this is a dictionary with keys:
- `x_data`: this in itself is another dictionary with keys `path` which is a four-dimensional tensor containing the batch of units, each of which is a stream of embeddings, which get processed by the SWNU and `features` which is a two-dimensional tensor containing the batch of embeddings that we concatenate to the output of SWNU in the network
    - for the path, the dimensions are given by `[batch, n_units, window_size, input_channels]` 
- `input_channels`: this is a integer which is computed by the dimension of the dimension reduced embeddings + the number of features we want to add (15+2 in our case)
- `embedding_dim`: this is the dimension of the embeddings we want to concatentate with (384 in our case)
- `num_features`: this is the number of features that we're concatenating in the input (0 since we set `include_features_in_input=False`)

In [None]:
seqsignet_input["x_data"].keys()

In [None]:
seqsignet_input["x_data"]["path"].shape

In [None]:
seqsignet_input["x_data"]["features"].shape

In [None]:
(
    seqsignet_input["input_channels"],
    seqsignet_input["embedding_dim"],
    seqsignet_input["num_features"],
)

In the SeqSigNet model, after we apply the SWNU to each of the units in each of the batches, we end up with a three dimensional tensor and we use a global BiLSTM to process the outputs of the units. The final representation for the data point is taken as the last hidden state of the BiLSTM which summarises the history of that point.

Once we have the input for the SeqSigNet, we can use the `implement_seqsignet` helper function provided in the library in `sig_networks.scripts.seqsignet_functions`. Here we set up some of the arguments to pass into the function:

In [None]:
kwargs = {
    "num_epochs": 5,
    "x_data": seqsignet_input["x_data"],
    "y_data": y_data_client,
    "input_channels": seqsignet_input["input_channels"],
    "num_features": seqsignet_input["num_features"],
    "embedding_dim": seqsignet_input["embedding_dim"],
    "log_signature": True,
    "sig_depth": 3,
    "pooling": "signature",
    "swnu_hidden_dim": 5,
    "lstm_hidden_dim": 100,
    "ffn_hidden_dim": [32, 32],
    "output_dim": output_dim_client,
    "BiLSTM": True,
    "dropout_rate": 0.1,
    "learning_rate": 3e-4,
    "seed": 0,
    "loss": "focal",
    "gamma": 2,
    "split_ids": client_transcript_id,
    "k_fold": True,
    "patience": 3,
    "verbose_training": True,
    "verbose_results": True,
    "verbose_model": True,
}

**Note**: we only do a small number of epochs and try to reduce the number of parameters in this model than what we usually run so that this notebook runs a bit quicker.

Here, we set up most of the arguments to the SWNU-Network model as well as some training arguments such as the number of epochs, the learning rate, the loss function. We are going to do a $K$-fold analysis over a deafult of $5$ splits.

Since we're doing $K$-fold analysis, we get a randomly initialised SWNU-Network model (if we set `k_fold=False`, then we would get a trained model trained on the the training split using a validation set), and a results dataframe.

The dataset gets split and since we the data points in the dataset are grouped by the transcript ID, we must pass in the transcript IDs so that we have no data contamination between folds and data splits.

In [None]:
swnu_network, results_df = implement_seqsignet(**kwargs)

In [None]:
results_df

The library also has a function to do some hyperparameter searching - for SeqSigNet, see the `seqsignet_hyperparameter_search` function in `sig_networks.scripts.seqsignet_functions`. For each of the other models, there will be similar functions that we've seen in this notebook. 

For examples of this using the hyperparameter search functions, see the scripts for running the experiments in the `examples/` folder in the repo.