##### Parameter Setting

In [None]:
s1_batches = ['s1d1', 's1d2', 's1d3']
s2_batches = ['s2d1', 's2d4', 's2d5']
s3_batches = ['s3d3', 's3d6', 's3d7', 's3d10']
s4_batches = ['s4d1', 's4d8', 's4d9']
all_batches = s1_batches + s2_batches + s3_batches + s4_batches

batches_used = all_batches
data_path = './data/ATACSeq'
train_epoch = 50
train_batch = 512
train_task = 'cmp'

finetune_epoch = 50
finetune_batch = 5000
finetune_task = 'cmp'

transfer_epoch = 10
transfer_batch = 512
transfer_task = 'sgi'

clean_previous_results = True
checkpoint = 1

batch_key = 'batch'
modified_batch_key = 'batch_modified'
BATCH_TRAIN = 'train'
BATCH_TEST = 'test'

train_task_to_schedule = {
    'cmp': 'translation',
    'sgi': 'classification',
}
finetune_task_to_schedule = {
    'cmp': 'translation(finetune)',
    'sgi': 'classification(finetune)',
}
transfer_task_to_schedule = {
    'cmp': 'translation(transfer)',
    'sgi': 'classification(transfer)',
}

##### Read Data

In [None]:
import anndata as ad
import scanpy as sc

adata_atac = sc.read_h5ad(f'{data_path}/atac_processed.h5ad')
adata_gex  = sc.read_h5ad(f'{data_path}/gex_processed.h5ad')
    
def split_data(test_batch):
    adata_atac_train  = adata_atac[adata_atac.obs['batch'] != test_batch]
    adata_atac_test   = adata_atac[adata_atac.obs['batch'] == test_batch]
    adata_gex_train   = adata_gex[adata_gex.obs['batch']   != test_batch]
    adata_gex_test    = adata_gex[adata_gex.obs['batch']   == test_batch]

    adata_atac_train.obs[modified_batch_key] = BATCH_TRAIN
    adata_gex_train.obs[modified_batch_key]  = BATCH_TRAIN
    adata_atac_test.obs[modified_batch_key]  = BATCH_TEST
    adata_gex_test.obs[modified_batch_key]   = BATCH_TEST
    return [adata_atac_train, adata_gex_train], [adata_atac_test, adata_gex_test]

def concat_adatas(adatas_train, adatas_test):
    return [ad.concat([adata_train,adata_test]) for adata_train, adata_test in zip(adatas_train, adatas_test)]

##### Train Model

In [None]:
from src import UnitedNet
import anndata as ad
import os, shutil
import time 

def clean_results(log_path, model_path, tensorboard_path):
    if os.path.exists(log_path):
        os.remove(log_path)
    if os.path.exists(model_path):
        shutil.rmtree(model_path)
    if os.path.exists(tensorboard_path):
        shutil.rmtree(tensorboard_path)


for test_batch in batches_used:
    print('='*20, test_batch)

    start_time = time.time()

    adatas_train, adatas_test = split_data(test_batch)
    data_split_time = time.time()
    print('Split data time', (data_split_time - start_time)/60)

    log_path, model_path, tensorboard_path = f'saved_log_{test_batch}', f'saved_models_{test_batch}', f'saved_tensorboards_{test_batch}'
    if clean_previous_results: clean_results(log_path, model_path, tensorboard_path)
    model = UnitedNet(
        device='cuda:0',
        log_path=log_path,
        model_path=model_path,
        tensorboard_path=tensorboard_path,
        verbose=False,
    )

    # ======================================== Register Data ========================================
    model.register_anndatas(
        adatas_train, 
        label_index=0, label_key='cell_type', batch_index=0, batch_key=modified_batch_key,
    )
    data_process_time = time.time()
    print('Register data time', (data_process_time - data_split_time)/60)

    # ======================================== Train ========================================
    model.train(
        train_task, n_epoch=train_epoch, batch_size=train_batch, 
        save_best_model=True, checkpoint=checkpoint,
    )
    train_time = time.time()
    print('Train time', (train_time - data_process_time)/60)

    # ======================================== Finetune ========================================
    model.load_model(f'{model_path}/{train_task}_train_1_{train_task_to_schedule[train_task]}/best.pt')
    model.finetune(
        finetune_task, n_epoch=finetune_epoch, batch_size=finetune_batch, 
        save_best_model=True, checkpoint=checkpoint,
    )
    finetune_time = time.time()
    print('Finetune time', (finetune_time - train_time)/60)

    # ======================================== Transfer ========================================
    model.load_model(f'{model_path}/{finetune_task}_finetune_1_{finetune_task_to_schedule[finetune_task]}/best.pt')
    model.transfer(
        transfer_task, n_epoch=transfer_epoch, batch_size=transfer_batch,
        adatas_transfer=adatas_test, 
        label_index_transfer=0, label_key_transfer='cell_type', 
        batch_index_transfer=0, batch_key_transfer=modified_batch_key,
        save_best_model=True, checkpoint=checkpoint,
    )
    transfer_time = time.time()
    print('Transfer time', (transfer_time - finetune_time)/60)

    end_time = time.time()
    print('Total time:', (end_time-start_time)/60, '\n')


##### Evaluate Model

In [None]:
from src import UnitedNet
def evaluate_model(path, test_batch):  
    adatas_train, adatas_test = split_data(test_batch)

    model = UnitedNet(
        device='cuda:0',
        log_path=None,
        model_path=None,
        tensorboard_path=None,
    )
    model.load_model(path)

    model.logger.log_message('Training Data')
    model.evaluate(
        adatas_train, 
        label_index_evaluate=0, label_key_evaluate='cell_type',
        batch_index_evaluate=0, batch_key_evaluate=modified_batch_key,
    )

    model.logger.log_message('Testing Data')
    model.evaluate(
        adatas_test, 
        label_index_evaluate=0, label_key_evaluate='cell_type',
        batch_index_evaluate=0, batch_key_evaluate=modified_batch_key,
    )
    
    model.logger.log_message('All Data')
    adatas_all = concat_adatas(adatas_train, adatas_test)
    model.evaluate(
        adatas_all, 
        label_index_evaluate=0, label_key_evaluate='cell_type',
        batch_index_evaluate=0, batch_key_evaluate=modified_batch_key,
    )

###### Train Model

In [None]:
for test_batch in batches_used:
    print('='*20, test_batch)
    start_time = time.time()

    path = f'{model_path}/{train_task}_train_1_{train_task_to_schedule[train_task]}/best.pt'
    evaluate_model(path, test_batch)

    print('Total time', (time.time() - start_time) / 60)


###### Finetune Model

In [None]:
for test_batch in batches_used:
    print('='*20, test_batch)
    start_time = time.time()

    path = f'{model_path}/{finetune_task}_finetune_1_{finetune_task_to_schedule[finetune_task]}/best.pt'
    evaluate_model(path, test_batch)

    print('Total time', (time.time() - start_time) / 60)


###### Transfer Model

In [None]:
for test_batch in batches_used:
    print('='*20, test_batch)
    start_time = time.time()

    path = f'{model_path}/{transfer_task}_transfer_3_{transfer_task_to_schedule[transfer_task]}/best.pt'
    evaluate_model(path, test_batch)

    print('Total time', (time.time() - start_time) / 60)


##### Infer Model

In [None]:
import scanpy as sc
from src import UnitedNet
def infer_adatas(path, adatas):
    model = UnitedNet(
        device='cuda:0',
        log_path=None,
        model_path=None,
        tensorboard_path=None,
    )
    model.load_model(path)

    adata_inferred = model.infer(
        adatas,
        modalities_provided=[0, 1],
        batch_index_infer=0, batch_key_infer=modified_batch_key, 
        modality_sizes=[13634, 4000]
    )
    
    adata_inferred.obs[modified_batch_key] = list(adatas[0].obs[modified_batch_key])
    sc.pl.umap(adata_inferred, color=[modified_batch_key])

    adata_inferred.obs['batch'] = list(adatas[0].obs['batch'])
    sc.pl.umap(adata_inferred, color=['batch'])
    
    sc.pl.umap(adata_inferred, color=['predicted_label'])
    
    adata_inferred.obs['label'] = list(adatas[0].obs['cell_type'])
    sc.pl.umap(adata_inferred, color=['label'])

###### Train Data

In [None]:
from src import UnitedNet
for test_batch in batches_used:
    print('='*20, test_batch)
    start_time = time.time()
    
    adatas_train, adatas_test = split_data(test_batch)

    path = f'saved_models_{test_batch}/{transfer_task}_transfer_3_{transfer_task_to_schedule[transfer_task]}/best.pt'
    infer_adatas(path, adatas_train)

    print('Total time', (time.time() - start_time) / 60)

###### Test Data

In [None]:
from src import UnitedNet
for test_batch in batches_used:
    print('='*20, test_batch)
    start_time = time.time()
    
    adatas_train, adatas_test = split_data(test_batch)

    path = f'saved_models_{test_batch}/{transfer_task}_transfer_3_{transfer_task_to_schedule[transfer_task]}/best.pt'
    infer_adatas(path, adatas_test)

    print('Total time', (time.time() - start_time) / 60)

###### All Data

In [None]:
from src import UnitedNet
for test_batch in batches_used:
    print('='*20, test_batch)
    start_time = time.time()

    adatas_train, adatas_test = split_data(test_batch)
    adatas_all = concat_adatas(adatas_train, adatas_test)

    path = f'saved_models_{test_batch}/{transfer_task}_transfer_3_{transfer_task_to_schedule[transfer_task]}/best.pt'
    infer_adatas(path, adatas_all)
    
    print('Total time', (time.time() - start_time) / 60)