In [None]:
on_colab = False

##### Colab Specifics

In [None]:
if on_colab:
    from google.colab import drive
    drive.mount('/UnitedNet')
    %cd /UnitedNet/MyDrive/UnitedNet/

In [None]:
if on_colab:
    !pip install pydantic scanpy ipython-autotime
    %load_ext autotime

##### Tunable Parameters

In [None]:
test_batches =  []
# s1_batches = ['s1d1', 's1d2', 's1d3']
# s2_batches = ['s2d1', 's2d4', 's2d5']
# s3_batches = ['s3d3', 's3d6', 's3d7', 's3d10']
# s4_batches = ['s4d1', 's4d8', 's4d9']

device= 'cuda:0'
data_path = './data/ATACSeq'
train_epoch = 50
train_batch = 512
train_task = 'cmp'

finetune_epoch = 50
finetune_batch = 5000
finetune_task = 'cmp'

transfer_epoch = 10
transfer_batch = 512
transfer_task = 'sgi'

clean_previous_results = True
checkpoint = 1

verbose_training = False
verbose_evaluation = True
verbose_inference = False

train_model = True 
finetune_model = True
transfer_model = True

evaluate_model = True

infer_model = True

##### Auto-Generated Parameters

In [None]:
import os

batch_key = 'batch'
modified_batch_key = 'batch_modified'
BATCH_TRAIN = 'train'
BATCH_TEST = 'test'

train_task_to_schedule = {
    'cmp': 'translation',
    'sgi': 'classification',
}
finetune_task_to_schedule = {
    'cmp': 'translation(finetune)',
    'sgi': 'classification(finetune)',
}
transfer_task_to_schedule = {
    'cmp': 'translation(transfer)',
    'sgi': 'classification(transfer)',
}

root_log_path = 'saved_log'
root_model_path = 'saved_models'
root_tensorboard_path = 'saved_tensorboards'
os.makedirs(root_log_path, exist_ok=True)
os.makedirs(root_model_path, exist_ok=True)
os.makedirs(root_tensorboard_path, exist_ok=True)

def get_training_paths(test_batch):
    return f'{root_log_path}/{test_batch}', f'{root_model_path}/{test_batch}', f'{root_tensorboard_path}/{test_batch}'

def get_evaluation_paths(test_batch):
    return f'{root_log_path}/{test_batch}_eval', None, None

def get_inference_paths(test_batch):
    return None, None, None

def train_model_path(test_batch):
    return f'{root_model_path}/{test_batch}/{train_task}_train_1_{train_task_to_schedule[train_task]}/best.pt'

def finetune_model_path(test_batch):
    return f'{root_model_path}/{test_batch}/{finetune_task}_finetune_1_{finetune_task_to_schedule[finetune_task]}/best.pt'

def transfer_model_path(test_batch):
    return f'{root_model_path}/{test_batch}/{transfer_task}_transfer_3_{transfer_task_to_schedule[transfer_task]}/best.pt'

def print_time(name, start_time):
    print(f"{name}: {((time.time()-start_time)/60):0.2f} mins")

##### Read Data

In [None]:
import anndata as ad
import scanpy as sc

adata_atac = sc.read_h5ad(f'{data_path}/atac_processed.h5ad')
adata_gex  = sc.read_h5ad(f'{data_path}/gex_processed.h5ad')
    
def split_data(test_batch):
    adata_atac_train  = adata_atac[adata_atac.obs['batch'] != test_batch]
    adata_atac_test   = adata_atac[adata_atac.obs['batch'] == test_batch]
    adata_gex_train   = adata_gex[adata_gex.obs['batch']   != test_batch]
    adata_gex_test    = adata_gex[adata_gex.obs['batch']   == test_batch]

    adata_atac_train.obs[modified_batch_key] = BATCH_TRAIN
    adata_gex_train.obs[modified_batch_key]  = BATCH_TRAIN
    adata_atac_test.obs[modified_batch_key]  = BATCH_TEST
    adata_gex_test.obs[modified_batch_key]   = BATCH_TEST
    return [adata_atac_train, adata_gex_train], [adata_atac_test, adata_gex_test]

def concat_adatas(adatas_train, adatas_test):
    return [ad.concat([adata_train,adata_test]) for adata_train, adata_test in zip(adatas_train, adatas_test)]

##### Train, Finetune and then Transfer Model

In [None]:
from src import UnitedNet
import anndata as ad
import os, shutil
import time 

def clean_results(log_path, model_path, tensorboard_path):
    if os.path.exists(log_path):
        os.remove(log_path)
    if os.path.exists(model_path):
        shutil.rmtree(model_path)
    if os.path.exists(tensorboard_path):
        shutil.rmtree(tensorboard_path)

for test_batch in test_batches:
    print('='*20, test_batch)

    start_time = time.time()

    adatas_train, adatas_test = split_data(test_batch)
    print_time('Split data', start_time)

    log_path, model_path, tensorboard_path = get_training_paths(test_batch)
    if clean_previous_results: clean_results(log_path, model_path, tensorboard_path)
    model = UnitedNet(
        device=device,
        log_path=log_path,
        model_path=model_path,
        tensorboard_path=tensorboard_path,
        verbose=verbose_training,
    )

    # ======================================== Register Data ========================================
    checkpoint_time = time.time()
    model.register_anndatas(
        adatas_train, 
        label_index=0, label_key='cell_type', batch_index=0, batch_key=modified_batch_key,
    )
    print_time('Register data', checkpoint_time)

    if train_model:
    # ======================================== Train ========================================
        checkpoint_time = time.time()
        model.train(
            train_task, n_epoch=train_epoch, batch_size=train_batch, 
            save_best_model=True, checkpoint=checkpoint,
        )
        print_time('Train model', checkpoint_time)

    # ======================================== Finetune ========================================
    if finetune_model:
        checkpoint_time = time.time()
        model.load_model(train_model_path(test_batch))
        model.finetune(
            finetune_task, n_epoch=finetune_epoch, batch_size=finetune_batch, 
            save_best_model=True, checkpoint=checkpoint,
        )
        print_time('Finetune model', checkpoint_time)

    # ======================================== Transfer ========================================
    if transfer_model:
        checkpoint_time = time.time()
        model.load_model(finetune_model_path(test_batch))
        # model.load_model(f'saved_models_s3d6/sgi_transfer_3_classification(transfer)/epoch_6.pt')
        model.transfer(
            transfer_task, n_epoch=transfer_epoch, batch_size=transfer_batch,
            adatas_transfer=adatas_test,
            label_index_transfer=0, label_key_transfer='cell_type', 
            batch_index_transfer=0, batch_key_transfer=modified_batch_key,
            save_best_model=True, checkpoint=checkpoint,
        )
        print_time('Transfer model', checkpoint_time)

    print_time(f'Total for batch {test_batch}', checkpoint_time)

#### Evaluate Model

##### Evaluate Commons

In [None]:
from src import UnitedNet
def evaluate_batch(path, test_batch):  
    adatas_train, adatas_test = split_data(test_batch)

    log_path, model_path, tensorboard_path = get_evaluation_paths(test_batch)
    if clean_previous_results: clean_results(log_path, model_path, tensorboard_path)
    model = UnitedNet(
        device=device,
        log_path=log_path,
        model_path=model_path,
        tensorboard_path=tensorboard_path,
        verbose=verbose_evaluation,
    )
    model.load_model(path)

    model.logger.log_message('Training Data')
    model.evaluate(
        adatas_train, 
        label_index_evaluate=0, label_key_evaluate='cell_type',
        batch_index_evaluate=0, batch_key_evaluate=modified_batch_key,
    )

    model.logger.log_message('Testing Data')
    model.evaluate(
        adatas_test, 
        label_index_evaluate=0, label_key_evaluate='cell_type',
        batch_index_evaluate=0, batch_key_evaluate=modified_batch_key,
    )
    
    model.logger.log_message('All Data')
    adatas_all = concat_adatas(adatas_train, adatas_test)
    model.evaluate(
        adatas_all, 
        label_index_evaluate=0, label_key_evaluate='cell_type',
        batch_index_evaluate=0, batch_key_evaluate=modified_batch_key,
    )

##### Evaluate Best from Train

In [None]:
if evaluate_model:
    for test_batch in test_batches:
        print('='*20, test_batch)
        start_time = time.time()

        evaluate_batch(train_model_path(test_batch), test_batch)

        print_time('Evaluate train model', start_time)

##### Evaluate Best from Finetune

In [None]:
if evaluate_model:
    for test_batch in test_batches:
        print('='*20, test_batch)
        start_time = time.time()

        evaluate_batch(finetune_model_path(test_batch), test_batch)

        print_time('Evaluate finetune model', start_time)

##### Evaluate Best from Transfer

In [None]:
if evaluate_model:
    for test_batch in test_batches:
        print('='*20, test_batch)
        start_time = time.time()

        evaluate_batch(transfer_model_path(test_batch), test_batch)

        print_time('Evaluate transfer model', start_time)

#### Infer with Final Model (Best from Transfer)

##### Infer Commons

In [None]:
import scanpy as sc
from src import UnitedNet
def infer_adatas(path, adatas):
    log_path, model_path, tensorboard_path = get_inference_paths(test_batch)
    if clean_previous_results: clean_results(log_path, model_path, tensorboard_path)
    model = UnitedNet(
        device=device,
        log_path=log_path,
        model_path=model_path,
        tensorboard_path=tensorboard_path,
        verbose=verbose_inference,
    )
    model.load_model(path)

    adata_inferred = model.infer(
        adatas,
        modalities_provided=[0, 1],
        batch_index_infer=0, batch_key_infer=modified_batch_key, 
        modality_sizes=[13634, 4000]
    )
    
    adata_inferred.obs[modified_batch_key] = list(adatas[0].obs[modified_batch_key])
    sc.pl.umap(adata_inferred, color=[modified_batch_key])

    adata_inferred.obs['batch'] = list(adatas[0].obs['batch'])
    sc.pl.umap(adata_inferred, color=['batch'])
    
    sc.pl.umap(adata_inferred, color=['predicted_label'])
    
    adata_inferred.obs['label'] = list(adatas[0].obs['cell_type'])
    sc.pl.umap(adata_inferred, color=['label'])

##### Infer on the final model with train data

In [None]:
if infer_model:
    for test_batch in test_batches:
        print('='*20, test_batch)
        start_time = time.time()
        
        adatas_train, adatas_test = split_data(test_batch)

        infer_adatas(transfer_model_path(test_batch), adatas_train)

        print_time('Infer final model on train data', start_time)

##### Infer on the final model with test data

In [None]:
if infer_model:
    for test_batch in test_batches:
        print('='*20, test_batch)
        start_time = time.time()
        
        adatas_train, adatas_test = split_data(test_batch)

        infer_adatas(transfer_model_path(test_batch), adatas_test)

        print_time('Infer final model on test data', start_time)

##### Infer on the final model with all data

In [None]:
if infer_model:
    for test_batch in test_batches:
        print('='*20, test_batch)
        start_time = time.time()

        adatas_train, adatas_test = split_data(test_batch)
        adatas_all = concat_adatas(adatas_train, adatas_test)

        infer_adatas(transfer_model_path(test_batch), adatas_all)
        
        print_time('Infer final model on all data', start_time)