In [16]:
import scanpy as sc
import multigrate as mtg

In [17]:
rna = sc.read('/lustre/groups/ml01/projects/2022_multigrate_anastasia.litinetskaya/multimil_reproducibility/pipeline/data/pp/pbmc_healthy_covid_rna.h5ad')
rna

AnnData object with n_obs × n_vars = 624325 × 2000
    obs: 'sample_id', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'full_clustering', 'initial_clustering', 'Resample', 'Collection_Day', 'Sex', 'Age_interval', 'Swab_result', 'Status', 'Smoker', 'Status_on_day_collection', 'Status_on_day_collection_summary', 'Days_from_onset', 'Site', 'time_after_LPS', 'Worst_Clinical_Status', 'Outcome', 'patient_id', 'split0', 'split1', 'split2', 'split3', 'split4'
    var: 'feature_types', 'means', 'variances', 'residual_variances', 'highly_variable_rank', 'highly_variable_nbatches', 'highly_variable_intersection', 'highly_variable'
    uns: 'hvg', 'leiden', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap'
    layers: 'counts'

In [3]:
rna.X.data

array([1., 9., 2., ..., 7., 3., 7.], dtype=float32)

In [4]:
adt = sc.read('../../../pipeline/data/pp/pbmc_healthy_covid_adt.h5ad')
adt

AnnData object with n_obs × n_vars = 624325 × 192
    obs: 'sample_id', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'full_clustering', 'initial_clustering', 'Resample', 'Collection_Day', 'Sex', 'Age_interval', 'Swab_result', 'Status', 'Smoker', 'Status_on_day_collection', 'Status_on_day_collection_summary', 'Days_from_onset', 'Site', 'time_after_LPS', 'Worst_Clinical_Status', 'Outcome', 'patient_id', 'split0', 'split1', 'split2', 'split3', 'split4'
    var: 'feature_types'
    uns: 'hvg', 'leiden', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_pca_harmony', 'X_umap'
    layers: 'counts'

In [5]:
adt.X.data

array([0.16182648, 0.4232845 , 0.80181926, ..., 1.3309513 , 1.3929683 ,
       1.9186018 ], dtype=float32)

In [6]:
adata = mtg.data.organize_multiome_anndatas(
     adatas = [[rna], [adt]],
)
adata

AnnData object with n_obs × n_vars = 624325 × 2192
    obs: 'sample_id', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'full_clustering', 'initial_clustering', 'Resample', 'Collection_Day', 'Sex', 'Age_interval', 'Swab_result', 'Status', 'Smoker', 'Status_on_day_collection', 'Status_on_day_collection_summary', 'Days_from_onset', 'Site', 'time_after_LPS', 'Worst_Clinical_Status', 'Outcome', 'patient_id', 'split0', 'split1', 'split2', 'split3', 'split4', 'group'
    var: 'modality'
    uns: 'modality_lengths'
    layers: 'counts'

In [7]:
mtg.model.MultiVAE.setup_anndata(
    adata,
    rna_indices_end=2000,
    categorical_covariate_keys=[
        "Site",
        "patient_id",
      ])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [8]:
vae = mtg.model.MultiVAE(
    adata, 
    losses=[
        "nb", "mse"
    ],
    loss_coefs={
        "kl": 0.01,
    },
    cond_dim=16,
    z_dim=16,
)

In [None]:
vae.train(lr=1e-4)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 3g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-830c8b75-b7fb-562e-94eb-933eb50d368a]


Epoch 125/200:  62%|██████▏   | 124/200 [1:02:12<29:43, 23.47s/it, loss=830, v_num=1] 

In [None]:
vae.plot_losses()

In [None]:
vae.get_latent_representation()
adata

In [None]:
sc.pp.neighbors(adata, use_rep='latent')
sc.tl.umap(adata)

In [None]:
sc.pl.umap(
    adata, 
    color=[
      "Status_on_day_collection_summary",
      "Site",
      "initial_clustering"
    ],
    ncols=1,
    frameon=False
)

In [None]:
adata.write('../../../pipeline/data/pp/pbmc_healthy_covid_multigrate.h5ad')

In [None]:
vae.save('model/', overwrite=True)