Skip to content

zhijian-yang/SmileGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Smile-GAN

Smile-GAN is a semi-supervised clustering method which is designed to identify disease-related heterogeneity among the patient group. The model effectively avoids variations among the normal control (CN) group and clusters patients based on disease-related variations only. Semi-supervised clustering of Smile-GAN is achieved through joint training of the mapping and clustering functions, where the mapping function can map CN subjects along different mapping directions depending on disease-related variations.

image info

License

Copyright (c) 2016 University of Pennsylvania. All rights reserved. See https://www.cbica.upenn.edu/sbia/software/license.html

Installation

We highly recommend that users install Anaconda3 on their machines. After installing Anaconda3, Smile-GAN can be used following this procedure:

We recommend that users use the Conda virtual environment:

$ conda create --name smilegan python=3.8

Activate the virtual environment

$ conda activate smilegan

Install SmileGAN from PyPi:

$ pip install SmileGAN

Input structure

The main functions of SmileGAN basically take two Panda dataframes as data inputs: data and covariate (optional). Columns with the names 'participant_id' and diagnosis must exist in both dataframes. Some conventions for the group label/diagnosis: -1 represents healthy control (CN) and 1 represents patient (PT); categorical variables, such as sex, should be encoded as numbers: Female for 0 and Male for 1, for example.

Example for data:

participant_id    diagnosis    ROI1    ROI2 ...
subject-1	    -1         325.4   603.4
subject-2            1         260.5   580.3
subject-3           -1         326.5   623.4
subject-4            1         301.7   590.5
subject-5            1	       293.1   595.1
subject-6            1         287.8   608.9

Example for covariate

participant_id    diagnosis    age    sex ...
subject-1	    -1         57.3   0
subject-2 	     1         43.5   1
subject-3           -1         53.8   1
subject-4            1         56.0   0
subject-5            1	       60.0   1
subject-6            1         62.5   0

Example

We offer a toy dataset in the folder "SmileGAN/dataset".

Runing SmileGAN for clustering CN vs Subtype1 vs Subtype2 vs ...

import pandas as pd
from SmileGAN.Smile_GAN_clustering import single_model_clustering, cross_validated_clustering, clustering_result

train_data = pd.read_csv('train_roi.csv')
covariate = pd.read_csv('train_cov.csv')

output_dir = "PATH_OUTPUT_DIR"
ncluster = 3
start_saving_epoch = 9000
max_epoch = 14000

## three parameters for stopping threshold
WD = 0.10
AQ = 20
cluster_loss = 0.0015

## one parameter for consensus method
consensus_type = "highest_matching_clustering"

When using the package, WD, AQ, cluster_loss, consensus_type need to be chosen empirically:

WD: Wasserstein Distance measures the distance between generated PT data along each direction and real PT data. (Recommended value: 0.11-0.14)

AQ: Alteration Quantity measures the number of participants who change cluster labels during the last three training epochs. Low AQ implies convergence in training. (Recommended value: 1/20 of the PT sample size)

cluster_loss: Cluster loss measures how well the clustering function reconstructs the sampled Z variable. (Recommended value: 0.0015-0.002)

consensus_type: Consensus_type needs to be chosen from "consensus_clustering" and "highest_matching_clustering". It determines how the final consensus result is derived from the k clustering results obtained through the k-fold hold-out CV procedure. "highest_matching_clustering" is recommended if the Adjusted Random Index among k clustering results is greater than 0.3. Otherwise, "consensus_clustering" might give more reliable consensus results. The user can always use function clustering_result, trained models, and a different consensus_type to rederive results with a different consensus_type without retraining.

Some other parameters, lam, mu, batch_size, have default values but need to be changed in some cases:

batch_size: Size of the batch for each training epoch. (Default to be 25.) It is necessary to reset it to 1/10 - 1/20 of the PT sample size.

lam: coefficient controlling the relative importance of cluster_loss in the training objective function. (Default to be 9).

mu: coefficient controlling the relative importance of change_loss in the training objective function. (Default to be 5). It is necessary to try different values of mu (mu = 1-7), and choose the value leading to the highest ARI (Adjusted Random Index).

single_model_clustering(train_data, ncluster, start_saving_epoch, max_epoch,\
					    output_dir, WD, AQ, cluster_loss, covariate=covariate)

single_model_clustering performs clustering without cross validation. Since only one model is trained with this function, the model may not be representative or reproducible. Therefore, this function is not recommended. The function automatically saves a CSV file with clustering results and returns the same dataframe.

fold_number = 10  # number of folds the leave-out cv runs
data_fraction = 0.8 # fraction of data used in each fold
cross_validated_clustering(train_data, ncluster, fold_number, data_fraction, start_saving_epoch, max_epoch,\
					    output_dir, WD, AQ, cluster_loss, consensus_tpype, covariate=covariate)

cross_validated_clustering performs clustering with leave-out cross validation. It is the recommended function for clustering. Since the CV process may take a long training time on a normal desktop computer, the function enables an early stop and later resumption. Users can set stop_fold to be an early stopping point and start_fold depending on the previous stopping point. The function automatically saves a CSV file with clustering results and the mean ARI value.

model_dirs = ['PATH_TO_CHECKPOINT1','PATH_TO_CHECKPOINT2',...] #list of paths to previously saved checkpoints (with name 'converged_model_foldk' after cv process)
cluster_label, cluster_probabilities, _, _ = clustering_result(model_dirs, 'highest_matching_clustering', train_data, covariate)

clustering_result is a function used for clustering patient data using previously saved models. Input data and covariate (optional) should be Panda dataframes with the same format shown before. Only PT data (which can be inside or outside of the training set) for which the user wants to derive cluster memberships needs to be provided, with diagnoses set to 1. ***The function returns cluster labels of PT data following the order of PT in the provided dataframe. *** If consensus_type is chosen to be 'highest_matching_clustering, probabilities of each cluster will also be returned. 

Citation

If you use this package for research, please cite the following paper:

@article{yang2021BrainHeterogeneity,
author = {Yang, Zhijian and Nasrallah, Ilya M. and Shou, Haochang and Wen, Junhao and Doshi, Jimit and Habes, Mohamad and Erus, Guray and Abdulkadir, Ahmed and Resnick, Susan M. and Albert, Marilyn S. and Maruff, Paul and Fripp, Jurgen and Morris, John C. and Wolk, David A. and Davatzikos, Christos and {iSTAGING Consortium} and {Baltimore Longitudinal Study of Aging (BLSA)} and {Alzheimer’s Disease Neuroimaging Initiative (ADNI)}},
year = {2021},
month = {12},
pages = {},
title = {A deep learning framework identifies dimensional representations of Alzheimer’s Disease from brain structure},
volume = {12},
journal = {Nature Communications},
doi = {10.1038/s41467-021-26703-z}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages