#### Constants

In [1]:
import itertools

data_technique = 'dlpfc'

groups = [
        ['151507', '151508', '151509', '151510'],
        ['151669', '151670', '151671', '151672'],
        ['151673', '151674', '151675', '151676']
]

donor = {
    '151507':'donor0', 
    '151508':'donor0', 
    '151509':'donor0', 
    '151510':'donor0',
    '151669':'donor1',
    '151670':'donor1',
    '151671':'donor1',
    '151672':'donor1',
    '151673':'donor2',
    '151674':'donor2',
    '151675':'donor2',
    '151676':'donor2'
}

all_batches = list(itertools.chain(*groups))
test_batches = ['151507']


device= 'cuda:0'


data_path = 'data/DLPFC'
batch_key = 'batch'
modified_batch_key = 'batch_modified'

#### Parameter Settings

In [2]:
train_epoch = 30
train_batch = 512

learning_rate = 0.1

use_batch_key = modified_batch_key

clean_previous_evaluation_results = True
clean_previous_inference_results = True
checkpoint = 1

verbose_train = True
verbose_evaluation = True
verbose_inference = True

train_model = True
evaluate_model = True
infer_model = True

root_path = f"saved_results/{train_epoch}_epochs"
final_step = 3

#### Auto Programs

##### Generate Parameters

In [3]:
import os

label_key = 'cell_type'
BATCH_TRAIN = 'train'
BATCH_TEST = 'test'

root_log_path = f'{root_path}'
root_model_path = f'{root_path}/models'
root_tensorboard_path = f'{root_path}/tensorboards'
os.makedirs(root_log_path, exist_ok=True)
os.makedirs(root_model_path, exist_ok=True)
os.makedirs(root_tensorboard_path, exist_ok=True)

def get_train_paths(test_batch):
    return f'{root_log_path}/{test_batch}', f'{root_model_path}/{test_batch}', f'{root_tensorboard_path}/{test_batch}'

def print_time(name, start_time):
    print(f"{name}: {((time.time()-start_time)/60):0.2f} mins")

def clean_results(log_path, model_path, tensorboard_path):
    if log_path is not None and os.path.exists(log_path):
        os.remove(log_path)
    if model_path is not None and os.path.exists(model_path):
        shutil.rmtree(model_path)
    if tensorboard_path is not None and os.path.exists(tensorboard_path):
        shutil.rmtree(tensorboard_path)

##### Read Data

In [4]:
import anndata as ad
import numpy as np
import scanpy as sc
import pandas as pd

def generate_adata(data, nonnan_indices, cell_type_label, cols, rows, batch):
    data = data.loc[data.index[nonnan_indices]]
    adata=ad.AnnData(X=np.array(data),obs=list(data.index))
    adata.obs[label_key]  = cell_type_label
    adata.obs['imagecol'] = cols 
    adata.obs['imagerow'] = rows
    adata.obs[batch_key]  = batch
    return adata

dlpfc_adatas = dict()
for batch in all_batches:
    data_batch      = pd.read_csv(f"{data_path}/{batch}/tissue_positions_list.csv", header=None, index_col=0)
    rna_data        = pd.read_csv(f'{data_path}/{batch}/DLPFC_spatial_simulation_mrna_marker.csv',index_col=0)
    morph_data      = pd.read_csv(f'{data_path}/{batch}/DLPFC_spatial_simulation_morph.csv',index_col=0).T
    mrna_niche_data = pd.read_csv(f'{data_path}/{batch}/DLPFC_spatial_simulation_niche_mrna_mp.csv',index_col=0).T
    cell_type_label = np.array(pd.read_csv(f"{data_path}/{batch}/cluster_labels_{batch}.csv",index_col=0).astype(str)['ground_truth'])

    nonnan_indices = np.where(cell_type_label != 'nan')[0].astype(int)

    rows = np.array(data_batch.loc[list(rna_data.index),4])
    cols = np.array(data_batch.loc[list(rna_data.index),5])

    cell_type_label = cell_type_label[nonnan_indices]
    rows = rows[nonnan_indices]
    cols = cols[nonnan_indices]

    adata_rna = generate_adata(rna_data, nonnan_indices, cell_type_label, cols, rows, batch)
    adata_morph = generate_adata(morph_data, nonnan_indices, cell_type_label, cols, rows, batch)
    adata_mrna_niche = generate_adata(mrna_niche_data, nonnan_indices, cell_type_label, cols, rows, batch)

    dlpfc_adatas[batch] = [adata_rna, adata_morph, adata_mrna_niche]
    
adata_rna_all = ad.concat([dlpfc_adatas[batch][0] for batch in all_batches])
adata_morph_all = ad.concat([dlpfc_adatas[batch][1] for batch in all_batches])
adata_mrna_niche_all = ad.concat([dlpfc_adatas[batch][2] for batch in all_batches])

sc.pp.normalize_total(adata_rna_all)
sc.pp.log1p(adata_rna_all)
sc.pp.scale(adata_rna_all, max_value=4)

def split_data(test_batch):
    adata_rna_train  = adata_rna_all[adata_rna_all.obs[batch_key] != test_batch]
    adata_morph_train  = adata_morph_all[adata_morph_all.obs[batch_key] != test_batch]
    adata_mrna_niche_train  = adata_mrna_niche_all[adata_mrna_niche_all.obs[batch_key] != test_batch]

    adata_rna_test  = adata_rna_all[adata_rna_all.obs[batch_key] == test_batch]
    adata_morph_test  = adata_morph_all[adata_morph_all.obs[batch_key] == test_batch]
    adata_mrna_niche_test  = adata_mrna_niche_all[adata_mrna_niche_all.obs[batch_key] == test_batch]

    adata_rna_train.obs[modified_batch_key] = BATCH_TRAIN
    adata_morph_train.obs[modified_batch_key]  = BATCH_TRAIN
    adata_mrna_niche_train.obs[modified_batch_key] = BATCH_TRAIN

    adata_rna_test.obs[modified_batch_key] = BATCH_TEST
    adata_morph_test.obs[modified_batch_key]  = BATCH_TEST
    adata_mrna_niche_test.obs[modified_batch_key] = BATCH_TEST
    
    return [adata_rna_train, adata_morph_train, adata_mrna_niche_train], [adata_rna_test, adata_morph_test, adata_mrna_niche_test]

def concat_adatas(adatas_train, adatas_test):
    return [ad.concat([adata_train,adata_test]) for adata_train, adata_test in zip(adatas_train, adatas_test)]

  
  utils.warn_names_duplicates("obs")


#### Train Model

In [6]:
from src import UnitedNet
import anndata as ad
import os, shutil
import time 

def set_model_paths(model, log_path, model_path, tensorboard_path):
    clean_results(log_path, model_path, tensorboard_path)
    model.set_log_path(log_path)
    model.set_model_path(model_path)
    model.set_tensorboard_path(tensorboard_path)

if train_model:
    for test_batch in test_batches:
        print('='*20, test_batch)

        start_time = time.time()

        adatas_train, adatas_test = split_data(test_batch)
        print_time('Split data', start_time)

        model = UnitedNet(device=device)

        # ======================================== Register Data ========================================
        checkpoint_time = time.time()
        model.register_anndatas(
            adatas_train, 
            label_index=0, label_key=label_key, 
            technique=data_technique,
        )
        print_time('Register data', checkpoint_time)

        # ======================================== Transfer ========================================
        if train_model:
            checkpoint_time = time.time()
            set_model_paths(model, *get_train_paths(test_batch))
            model.set_verbose(verbose_train)
            model.train(
                'supervised_group_identification', n_epoch=train_epoch, learning_rate=learning_rate, batch_size=train_batch,
                save_best_model=True, checkpoint=checkpoint,
            )
            model.evaluate(adatas_test, label_index_evaluate=0, label_key_evaluate=label_key)
            model.transfer(
                'supervised_group_identification', n_epoch=train_epoch, learning_rate=learning_rate, batch_size=train_batch,
                adatas_transfer=adatas_test,
                save_best_model=True, checkpoint=checkpoint,
            )

        print_time(f'Total for batch {test_batch}', start_time)





Split data: 0.01 mins
Register data: 0.12 mins

TRAIN

    (Epoch 1 / 50)
    Losses
        contrastive: 5.871459484100342
        discriminator: 0.48772650957107544
        generator: 0.013832403346896172
        reconstruction: 5.88614559173584
        translation: 6.483118534088135
    Losses
        contrastive: 5.183070182800293
        discriminator: 0.49928200244903564
        generator: 0.0006626107497140765
        reconstruction: 4.028621196746826
        translation: 5.511995792388916
            Saving model to saved_results/50_epochs/models/151507/train_0_translation/best.pt
            Saving model to saved_results/50_epochs/models/151507/train_0_translation/epoch_1.pt
    Losses
        cross_entropy: 3.6536924839019775
    Losses
        cross_entropy: 0.38836532831192017
            Saving model to saved_results/50_epochs/models/151507/train_1_classification/best.pt
            Saving model to saved_results/50_epochs/models/151507/train_1_classification/epoch_1.pt
   

KeyboardInterrupt: 

In [None]:
import scanpy as sc
from src import UnitedNet
def evaluate_adatas(path, adatas):
    model = UnitedNet(
        device=device,
        log_path=None,
        model_path=None,
        tensorboard_path=None,
        verbose=False,
    )
    model.load_model(path)

    return model.evaluate(
        adatas,
        label_index_evaluate=0, label_key_evaluate=label_key,
        batch_index_evaluate=0, batch_key_evaluate=use_batch_key,
    )['ari']

In [None]:
final_model_path = f'{root_path}/models/{test_batch}/transfer_{final_step}_classification'
aris = []
for epoch in range(1, train_epoch+1):
    adatas_train, adatas_test = split_data(test_batch)
    adatas_all = concat_adatas(adatas_train, adatas_test)

    model_path = f'{final_model_path}/epoch_{epoch}.pt'
    aris.append(evaluate_adatas(model_path, adatas_test))

for epoch in range(1, train_epoch+1):
    print('epoch', epoch, aris[epoch-1])

#### Infer Model

##### Infer Commons

In [None]:
import scanpy as sc
from src import UnitedNet
def infer_adatas(path, adatas, eval_only=True):
    model = UnitedNet(
        device=device,
        log_path=None,
        model_path=None,
        tensorboard_path=None,
        verbose=True,
    )
    model.load_model(path)

    model.evaluate(
        adatas,
        label_index_evaluate=0, label_key_evaluate=label_key,
        batch_index_evaluate=0, batch_key_evaluate=use_batch_key,
    )

    if not eval_only:
      adata_inferred = model.infer(
          adatas,
          modalities_provided=list(range(len(adatas))),
          batch_index_infer=0, batch_key_infer=use_batch_key, 
          modality_sizes=[adata.shape[1] for adata in adatas]
      )
      
      adata_inferred.obs[modified_batch_key] = list(adatas[0].obs[modified_batch_key])
      sc.pl.umap(adata_inferred, color=[modified_batch_key])

      adata_inferred.obs['batch'] = list(adatas[0].obs['batch'])
      sc.pl.umap(adata_inferred, color=['batch'])
      
      sc.pl.umap(adata_inferred, color=['predicted_label'])
      
      adata_inferred.obs[label_key] = list(adatas[0].obs[label_key])
      sc.pl.umap(adata_inferred, color=[label_key])

##### Infer on train model

In [None]:
final_model_path = f'{root_path}/models/{test_batch}/transfer_{final_step}_classification'
for epoch in range(1, train_epoch+1):
    print('='*20, 'epoch', epoch)

    adatas_train, adatas_test = split_data(test_batch)
    adatas_all = concat_adatas(adatas_train, adatas_test)

    model_path = f'{final_model_path}/epoch_{epoch}.pt'
    infer_adatas(model_path, adatas_test)
    infer_adatas(model_path, adatas_all, eval_only=False)
    

In [None]:
final_model_path = f'{root_path}/models/{test_batch}/transfer_{final_step}_classification'
for epoch in [train_epoch]:
    print('='*20, 'epoch', epoch)

    adatas_train, adatas_test = split_data(test_batch)
    adatas_all = concat_adatas(adatas_train, adatas_test)

    model_path = f'{final_model_path}/epoch_{epoch}.pt'
    infer_adatas(model_path, adatas_test)
    infer_adatas(model_path, adatas_all, eval_only=False)
    