In [1]:
from scHash import *
from util import *
from dataModule import *
import anndata as ad

# Pancreas Dataset

We demonstrate how scHash encode multiple datasets into hash codes for six public avaialble Pancreas datasets.

The raw data for first five datasets can be obtained from Harmony https://github.com/immunogenomics/harmony2019/tree/master/data/figure5

The sixth Pancreas dataset is public available at GSE83139.

We compiled the six datasets into one AnnData object for easy demonstration. The processed data can be downloaded hereh.

In [6]:
data_dir = '../../../share_data/Pancreas_Wang/fivepancreas_wang_raw.h5ad'

# This data contains both the reference and query source
data = ad.read_h5ad(data_dir)

# Split the reference and query data. We use Muraro(celseq) for query as demonstration. 
# You could specify your query datasets.
train = data[data.obs.dataset != 'smartseq']
test = data[data.obs.dataset == 'smartseq']

# set up datamodule
########## consider write this into one function as setup_data
datamodule = Cross_DataModule(train_data = train, cell_type_key='cell_type')
datamodule.setup()
N_CLASS = datamodule.N_CLASS
N_FEATURES = datamodule.N_FEATURES

# set the query data
# this can be also set after train
datamodule.test_data = test


########### consider write into a function again
# Init ModelCheckpoint callback
checkpointPath = '../checkpoint/'

# Train
checkpoint_callback = ModelCheckpoint(
                            monitor='Val_F1_score_median_CHC_epoch',
                            dirpath=checkpointPath,
                            filename='scHash-{epoch:02d}-{Val_F1_score_median_CHC_epoch:.3f}',
                            verbose=True,
                            # save_last = True,
                            mode='max'
                            )
early_stopping_callback = EarlyStopping(monitor="Val_F1_score_median_CHC_epoch")
start = time.time()
trainer = pl.Trainer(max_epochs=200,
                    gpus=1,
                    check_val_every_n_epoch=10,
                    progress_bar_refresh_rate=50,
                    callbacks=[checkpoint_callback]
                    )
print("Number of Feature: ", N_FEATURES)
model = scHashModel(N_CLASS, N_FEATURES)

trainer.fit(model = model, datamodule = datamodule)



# Test the best model
best_model_path = checkpoint_callback.best_model_path
best_model = scHashModel.load_from_checkpoint(best_model_path, n_class=N_CLASS, n_features=N_FEATURES)

best_model.eval()

start = time.time()
trainer.test(model=best_model, datamodule=datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Number of Feature:  1000
hparam: l_r = 1.2e-05, lambda = 0.001, beta = 0.9999


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | hash_layer | Sequential | 641 K 
------------------------------------------
641 K     Trainable params
0         Non-trainable params
641 K     Total params
2.567     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 9, global step 9: 'Val_F1_score_median_CHC_epoch' reached 0.00000 (best 0.00000), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=09-Val_F1_score_median_CHC_epoch=0.000.ckpt' as top 1


Epoch: 9, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.000,                     val_labeling_accuracy_CHC:0.255,                    val_F1_score_weighted_average_CHC:0.124,                    val_F1_score_per_class_CHC:['0.000', '0.429', '0.000', '0.000', '0.000', '0.000', '0.000', '0.000', '0.000', '0.035', '0.000', '0.000'],                     val_precision:0.082,                     val_recall:0.255,                     train_F1_score_median_CHC: 0.000


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 19: 'Val_F1_score_median_CHC_epoch' was not in top 1


Epoch: 19, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.000,                     val_labeling_accuracy_CHC:0.259,                    val_F1_score_weighted_average_CHC:0.132,                    val_F1_score_per_class_CHC:['0.012', '0.428', '0.000', '0.000', '0.048', '0.000', '0.000', '0.000', '0.000', '0.036', '0.000', '0.000'],                     val_precision:0.340,                     val_recall:0.259,                     train_F1_score_median_CHC: 0.000


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 29: 'Val_F1_score_median_CHC_epoch' reached 0.02465 (best 0.02465), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=29-Val_F1_score_median_CHC_epoch=0.025.ckpt' as top 1


Epoch: 29, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.025,                     val_labeling_accuracy_CHC:0.269,                    val_F1_score_weighted_average_CHC:0.145,                    val_F1_score_per_class_CHC:['0.041', '0.428', '0.009', '0.000', '0.098', '0.000', '0.000', '0.000', '0.091', '0.051', '0.000', '0.118'],                     val_precision:0.594,                     val_recall:0.269,                     train_F1_score_median_CHC: 0.016


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 39: 'Val_F1_score_median_CHC_epoch' reached 0.03006 (best 0.03006), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=39-Val_F1_score_median_CHC_epoch=0.030.ckpt' as top 1


Epoch: 39, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.030,                     val_labeling_accuracy_CHC:0.284,                    val_F1_score_weighted_average_CHC:0.162,                    val_F1_score_per_class_CHC:['0.041', '0.439', '0.045', '0.000', '0.122', '0.000', '0.000', '0.000', '0.174', '0.154', '0.000', '0.019'],                     val_precision:0.553,                     val_recall:0.284,                     train_F1_score_median_CHC: 0.048


Validation: 0it [00:00, ?it/s]

Epoch 49, global step 49: 'Val_F1_score_median_CHC_epoch' reached 0.05143 (best 0.05143), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=49-Val_F1_score_median_CHC_epoch=0.051.ckpt' as top 1


Epoch: 49, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.051,                     val_labeling_accuracy_CHC:0.317,                    val_F1_score_weighted_average_CHC:0.208,                    val_F1_score_per_class_CHC:['0.074', '0.454', '0.139', '0.000', '0.197', '0.000', '0.000', '0.024', '0.385', '0.370', '0.000', '0.029'],                     val_precision:0.572,                     val_recall:0.317,                     train_F1_score_median_CHC: 0.090


Validation: 0it [00:00, ?it/s]

Epoch 59, global step 59: 'Val_F1_score_median_CHC_epoch' reached 0.08166 (best 0.08166), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=59-Val_F1_score_median_CHC_epoch=0.082.ckpt' as top 1


Epoch: 59, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.082,                     val_labeling_accuracy_CHC:0.342,                    val_F1_score_weighted_average_CHC:0.245,                    val_F1_score_per_class_CHC:['0.112', '0.470', '0.200', '0.000', '0.259', '0.000', '0.000', '0.047', '0.514', '0.500', '0.016', '0.051'],                     val_precision:0.526,                     val_recall:0.342,                     train_F1_score_median_CHC: 0.140


Validation: 0it [00:00, ?it/s]

Epoch 69, global step 69: 'Val_F1_score_median_CHC_epoch' was not in top 1


Epoch: 69, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.078,                     val_labeling_accuracy_CHC:0.325,                    val_F1_score_weighted_average_CHC:0.220,                    val_F1_score_per_class_CHC:['0.203', '0.124', '0.450', '0.000', '0.237', '0.000', '0.000', '0.047', '0.091', '0.556', '0.062', '0.065'],                     val_precision:0.662,                     val_recall:0.325,                     train_F1_score_median_CHC: 0.146


Validation: 0it [00:00, ?it/s]

Epoch 79, global step 79: 'Val_F1_score_median_CHC_epoch' reached 0.09007 (best 0.09007), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=79-Val_F1_score_median_CHC_epoch=0.090.ckpt' as top 1


Epoch: 79, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.090,                     val_labeling_accuracy_CHC:0.344,                    val_F1_score_weighted_average_CHC:0.250,                    val_F1_score_per_class_CHC:['0.391', '0.133', '0.455', '0.000', '0.234', '0.000', '0.000', '0.047', '0.000', '0.556', '0.106', '0.074'],                     val_precision:0.668,                     val_recall:0.344,                     train_F1_score_median_CHC: 0.165


Validation: 0it [00:00, ?it/s]

Epoch 89, global step 89: 'Val_F1_score_median_CHC_epoch' reached 0.10182 (best 0.10182), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=89-Val_F1_score_median_CHC_epoch=0.102.ckpt' as top 1


Epoch: 89, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.102,                     val_labeling_accuracy_CHC:0.377,                    val_F1_score_weighted_average_CHC:0.292,                    val_F1_score_per_class_CHC:['0.644', '0.145', '0.468', '0.000', '0.237', '0.034', '0.000', '0.047', '0.000', '0.556', '0.120', '0.083'],                     val_precision:0.700,                     val_recall:0.377,                     train_F1_score_median_CHC: 0.194


Validation: 0it [00:00, ?it/s]

Epoch 99, global step 99: 'Val_F1_score_median_CHC_epoch' reached 0.11347 (best 0.11347), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=99-Val_F1_score_median_CHC_epoch=0.113.ckpt' as top 1


Epoch: 99, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.113,                     val_labeling_accuracy_CHC:0.413,                    val_F1_score_weighted_average_CHC:0.328,                    val_F1_score_per_class_CHC:['0.828', '0.147', '0.484', '0.000', '0.264', '0.034', '0.000', '0.047', '0.000', '0.588', '0.199', '0.080'],                     val_precision:0.700,                     val_recall:0.413,                     train_F1_score_median_CHC: 0.203


Validation: 0it [00:00, ?it/s]

Epoch 109, global step 109: 'Val_F1_score_median_CHC_epoch' reached 0.11744 (best 0.11744), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=109-Val_F1_score_median_CHC_epoch=0.117.ckpt' as top 1


Epoch: 109, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.117,                     val_labeling_accuracy_CHC:0.428,                    val_F1_score_weighted_average_CHC:0.343,                    val_F1_score_per_class_CHC:['0.888', '0.152', '0.490', '0.000', '0.284', '0.034', '0.000', '0.070', '0.000', '0.588', '0.211', '0.083'],                     val_precision:0.709,                     val_recall:0.428,                     train_F1_score_median_CHC: 0.211


Validation: 0it [00:00, ?it/s]

Epoch 119, global step 119: 'Val_F1_score_median_CHC_epoch' reached 0.12191 (best 0.12191), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=119-Val_F1_score_median_CHC_epoch=0.122.ckpt' as top 1


Epoch: 119, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.122,                     val_labeling_accuracy_CHC:0.444,                    val_F1_score_weighted_average_CHC:0.360,                    val_F1_score_per_class_CHC:['0.931', '0.160', '0.497', '0.000', '0.320', '0.034', '0.000', '0.070', '0.000', '0.588', '0.248', '0.083'],                     val_precision:0.711,                     val_recall:0.444,                     train_F1_score_median_CHC: 0.221


Validation: 0it [00:00, ?it/s]

Epoch 129, global step 129: 'Val_F1_score_median_CHC_epoch' reached 0.12305 (best 0.12305), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=129-Val_F1_score_median_CHC_epoch=0.123.ckpt' as top 1


Epoch: 129, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.123,                     val_labeling_accuracy_CHC:0.453,                    val_F1_score_weighted_average_CHC:0.368,                    val_F1_score_per_class_CHC:['0.954', '0.163', '0.501', '0.000', '0.333', '0.034', '0.000', '0.070', '0.000', '0.588', '0.282', '0.083'],                     val_precision:0.712,                     val_recall:0.453,                     train_F1_score_median_CHC: 0.222


Validation: 0it [00:00, ?it/s]

Epoch 139, global step 139: 'Val_F1_score_median_CHC_epoch' was not in top 1


Epoch: 139, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.123,                     val_labeling_accuracy_CHC:0.460,                    val_F1_score_weighted_average_CHC:0.375,                    val_F1_score_per_class_CHC:['0.967', '0.163', '0.505', '0.000', '0.362', '0.034', '0.000', '0.071', '0.000', '0.500', '0.314', '0.083'],                     val_precision:0.717,                     val_recall:0.460,                     train_F1_score_median_CHC: 0.235


Validation: 0it [00:00, ?it/s]

Epoch 149, global step 149: 'Val_F1_score_median_CHC_epoch' reached 0.12589 (best 0.12589), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=149-Val_F1_score_median_CHC_epoch=0.126.ckpt' as top 1


Epoch: 149, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.126,                     val_labeling_accuracy_CHC:0.460,                    val_F1_score_weighted_average_CHC:0.375,                    val_F1_score_per_class_CHC:['0.969', '0.165', '0.505', '0.000', '0.358', '0.033', '0.000', '0.071', '0.000', '0.500', '0.303', '0.087'],                     val_precision:0.704,                     val_recall:0.460,                     train_F1_score_median_CHC: 0.252


Validation: 0it [00:00, ?it/s]

Epoch 159, global step 159: 'Val_F1_score_median_CHC_epoch' reached 0.13476 (best 0.13476), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=159-Val_F1_score_median_CHC_epoch=0.135.ckpt' as top 1


Epoch: 159, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.135,                     val_labeling_accuracy_CHC:0.467,                    val_F1_score_weighted_average_CHC:0.385,                    val_F1_score_per_class_CHC:['0.967', '0.173', '0.510', '0.000', '0.378', '0.097', '0.000', '0.093', '0.000', '0.429', '0.344', '0.087'],                     val_precision:0.693,                     val_recall:0.467,                     train_F1_score_median_CHC: 0.278


Validation: 0it [00:00, ?it/s]

Epoch: 169, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.141,                     val_labeling_accuracy_CHC:0.472,                    val_F1_score_weighted_average_CHC:0.392,                    val_F1_score_per_class_CHC:['0.967', '0.186', '0.513', '0.000', '0.390', '0.097', '0.000', '0.093', '0.000', '0.429', '0.360', '0.091'],                     val_precision:0.694,                     val_recall:0.472,                     train_F1_score_median_CHC: 0.300


Epoch 169, global step 169: 'Val_F1_score_median_CHC_epoch' reached 0.14138 (best 0.14138), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=169-Val_F1_score_median_CHC_epoch=0.141.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 179, global step 179: 'Val_F1_score_median_CHC_epoch' reached 0.17633 (best 0.17633), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=179-Val_F1_score_median_CHC_epoch=0.176.ckpt' as top 1


Epoch: 179, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.176,                     val_labeling_accuracy_CHC:0.473,                    val_F1_score_weighted_average_CHC:0.396,                    val_F1_score_per_class_CHC:['0.966', '0.196', '0.513', '0.000', '0.378', '0.156', '0.000', '0.093', '0.000', '0.462', '0.370', '0.095'],                     val_precision:0.691,                     val_recall:0.473,                     train_F1_score_median_CHC: 0.320


Validation: 0it [00:00, ?it/s]

Epoch 189, global step 189: 'Val_F1_score_median_CHC_epoch' reached 0.18884 (best 0.18884), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=189-Val_F1_score_median_CHC_epoch=0.189.ckpt' as top 1


Epoch: 189, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.189,                     val_labeling_accuracy_CHC:0.475,                    val_F1_score_weighted_average_CHC:0.400,                    val_F1_score_per_class_CHC:['0.966', '0.221', '0.514', '0.000', '0.358', '0.156', '0.000', '0.093', '0.000', '0.462', '0.360', '0.095'],                     val_precision:0.688,                     val_recall:0.475,                     train_F1_score_median_CHC: 0.331


Validation: 0it [00:00, ?it/s]

Epoch 199, global step 199: 'Val_F1_score_median_CHC_epoch' reached 0.20616 (best 0.20616), saving model to '/project/6061845/shaoc/scDeepHash1/checkpoint/scHash-epoch=199-Val_F1_score_median_CHC_epoch=0.206.ckpt' as top 1


Epoch: 199, Val_loss_epoch: 0.04
val_F1_score_median_CHC:0.206,                     val_labeling_accuracy_CHC:0.476,                    val_F1_score_weighted_average_CHC:0.403,                    val_F1_score_per_class_CHC:['0.964', '0.228', '0.514', '0.000', '0.350', '0.185', '0.000', '0.093', '0.000', '0.462', '0.390', '0.105'],                     val_precision:0.690,                     val_recall:0.476,                     train_F1_score_median_CHC: 0.342
hparam: l_r = 1.2e-05, lambda = 0.001, beta = 0.9999


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

################### compute_result time:  0.8244278430938721
################### cell_assign time:  0.05458855628967285
################### query time:  0.8790163993835449


[{'Test_F1': 0.6214442449078369,
  'Test_F1_Median': 0.45772854596384005,
  'Test_precision': 0.7585122142303424,
  'Test_recall': 0.586411889596603,
  'Test_hashing_time': 0.8244278430938721,
  'Test_cell_assign_time': 0.05458855628967285,
  'Test_query_time': 0.8790163993835449,
  'Test_accuracy': 0.5864118933677673}]