In [None]:
import scanpy as sc
import mudata
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from dcor import energy_distance
from scvi.model import SCVI

## Open data and learn about fields

In [None]:
adata = sc.read_h5ad("rna_donors_for_totalVI_temp.h5ad")

In [None]:
adata

In [None]:
# sketch out all the info needed as cell info
# assignment to library (watch out for batch effects)
adata.obs["library"]
# assignment to timepoint
adata.obs["timepoint"]
# assignment to perturbation
adata.obs["treatment"]
# assignment to donor (watch out for batch effects)
adata.obs["donor"]
# assignment to drug pathway (a summary of the target of the drug)
adata.obs["pathway"]
# assignment to drug target
adata.obs["target"]
# end

## train scVI and greedily filter drugs based on rna effects

In [None]:
adata.obs["batch"] = adata.obs["donor"].astype(str) + adata.obs["library"].astype(str)

In [None]:
train = False
save_dir = "model_saved"
if train:
    SCVI.setup_anndata(adata, layer="counts", batch_key="batch")
    model = SCVI(adata, n_hidden=256, n_latent=50, n_layers=2, gene_likelihood="nb")
    model.train(use_gpu=True, early_stopping=True)
    model.save(save_dir, overwrite=True)
else:
    model = SCVI.load(save_dir, adata, use_gpu=True)

In [None]:
# get latent space
adata.obsm["X_scVI"] = model.get_latent_representation()

In [None]:
# effect plot
# for each chemical, calculate the energy distance on foreground and on latent
total_effect = {}
for treatment in tqdm(adata.obs["treatment"].unique()):
        index_control = np.where(adata.obs["compound"] == "Vehicle")[0][:1000]
        index_condition = np.where(adata.obs["treatment"] == treatment)[0][:1000]
        if len(np.where(adata.obs["treatment"] == treatment)[0]) > 200:
            t_effect = energy_distance(
                adata.obsm["X_scVI"][index_control], adata.obsm["X_scVI"][index_condition]
            )
            total_effect[treatment] = t_effect

In [None]:
series = pd.Series(total_effect).sort_values()[::-1]
series[series > 1]

In [None]:
large_effects = list(series[series > 1].index)

In [None]:
to_keep = large_effects + ["Vehicle_100nM", "Vehicle_1uM", "Vehicle_10uM", "No stim_100nM", "No stim_1uM", "No stim_10uM"]

In [None]:
adata_filtered = adata[adata.obs["treatment"].isin(to_keep)].copy()

In [None]:
adata_filtered

## train scVI again and visualize data with UMAP

In [None]:
sc.pp.highly_variable_genes(adata_filtered, layer="counts", n_top_genes=6000, batch_key="library", flavor="seurat_v3", subset=True)

In [None]:
train = False
save_dir = "model_filtered_saved"
if train:
    SCVI.setup_anndata(adata_filtered, layer="counts", batch_key="batch")
    model = SCVI(adata_filtered, n_hidden=128, n_latent=30, n_layers=2, gene_likelihood="nb")
    model.train(max_epochs=100, use_gpu=True, early_stopping=True)
    model.history["elbo_validation"].plot()
    model.save(save_dir, overwrite=True)
else:
    model = SCVI.load(save_dir, adata_filtered, use_gpu=True)

In [None]:
adata_filtered.obsm["X_filt_scVI"] = model.get_latent_representation()
# umap
sc.pp.neighbors(adata_filtered, use_rep="X_filt_scVI")
sc.tl.umap(adata_filtered)
adata_filtered.obsm["X_filt_scVI_umap"] = adata_filtered.obsm["X_umap"].copy()

In [None]:
sc.pl.embedding(adata_filtered, basis="X_filt_scVI_umap", color=['library', 'Plate#', "score_s", "score_g2m", "donor", "hto_label"], ncols=2)

In [None]:
sc.pl.embedding(adata_filtered, basis="X_filt_scVI_umap", color=['timepoint', "pathway", "target"], ncols=2)

## save the filtered dataset

In [None]:
adata_filtered

In [None]:
adata_filtered.X = adata_filtered.layers["counts"].copy()
del adata_filtered.layers

In [None]:
del adata_filtered.uns
del adata_filtered.obsp
del adata_filtered.raw

In [None]:
del adata_filtered.obsm["X_scVI"]
del adata_filtered.obsm["adt"]
del adata_filtered.obsm["adt_norm"]
del adata_filtered.obsm["adt_select"]
del adata_filtered.obsm["proteins"]
del adata_filtered.obsm["proteins_norm"]
del adata_filtered.obsm["tsb"]
del adata_filtered.obsm["tsb_norm"]
del adata_filtered.obsm["tsb_select"]

In [None]:
adata_filtered.write_h5ad("icCITE-plex_filtered_top_drugs.h5ad")