In [1]:
x=1

In [2]:
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd

import os.path

import multigrate as mtg
import multimil as mtm
import scvi
import torch

from sklearn.model_selection import LeaveOneOut
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler

from matplotlib import pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

[rank: 0] Global seed set to 0


In [3]:
input1 = '/lustre/groups/ml01/projects/2022_multigrate_anastasia.litinetskaya/multimil_reproducibility/pipeline/data/pp/pbmc_3_cond_balanced_rna.h5ad'
input2 = '/lustre/groups/ml01/projects/2022_multigrate_anastasia.litinetskaya/multimil_reproducibility/pipeline/data/pp/pbmc_3_cond_balanced_adt.h5ad'
label_key = 'initial_clustering'
batch_key = 'Site'
condition_key = 'Status_on_day_collection_summary'
sample_key = 'patient_id'
donor = sample_key
condition = condition_key
n_splits = 5

In [4]:
adata1 = sc.read_h5ad(input1)
adata2 = sc.read_h5ad(input2)

In [5]:
setup_params = {
        "rna_indices_end": 2000,
        "categorical_covariate_keys": ['patient_id', 'Site', 'Status_on_day_collection_summary'],
    }

In [6]:
model_params = {
        "z_dim": 16,
        "attn_dim": 16,
        "class_loss_coef": None,
        "cond_dim": 16,
    }

In [7]:
lr = 1e-4
batch_size = 256
kl = 1e-6
seed = 0

scvi.settings.seed = seed

[rank: 0] Global seed set to 0


In [8]:
hashes = ['52e3c74810', 'dfbda3bb3e', '478fe63f05', '44ddcb1ef8']
coefs = [0.1, 1.0, 10, 100]
ckpts = ['epoch=49-step=36000', 'epoch=49-step=37350', 'epoch=49-step=34850', 'epoch=49-step=33750', 'epoch=49-step=35650']
q_ckpts = ['epoch=19-step=3440', 'epoch=19-step=2840', 'epoch=19-step=3900', 'epoch=19-step=4340', 'epoch=19-step=3540']

In [9]:
for h, coef in zip(hashes, coefs):
    print(f'Hash {h}, coef {coef}...')

    model_params['class_loss_coef'] = coef
    for i, best_ckpt, best_q_ckpt in zip(range(n_splits), ckpts, q_ckpts):
        print(f'Split {i}...')
    
        print('Organizing multimodal anndatas...')
        adata = mtg.data.organize_multiome_anndatas(
            adatas = [[adata1], [adata2]],
            )
        losses = ['nb', 'mse']
    
        query = adata[adata.obs[f"split{i}"] == "val"].copy()
        adata = adata[adata.obs[f"split{i}"] == "train"].copy()
    
        idx = adata.obs[donor].sort_values().index
        adata = adata[idx].copy()
    
        print('Setting up anndata...')
        mtm.model.MultiVAE_MIL.setup_anndata(
            adata, 
            **setup_params
        )
    
        print('Initializing the model...')
        
        mil = mtm.model.MultiVAE_MIL(
            adata,
            patient_label=donor,
            losses=losses,
            loss_coefs={
                'kl': kl,
            },
            classification=[condition],
            **model_params,
        )
    
        path_to_train_checkpoints = f'/lustre/groups/ml01/projects/2022_multigrate_anastasia.litinetskaya/multimil_reproducibility/pipeline/data/multigrate/pbmc_3_cond_balanced_end2end/{h}/{i}/checkpoints/'
    
        train_state_dict = torch.load(path_to_train_checkpoints + f'{best_ckpt}.ckpt')['state_dict']
        for key in list(train_state_dict.keys()):
            train_state_dict[key.replace('module.', '')] = train_state_dict.pop(key)
    
        mil.module.load_state_dict(train_state_dict)
    
        mil.is_trained_ = True
        mil.get_model_output(adata, batch_size=batch_size)
    
        idx = query.obs[donor].sort_values().index
        query = query[idx].copy()
    
        new_model = mtm.model.MultiVAE_MIL.load_query_data(query, use_prediction_labels=False, reference_model=mil)
    
        path_to_query_checkpoints = f'/lustre/groups/ml01/projects/2022_multigrate_anastasia.litinetskaya/multimil_reproducibility/pipeline/data/multigrate/pbmc_3_cond_balanced_end2end/{h}/{i}/query_checkpoints/{best_ckpt}/'
        
        query_state_dict = torch.load(path_to_query_checkpoints + f'{best_q_ckpt}.ckpt')['state_dict']
        for key in list(query_state_dict.keys()):
            query_state_dict[key.replace('module.', '')] = query_state_dict.pop(key)
            key = key.replace('module.', '')
            query_state_dict[f'vae.{key}'] = query_state_dict.pop(key)
    
        train_state_dict.update(query_state_dict)
    
        new_model.is_trained_ = True
        new_model.get_model_output(query, batch_size=batch_size)
    
        adata.obs['reference'] = 'reference'
        query.obs['reference'] = 'query'
        adata_both = ad.concat([adata, query])
    
        sc.pp.neighbors(adata_both, use_rep='latent')
        sc.tl.umap(adata_both)
        
        adata1.obsm[f'X_umap_{i}'] = adata_both.obsm['X_umap']
        adata1.obsm[f'latent_{i}'] = adata_both.obsm['latent']
        adata1.obs[f'cell_attn_{i}'] = adata_both.obs['cell_attn']
        adata1.obs[f'reference_{i}'] = adata_both.obs['reference']
    
    adata1.obs['cell_attn'] = np.mean([adata1.obs[f'cell_attn_{i}'] for i in range(n_splits)], axis=0)
    adata1.write(f'/lustre/groups/ml01/projects/2022_multigrate_anastasia.litinetskaya/multimil_reproducibility/pipeline/data/multigrate/pbmc_3_cond_balanced_end2end/{h}_adata_both_full.h5ad')

Hash 52e3c74810, coef 0.1...
Split 0...
Organizing multimodal anndatas...
Setting up anndata...


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Initializing the model...
Split 1...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 2...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 3...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 4...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Hash dfbda3bb3e, coef 1.0...
Split 0...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 1...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 2...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 3...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Split 4...
Organizing multimodal anndatas...
Setting up anndata...
Initializing the model...
Hash 478fe63f05, coef 10...
Split 0...
Organizing multimodal anndatas...
Setting up anndata...
Initializing 