Skip to content
Go to file

Latest commit


Git stats


Failed to load latest commit information.
Latest commit message
Commit time

Survival Cluster Analysis (ACM CHIL 2020)

This repository contains the TensorFlow code to replicate experiments in our paper Survival Cluster Analysis accepted at ACM Conference on Health, Inference, and Learning (ACM CHIL) 2020:

  title={Survival Cluster Analysis},
  author={Paidamoyo Chapfuwa and Chunyuan Li and Nikhil Mehta and Lawrence Carin and Ricardo Henao},
  booktitle={ACM Conference on Health, Inference, and Learning},



Illustration of Survival Clustering Analysis (SCA). The latent space has a mixture-of-distributions structure, illustrated as three mixture components. Observation x is mapped into its latent representation z via a deterministic encoding, which is then used to stochastically predict (via sampling) the time-to-event p(t|x).


Cluster-specific Kaplan-Meier survival profiles for three clustering methods on the SLEEP dataset. Our model (SCA) can identify high-, medium- and low-risk individuals. Demonstrating the need to account for time information via a non-linear transformation of covariates when clustering survival datasets.


The code is implemented with the following dependencies:

pip install -r requirements.txt


We consider the following datasets:

  • Flchain
  • SEER
  • SLEEP: A subset of the Sleep Heart Health Study (SHHS), a multi-center cohort study implemented by the National Heart Lung & Blood Institute to determine the cardiovascular and other consequences of sleep-disordered breathing.
  • Framingham: A subset (Framingham Offspring) of the longitudinal study of heart disease dataset, initially for predicting 10-year risk for future coronary heart disease (CHD).
  • EHR: A large study from Duke University Health System centered around inpatient visits due to comorbidities in patients with Type-2 diabetes.

For convenience, we provide pre-processing scripts of all datasets (except EHR and Framingham). In addition, the data directory contains downloaded Flchain and SUPPORT datasets.

Model Training

Please modify the train arguments with the chosen:

  • dataset is set to one of the three public datasets {flchain, support, seer, sleep}, the default is support
  • K cluster uppper bound n_clusters, the default is 25
  • Dirichlet process concetration parameter gamma_0 selected from {2, 3, 4, 8}, default is 2
 python --dataset support --n_clusters 25 --gamma_0 2
  • The hyper-parameters settings can be found at

Metrics and Visualizations

Once the networks are trained and the results are saved, we extract the following key results:

  • Training and evaluation metrics are logged in model.log
  • Epoch based cost function plots can be found in the plots directory
  • Numpy files to generate calibration and cluster plots are saved in matrix directory
  • Run the Calibration.ipynb to generate calibration results and Clustering.ipynb for clustering results


This work leverages the calibration framework from SFM and the accuracy objective from DATE. Contact Paidamoyo for issues relevant to this project.

You can’t perform that action at this time.