# Attentive State Space Modeling of Cystic Fibrosis Trajectories

In this notebook, we provide a step-by-step guide for conducting the experiments in Section 4 of our paper "*Attentive State Space Modeling of Disease Progression*" by Ahmed M. Alaa and Mihaela van der Schaar (submitted to **NeurIPS 2019**). The notebook explains the API for our **Tensorflow** implementation of the model in order to facilitate its application to new datasets. A **Pyro** implementation of the model using probabilistic programming will be released soon.


In the following experiments, we use the attentive state-space framework to model cystic fibrosis (CF) progression trajectories. CF is a life-shortening disease that causes lung dysfunction, and is the most common genetic disease in Caucasian populations [1]. Experimental details are listed hereunder.

Before starting the experiment, we need to import the necessary libraries:

In [None]:
from models.SeqModels import attentive_state_space_model

from data.CF_dataset_processing import data_loader

Here, **models.SeqModels** is a module containing the **attentive_state_space_model** class, whereas **data.CF_dataset_processing** contains helper functions for processing the CF data. We will first start by loading the data and explain the format in which a dataset needs to be in order to be fed into our model. 

## Loading the CF dataset

To load the CF data, we simply need to call the **data_loader()** function as follows: 

In [None]:
X_observations, feature_names = data_loader('.')

The variable **X_observations** is formatted to be compatible with the API of attentive state space models. The format accepted by our model is **a list of numpy arrays** of the following form:

*X_observations = [Seq_1, Seq_2, ...., Seq_N],*

where *Seq_n* is the $n^{th}$ "data point" in the dataset. Here, a "data point" is actually a multidimensional time-series of the form:

*Seq_n = numpy.array((sequence_length_n, number_of_feature)).*

Note that each "time series" *Seq_n* is allowed to have a distinct length *sequence_length_n* and so the sizes of the numpy arrays in *X_observations* vary from one time-series to another. The *number_of_features* has to be the same for all sequences though.

Now let us look at the shape of the first sequence in *X_observations*

In [None]:
X_observations[0].shape

This is a sequence of 7 observations, each of which has 90 dimensions. Each dimension is a distinct feature; let us look at the values of these features:

In [None]:
X_observations[0]

As we can see, the values can be either continuous or binary. No external data processing is needed before supplying the data into our model --- however, missing data must be imputed before feeding **X_observations** to the attentive state-space model object.

The second variable retreieved by the **data_loader()** function is a list with the names of the 90 features in every observation in every sequence. Let's take a look at what these feature values actually correspond to:

In [None]:
import numpy as np

np.array(feature_names)

The feature names above are ordered according to their location in the numpy array *X_observation*. The following function is a simple helper that recovers the location of a feature in the list above based in its name:

In [None]:
def get_feature_loc(feature_names, required_feature_name):
    
    return np.where(np.array([(feature_names[k]==required_feature_name)*1 for k in range(len(feature_names))])==1)[0][0]

The CF dataset involved in our experiment was extracted from a cohort of patients enrolled in the UK CF registry, a database held by the UK CF trust [3]. The dataset records annual follow-ups for 10,263 patients over the period from 2008 and 2015, with a total of 60,218 hospital visits/encounters. The 90 features above (associated with each patient) include information on 36 possible treatments, diagnoses for 31 possible comorbidities and 16 possible infections, in addition to biomarkers and demographic information. The FEV1 biomarker (a measure of lung function) is the main measure of illness severity in CF patients. For more information on the meaning and significance of every feature, please refer to our medical study in [2].

Now let us check the total number of patients involved in the dataset:

In [None]:
len(X_observations)

To sum up, every element in the list **X_observations** correspond to a single patient's sequential data in the form of a numpy array, and in each such arrays, every column corresponds to a different feature and every row corresponds to a different time step. To use our model on any new dataset, we only need to put the data in the same format as **X_observations**.

## Instantiating an attentive state-space model 

Now that we have loaded the data, we are ready to instantiate an attentive state-space model object. We can do so as follows:

In [None]:
model = attentive_state_space_model(num_states=3,
                                    maximum_seq_length=7, 
                                    input_dim=90, 
                                    inference_network='Seq2SeqAttention', 
                                    rnn_type='LSTM',
                                    num_iterations=50, 
                                    num_epochs=10, 
                                    batch_size=100, 
                                    learning_rate=5*1e-4, 
                                    num_rnn_hidden=100, 
                                    num_out_hidden=100)

As we can see above, the model takes on a number of arguments that we explain below:

- **num_states**: Number of states in the state space.
- **maximum_seq_length**: Maximum allowable number of time steps for any sequence in the training data. Since our data spans the years 2008 to 2015, we set it to 7. 
- **input_dim**: Number of observations (features) in each time step. Here we have 90 features. 
- **inference_network**: The type of inference network used, this must be set to 'Seq2SeqAttention' to execute our own model but different architectures for the inference networks can also be used. 
- **rnn_type**: The type of RNN cells used in the inference networks. Default is 'LSTM', but can be set to 'GRU', 'RNN' or 'PhasedLSTM' for continuous-time sequences.
- **num_iterations, num_epochs, batch_size, learning_rate**: Parameters for the learning algorithms.
- **num_rnn_hidden, num_out_hidden**: Complexity control for the inference network.


We provide a very simple Sklearn-like API for our model with three main methods:

- **fit**: fits the model parameters to the observed sequences in an unsupervised fashion.
- **predict**: predicts the future trajectory (and attention weights) for a new sequence.
- **sample**: samples a synthetic trajectory from the learned model.

Because our model is **unsupervised**, it does not need to be provided with any labels. To train the model, we only need to pass the list of observations **X_observations** to the **fit** method as follows:

In [None]:
model.fit(X_observations)

Now we are done training the model on the whole data! The model's log-likelihood per iteration of the inference algorithm is stored in the attribute **model._Losses**. We can visualize the model's loss trajectory as follows:

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

plt.plot(np.array(model._Losses)/10**6)
plt.xlabel('Number of iterations')
plt.ylabel('Log-likelihood (1e-6)')

Using stochastic variational inference, the learning algorithm tries to maximize 

$\arg \max_{\theta} \log\left(P\left(X_{observations} \,\big|\, \theta\right)\right)$

As we can see, the more iterations we apply, the better the model likelihood gets: it jumped from $-4 \times 10^{-6}$ in the initial iterations to $-8 \times 10^{-5}$ after training was completed. The best value of the log-likelihood is 0, which is achieved when the data likelihood given the model is 1.

## Understanding CF progression mechanisms: What did the model learn? 

The main strength of a probabilistic model resides in its ability to extract knowledge from data through intuitively interpretable representations. In the following, we discuss two forms of knowledge about the CF progression mechanisms that can be directly derived by merely inspecting the model parameters: **population-level phenotyping** which is an overall description of the disease progression mechanism from a populational perspective, and **individual-level contextualized diagnosis** which is a context-based explanation of the disease mechanism for a particular individual.

### Population-level phenotyping

By setting **num_states** to 3, our model learned a phenotypic representation of three CF progression stages (Stages 1, 2 and 3) in an unsupervised fashion, i.e., each stage is a realization of a hidden state in the model. The learned baseline transition matrix describing the transition rates among states 1 to 3 is:

In [None]:
model.transition_matrix

As we can see, patients in any state at a given time step are more likely to stay in that state rather than transit to any other state in the next time step. Note though that transitions in our model are neither Markovian nor stationary, and hence they cannot be fully understood through the baseline transition matrix only. 

The initial state probabilities are:

In [None]:
model.initial_probabilities

So patients are much more likely to start at state 2. But what do these states/stages really mean? To inspect whether the different states correspond to different levels of disease severity, let us inspect the emission distribution for each state. Of course, one would expect that state 2 is the least sever since most patients start in it, so let us apply a sanity check by looking at the observations in every stage.

To inspect the distribution associated with each stage, we need to look at the attributes **.state_means** and **.state_covars**, e.g., 

In [None]:
model.state_means[1]

As we mentioned earlier, the FEV1 % (Forced expiratory volume) is the standard measure of lung function. Severly ill patients have an FEV1 equal to (or below) 30%, and those are the patients who would need to undergo a lung transplant. Health patients have an FEV1 close to 80%, those are patients who would enjoy a fairly normal lung function. Thus, we will check the mean FEV1 measure in states/stages 1, 2 and 3 to infer their corresponding levels of severity. (The FEV1 feature is number 15 in the **feature_names** list.) 

In [None]:
FEV1_loc = get_feature_loc(feature_names, 'FEV1')

model.state_means[0][FEV1_loc], model.state_means[1][FEV1_loc], model.state_means[2][FEV1_loc]

As we can see, the FEV1 values matched our expectation. State 2, which is the most likely initial state, corresponds to health patients whose average FEV1 is greater than 79 %. State 3 is a mildly sever stage of lung function decline, with an average FEV1 of 68.23 %. Finally, stage 2, which is the least likely initial stage corresponds to an FEV1 of 53.36 %. Thus, inferred states can be thought of as progression stages! This is important because it means that our model learned a clinically meaningful representation of the CF progression trajectories without any supervision. Based on this representation, clinical guidelines can be defined by recommending different treatments/actions to patients in different states. **(Note that the states indexes here are not ordered by severity since the model is unaware which stage is actually more severe, and so one would rename the stages 1-3 according to their respective levels of severity.)**

For the sake of notational convenience, let us define the following mapping from model states to disease stages based on the level of severity as follows:

- **State 1 ⬄ Stage 2** 
- **State 2 ⬄ Stage 1**
- **State 3 ⬄ Stage 3** 

Unfortunately, every time the model is run in this notebook, the stage-state mapping will need to be manually defined so that stages make clinical sense.

In [None]:
stage_1_index = 1 
stage_2_index = 0
stage_3_index = 2

We can visualize the distribution of the different observations per stage as follows:

In [None]:
from scipy.stats import norm

linecolors   = ['blue', 'purple', 'red']

sigma_values = [np.sqrt(model.state_covars[stage_1_index][FEV1_loc][FEV1_loc]), 
                np.sqrt(model.state_covars[stage_2_index][FEV1_loc][FEV1_loc]), 
                np.sqrt(model.state_covars[stage_3_index][FEV1_loc][FEV1_loc])]

mu_values    = [model.state_means[stage_1_index][FEV1_loc], 
                model.state_means[stage_2_index][FEV1_loc], 
                model.state_means[stage_3_index][FEV1_loc]]

FEV1_range = np.linspace(0, 120, 1000)
stage_indx = [1, 2, 3]


fig, ax = plt.subplots(figsize=(5, 3.75))

for sigma, mu, clr, stage in zip(sigma_values, mu_values, linecolors, stage_indx):

    dist = norm(mu, sigma)

    plt.plot(FEV1_range, dist.pdf(FEV1_range), color=clr, c='black', linewidth=3, label='Stage %d' % stage)
    plt.fill_between(FEV1_range, dist.pdf(FEV1_range), color=clr, alpha=0.5)

plt.xlabel('FEV1 % Predicted')
plt.ylabel('Distribution')

plt.legend(loc='upper left', fontsize=12, frameon=True, fancybox=True)
plt.show()

But stages are not just about the FEV1 indicator. All measures of a patient's health should also contribute to the definition of a disease state. One particular measure of the health of a CF patient is the prevalence of comorbidities or infections. Now let us examine the risks of different types of comorbidies and infections for patients in the different CF progression stages. We consider the following comorbidities:

- **Diabetes** (including CF-related diabetes, which is prevalent among CF patients)
- **Asthma**
- **ABPA**
- **Hypertension**
- **Depression** (very common among CF patients)

We can use the **get_feature_loc** functions to recover the positions of those comorbidities:

In [None]:
comorbidities     = ['Diabetes', 'Asthma', 'ABPA', 'Hypertension', 'Depression']
comorbidities_loc = [get_feature_loc(feature_names, _) for _ in comorbidities]

Moreover, we also interested in the following lung infections:

- **Pseudomonas Aeruginosa**

- **Haemophilus Influenza**

- **Klebsiella Pneumoniae**

- **Ecoli**

- **Aspergillus**

- **Staphylococcus Aureus**
 

In [None]:
infections     = ['Pseudomonas Aeruginosa', 'Haemophilus Influenza', 'Klebsiella Pneumoniae', 'Ecoli', 'Aspergillus', 'Staphylococcus Aureus']
infections_loc = [get_feature_loc(feature_names, _) for _ in infections]

How do the risks of comorbidities and infections above vary between stages? Let us investigate their corresponding emission distributions:

In [None]:
for comorb_name, comorb_loc  in zip(comorbidities, comorbidities_loc):
    
    print(comorb_name)
    print("Stage 1: %0.2f | Stage 2: %0.2f | Stage 3: %0.2f" % (model.state_means[stage_1_index][comorb_loc], 
                                                                model.state_means[stage_2_index][comorb_loc], 
                                                                model.state_means[stage_3_index][comorb_loc]))
    print('---------------------------------------------')

Again we see that Stage 1 is less severe, not only in the sense of the FEV1 % biomarker, but also in the sense that patients in this stage are also have a very low risk of developing comorbid conditions such as diabetes, Asthma, ABPA, hypertension and depression. We notice also that depression is quite prevalent in Stage 3 since this is the terminal stage of the disease. Moreoer, we can see a huge difference in the risk of diabetes and hypertension in patients in Stage 1 compared to those in Stages 2 and 3. In fact, Stages 2 and 3 have a risk of 40 % for developing hypertension, whereas Stage 1 have an almost zero risk!

Now let us look at the emission distributions for infections:

In [None]:
for infect_name, infect_loc  in zip(infections, infections_loc):
    
    print(infect_name)
    print("Stage 1: %0.2f | Stage 2: %0.2f | Stage 3: %0.2f" % (model.state_means[stage_1_index][infect_loc], 
                                                                model.state_means[stage_2_index][infect_loc], 
                                                                model.state_means[stage_3_index][infect_loc]))
    print('---------------------------------------------')

As we can see, infections are less discriminated by the state of severity, since they happen due to external/environmental factors rather than being endogenously occuring dure to deteriorating health conditions. Hence, we expect that predicting infections over time would be a much harder task. Thus, we expect infections to be triggers for the exacerbation of the disease but not a manifestation of a more sever disease state.

Now let us collect the different emission distributions in a grouped bar chart:

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 2.5))

# plot comorbidities

ax1=plt.subplot(1,2,1)

stage_1_comorb_means = tuple([model.state_means[stage_1_index][comorb_loc] for comorb_loc in comorbidities_loc])
stage_2_comorb_means = tuple([model.state_means[stage_2_index][comorb_loc] for comorb_loc in comorbidities_loc])
stage_3_comorb_means = tuple([model.state_means[stage_3_index][comorb_loc] for comorb_loc in comorbidities_loc])

ind   = np.arange(len(stage_1_comorb_means))  
width = 0.15                                  

rects1  = ax1.bar(ind, stage_1_comorb_means, width, label='Stage 1', color='b')
rects2  = ax1.bar(ind + width, stage_2_comorb_means, width, label='Stage 2', color='purple')
rects2  = ax1.bar(ind + 2 * width, stage_3_comorb_means, width, label='Stage 3', color='red')

ax1.set_ylabel('Comorbidity Risk')
ax1.set_xticks(ind)
ax1.set_xticklabels(tuple(comorbidities), rotation=40)
ax1.legend(loc='upper_left', fontsize=10, frameon=True, fancybox=True)

# plot infections

ax2=plt.subplot(1,2,2)

stage_1_infect_means = tuple([model.state_means[stage_1_index][infect_loc] for infect_loc in infections_loc])
stage_2_infect_means = tuple([model.state_means[stage_2_index][infect_loc] for infect_loc in infections_loc])
stage_3_infect_means = tuple([model.state_means[stage_3_index][infect_loc] for infect_loc in infections_loc])

ind   = np.arange(len(stage_1_infect_means))  
width = 0.15                                  

rects1  = ax2.bar(ind, stage_1_infect_means, width, label='Stage 1', color='blue')
rects2  = ax2.bar(ind + width, stage_2_infect_means, width, label='Stage 2', color='purple')
rects2  = ax2.bar(ind + 2 * width, stage_3_infect_means, width, label='Stage 3', color='red')

infections_short_names = ['P. Aeruginosa', 'H.Influenza', 'K. Pneumoniae', 
                          'Ecoli', 'Aspergillus', 'S. Aureus']

ax2.set_ylabel('Infection Risk')
ax2.set_xticks(ind)
ax2.set_xticklabels(tuple(infections_short_names), rotation=40)
ax2.legend(loc='upper_left', fontsize=10, frameon=True, fancybox=True)


The above statistics describe the "population-level" aspects of the model: the emission distributions attached to each state tell us what a patient in stage 1 looks like as compared to a patient in stage 3. But what about the individual-level aspects of the model? In the next Section, we investigate the model's outputs for individual patient trajectories, and show how the model can be used to make predictions, inferences and distill knowledge about the individual level behavior of the disease.

### Individualized contextual diagnosis

Population-level modeling of disease stages can be already obtained with other simpler probabilistic models, but our model captures more complex dynamics that are specific to individuals. To retrieve the model outputs for individual trajectories, we need to use the **predict** method as follows:

In [None]:
patient_number = 1

individual_predictions, individual_observations, individual_attentions = model.predict([X_observations[patient_number]])

The **predict** method recovers three different ouputs. The first is a prediction for the next hidden state at every time step, the second is the predicted (expected) observation in the next step, whereas the last is the attention weights assigned to all previous time steps, which reflects their relative contributions to the inference at the current time step. For a sequence length **seq_len**, the outputs have the following shapes and contents:

- **individual_predictions**: This is a list of *seq_len* $\times$ *num_states* numpy array, showing the posterior predictive distribution of the next state at each of the *seq_len* time step.
- **individual_observations**: This is a list of *seq_len* $\times$ *input_dim* numpy array, showing the posterior predictive distribution of the next observation at each of the *seq_len* time step.
- **individual_attentions**: This is a list of lists. Every element in the list is a list of numpy arrays of growing size, each reflects the attention weights assigned to all previous time state realizations to make the current inference.

For this particular patient, let us visualize her state trajectory and FEV1 trajectory. We first take the maximum a posteriori (MAP) estimate of her state trajectory:

In [None]:
Stage_state_map = [stage_1_index, stage_2_index, stage_3_index]

def get_MAP_estimates(Stage_state_map, preds):
    
    return [np.where(np.array(Stage_state_map)==np.argmax(preds, axis=1)[k])[0][0] + 1 for k in range(len(preds))]


MAP_states =  get_MAP_estimates(Stage_state_map, individual_predictions[0])

Also, we extract the predicted values of the FEV1 indicator as follows:

In [None]:
FEV1_predictions = individual_observations[0][:, FEV1_loc]

In [None]:
fig, axs = plt.subplots()

t = list(range(len(MAP_states)))

color = 'tab:red'
axs.set_xlabel('Time step')
axs.set_ylabel('Predicted FEV1 Biomarker', color=color)
axs.plot(t, FEV1_predictions, color=color, linewidth=5, marker='o', markersize=10)
axs.tick_params(axis='y', labelcolor=color)

axs2 = axs.twinx()  

color = 'tab:blue'
axs2.set_ylabel('Predicted Stage', color=color)  
axs2.step(t, MAP_states, color=color, linewidth=5)
axs2.tick_params(axis='y', labelcolor=color)
axs2.set_yticks([1, 2, 3])
axs2.set_ylim([0.5,3.5])

fig.tight_layout()  
plt.show()

As we can see, at time Step 2, the model predicts a transition to progression stage. This prediction preceded an unanticipated sudden drop in the FEV1 biomarker. By knowing that the patient is on the verge of stage 3 one year earlier, preventive interventions might have been possible!

So how do the attention weights look for this patient? What event made us able to predict the onset of Stage 3? To examine this, we plot the attention weights and their evolution over time as follows.

In [None]:
import seaborn as sns

Attention_weights = []

for w in range(len(individual_attentions[0])):
    
    Attention_weights.append(np.vstack((individual_attentions[0][w], np.zeros((len(individual_attentions[0][-1]) - len(individual_attentions[0][w]),1)))))


Attention_weights = np.array(Attention_weights).reshape((len(individual_attentions[0][-1]), len(individual_attentions[0][-1])))[:individual_predictions[0].shape[0], :individual_predictions[0].shape[0]]

mask = np.zeros_like(Attention_weights)

mask[np.triu_indices_from(mask)] = True
mask[np.diag_indices_from(mask)] = False

with sns.axes_style("white"):
    ax = sns.heatmap(Attention_weights, mask=mask, vmin=0, vmax=1, square=True, cmap='RdYlBu_r')
    ax.set_ylabel('Chronological time')
    ax.set_xlabel('Previous time steps')

As we can see, the patient's state trajectory behaves in a Markovian fashion (only current state takes all the weight) only on its edges: at the first time step and the last time step, the only thing that matters for prediction is the patient's current state. This is because in the first time step, the patient has no history, whereas in the final step, the patient is already in the most severe state and hence her current health deterioration depends overrides all past clinical events.

Before the onset of stage 3, however, we find a lot of weight allocated to the previous stages. We can see thus that some events that happened in time steps 1 and 2 may have triggered the onset of stage 3. Let us check out the events in those two time steps: 

In [None]:
events_in_step_0 = np.where(X_observations[patient_number][0,:]==1)[0]
events_in_step_1 = np.where(X_observations[patient_number][1,:]==1)[0]
events_in_step_2 = np.where(X_observations[patient_number][2,:]==1)[0]

We only care about the events that happened in steps 1 and 2 and did not occur in step 0. Those can be extracted as follows:

In [None]:
events_exclusive_in_1 = set(events_in_step_1) - set(events_in_step_0)
events_exclusive_in_1

In [None]:
events_exclusive_in_2 = set(events_in_step_2) - set(events_in_step_0)
events_exclusive_in_2

What are these events? We can check that out using the **feature_names** list:

In [None]:
np.array(feature_names)[list(events_exclusive_in_1)]

In [None]:
np.array(feature_names)[list(events_exclusive_in_2)]

So, the initation of these two particular treatments and the fact that the patient was hospitalized makes the fact that they are in stage 2 very relevant and allows us to think that they are very likely to progress to stage 3, despite the fact their FEV1 at prediction time was ostensibly stable.

But this is one patient. How do states evolve for the overall population? To visualize the typical state trajectories of patients, we first collect the predictions for all patients in the datasets and create an overlaid plot of state posterior distributions as follows:  

In [None]:
all_prediction, all_observation, all_attention = model.predict(X_observations)

Let us first examine the fraction of patients in each stage and how this changes over time:

In [None]:
all_map_estimates         = [np.array(get_MAP_estimates(Stage_state_map, all_prediction[k])) for k in range(len(all_prediction))]


def get_state_fraction_time_step(map_est, state, time_step):
    
    refined_map_est = [map_est[k] for k in range(len(map_est)) if len(map_est[k]) > time_step]
    state_true      = [(refined_map_est[k][time_step]==state)*1 for k in range(len(refined_map_est))]
    
    state_fraction  = np.sum(np.array(state_true))/len(state_true) 
    
    return state_fraction

In [None]:
max_seq_length            = 7
num_states                = 3

stage_1_dist_per_timestep = np.zeros(max_seq_length)
stage_2_dist_per_timestep = np.zeros(max_seq_length)
stage_3_dist_per_timestep = np.zeros(max_seq_length)

for time_step in range(max_seq_length):
    
    stage_1_dist_per_timestep[time_step] = get_state_fraction_time_step(all_map_estimates, 1, time_step)
    stage_2_dist_per_timestep[time_step] = get_state_fraction_time_step(all_map_estimates, 2, time_step)
    stage_3_dist_per_timestep[time_step] = get_state_fraction_time_step(all_map_estimates, 3, time_step)

    
All_states_occupancy = np.vstack((stage_1_dist_per_timestep, stage_2_dist_per_timestep, stage_3_dist_per_timestep))    

In [None]:
with sns.axes_style("white"):
    ax = sns.heatmap(All_states_occupancy, vmin=np.min(All_states_occupancy), 
                     vmax=np.max(All_states_occupancy), square=True, cmap='RdYlBu_r')
    
    ax.set_xlabel('Time steps')
    ax.set_yticklabels(['Stage 1', 'Stage 2', 'Stage 3'], rotation=45)

As we can see from the heatmap above, at time step 0, most patients are in Stage 1, whereas more and more patients transit into the more severe stages as time progresses. By the final time step, we have the patients more evenly distributed in the different disease severity stages. This makes sense because the disease is progressing over time, and unfortunately, no body ever gets cured from cystic fibrosis. 

Now let us look at the attention weights and how their behavior change from one time step to another. We do this by first filtering all attention vectors by the respective time step at which they are computed: 

In [None]:
import itertools

def complete_sequence(seq_in, max_seq_len):
    
    return np.vstack((np.zeros((max_seq_len - len(seq_in),1)), seq_in))
    

merged_attention  = list(itertools.chain.from_iterable(all_attention))    
    
attention_at_step = {"Step 0": [], "Step 1": [], "Step 2": [],
                     "Step 3": [], "Step 4": [], "Step 5": [],
                     "Step 6": []}

step_name          = ["Step 0", "Step 1", "Step 2", "Step 3", "Step 4", "Step 5", "Step 6"]

attention_lengths          = np.array([merged_attention[k].shape[0] for k in range(len(merged_attention))])
average_attention_pre_step = []


for _ in range(len(step_name)):
    
    locs = np.where(attention_lengths==_ + 1)[0]
    
    attention_at_step[step_name[_]] = [complete_sequence(merged_attention[k], max_seq_length) for k in locs]
    
    average_attention_pre_step.append(np.mean(np.array( attention_at_step[step_name[_]]).reshape((-1, max_seq_length)), axis=0))
    

In [None]:
from mpl_toolkits.mplot3d import Axes3D  


fig = plt.figure(figsize=(10,5))
ax = fig.gca(projection='3d')

X    = np.arange(0, 7, 1) 
Y    = np.arange(0, 7, 1)
X, Y = np.meshgrid(X, Y)
Z    = np.array(average_attention_pre_step)

# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap="RdYlBu_r",
                       linewidth=0, antialiased=False)

# Customize the z axis.
ax.set_zlim(0, 1.01)
ax.view_init(elev=35, azim=100)
ax.set_ylabel('Chronological time')
ax.set_xlabel('Previous time steps')
ax.set_zlabel('Attention weights')

# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()

As we have shown before for an individual patient, we can see again for the attention weights averaged over the entire population that state trajectories behaves in a Markovian fashion on the edges: for very sick or very healthy patients, i.e., early stage and terminal stage of the disease. This means that most relevant time steps for whcih clinical decisions are made, non-Markov models are required!

## Sampling synthetic trajectories

Because our model is generative, we can also sample new (synthetic) trajectories from the model. This means that our model can be used for generating synthetic data. Unlike implicit generative models, our model is an explicit generative model that suffers no privacy concern as it does not memorize any individual trajectory in the original dataset. To sample from the trained model, we can use the **sample** API as follows:

In [None]:
sampled_states, sampled_obervations  = model.sample(trajectory_length=7)

This way, a whole new synthetic dataset can be created by repeatedly sampling from the model! This can be very useful for sharing datasets needed for medical studies to conduct preliminary analyses without jeopardizing the patients' privacy.

## References

[1] Rhonda D Szczesniak, Dan Li, Weiji Su, Cole Brokamp, John Pestian, Michael Seid, and John P Clancy. Phenotypes of rapid cystic fibrosis lung disease progression during adolescence and young adulthood. *American journal of respiratory and critical care medicine*, vol. 196, no. 4, pp. 471-478, 2017.

[2] A. M. Alaa, M. van der Schaar, Prognostication and risk factors for cystic fibrosis via automated machine learning, *Scientific reports*, vol. 8, no. 1, 2018.

[3] [https://www.cysticfibrosis.org.uk/the-work-we-do/uk-cf-registry/](1https://www.cysticfibrosis.org.uk/the-work-we-do/uk-cf-registry/
)