In [1]:
all_batches = [
    's1d1', 's1d2', 's1d3',
    's2d1', 's2d4', 's2d5',
    's3d3', 's3d6', 's3d7', 's3d10',
    's4d1', 's4d8', 's4d9'
]
test_batches = ['s3d7', 's3d10',
    's4d1', 's4d8', 's4d9']
device = 'cuda:0'
experiment = 'saved_atacseq'
verbose = False
label_key = 'cell_type'
n_epoch = 50
checkpoint = 1
learning_rate = 0.01
train_batch_size = 512
finetune_batch_size = 5000
transfer_batch_size = 512

In [2]:
import scanpy as sc
data_path = 'data/ATACSeq'
adata_atac = sc.read_h5ad(f'{data_path}/atac_processed.h5ad')
adata_atac.X = adata_atac.X.toarray()
adata_gex  = sc.read_h5ad(f'{data_path}/gex_processed.h5ad')

In [3]:
def split_data(test_batch):
    adata_atac_train = adata_atac[adata_atac.obs['batch']!=test_batch]
    adata_atac_test  = adata_atac[adata_atac.obs['batch']==test_batch]

    adata_gex_train = adata_gex[adata_gex.obs['batch']!=test_batch]
    adata_gex_test  = adata_gex[adata_gex.obs['batch']==test_batch]

    return [adata_atac_train, adata_gex_train], [adata_atac_test, adata_gex_test]

In [4]:
from src import UnitedNet
for test_batch in test_batches:
    print(f'{test_batch}', end=' ')
    save_path = f'{experiment}/{test_batch}'
    model = UnitedNet(device=device, save_path=save_path, verbose=verbose)
    adatas_train, adatas_test = split_data(test_batch)
    model.register_anndatas(adatas_train, label_index=0, label_key=label_key)
    model.train('supervised_group_identification', n_epoch=n_epoch, checkpoint=checkpoint, save_best_model=True, learning_rate=learning_rate, batch_size=train_batch_size)
    print('trained', end=' ')
    model.finetune('supervised_group_identification', n_epoch=n_epoch, checkpoint=checkpoint, save_best_model=True, learning_rate=learning_rate, batch_size=finetune_batch_size)
    print('finetuned', end=' ')
    model.transfer('supervised_group_identification', adatas_transfer=adatas_test, label_index_validate=0, label_key_validate=label_key, n_epoch=n_epoch, checkpoint=checkpoint, save_best_model=True, learning_rate=learning_rate, batch_size=transfer_batch_size)
    print('done')

s2d1 trained finetuned done
s2d4 trained finetuned done
s2d5 trained finetuned done
s3d3 trained finetuned done
s3d6 trained finetuned done
s3d7 

KeyboardInterrupt: 

In [None]:
def get_best_epoch_by_batch(model_path):
    for test_batch in all_batches:
        _, adatas_test = split_data(test_batch)
        best_epoch_by_batch = {}
        best_ari = 0
        best_epoch = None
        for epoch in range(1, n_epoch+1):
            saved_model=UnitedNet(device=device, verbose=verbose, save_path=None)
            saved_model.load_model(f'{model_path}/epoch_{epoch}.pt')
            ari = saved_model.evaluate(adatas_test, label_index_evaluate=0, label_key_evaluate=label_key)['ari']
            if ari > best_ari:
                best_ari = ari
                best_epoch = epoch
        best_epoch_by_batch[test_batch] = (best_epoch, best_ari)
    return best_epoch_by_batch

best_train_epoch_by_batch = get_best_epoch_by_batch(f'{experiment}/{test_batch}/models/train_1_classification')
best_finetune_epoch_by_batch = get_best_epoch_by_batch(f'{experiment}/{test_batch}/models/finetune_1_classification(finetune)')
best_transfer_epoch_by_batch = get_best_epoch_by_batch(f'{experiment}/{test_batch}/models/transfer_1_classification(transfer)')