In [8]:
import pertpy as pt
import scanpy as sc
import pandas as pd
import pandas as pd
import pandas as pd
import numpy as np
import scvi
from _quasiSCVI import QuasiSCVI
from _quasivae import QuasiVAE
from anndata import AnnData

In [2]:
# Load a sample dataset
mdata = pt.dt.papalexi_2021()



In [3]:
gdo = mdata.mod["gdo"]
# Extract the guide count matrix
guide_count_matrix = gdo.X

guide_count_df = pd.DataFrame(guide_count_matrix.toarray(), index=gdo.obs.index, columns=gdo.var.index)

In [4]:
guide_adata = sc.AnnData(X=guide_count_df.values)
guide_adata.obs.index = guide_count_df.index
guide_adata.var.index = guide_count_df.columns
guide_adata.obs['replicate'] = mdata["rna"].obs['replicate']
guide_adata.obs['Phase'] = mdata["rna"].obs['Phase']
guide_adata.obs['perturbation'] = mdata["rna"].obs['perturbation']
guide_adata.obs['guide']= mdata["rna"].obs['NT']
# Inspect the AnnData object
print(guide_adata)

AnnData object with n_obs × n_vars = 20729 × 111
    obs: 'replicate', 'Phase', 'perturbation', 'guide'


In [5]:
scvi.model.SCVI.setup_anndata(guide_adata)

In [6]:
scvi_ref = scvi.model.SCVI(
    guide_adata,
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)
scvi_ref.train()

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 386/386: 100%|███████████████████████████████████████████| 386/386 [04:07<00:00,  1.57it/s, v_num=1, train_loss_step=144, train_loss_epoch=146]

`Trainer.fit` stopped: `max_epochs=386` reached.


Epoch 386/386: 100%|███████████████████████████████████████████| 386/386 [04:07<00:00,  1.56it/s, v_num=1, train_loss_step=144, train_loss_epoch=146]


In [9]:
guide_embeddings = scvi_ref.get_latent_representation()
guide_embeddings_df = pd.DataFrame(guide_embeddings, index=guide_adata.obs.index)
adata_ref = mdata["rna"].copy()

print("Shape of adata_ref.X:", adata_ref.X.shape)
print("Shape of guide_embeddings_df.values:", guide_embeddings_df.values.shape)
if not isinstance(adata_ref.X, np.ndarray):
    adata_ref.X = adata_ref.X.toarray()

# Ensure guide_embeddings_df.values is a numpy array
guide_embeddings_array = guide_embeddings_df.values

# Concatenate gene expression data (X) with guide embeddings along columns
combined_X = np.concatenate([adata_ref.X, guide_embeddings_array], axis=1)
print(combined_X.shape)

Shape of adata_ref.X: (20729, 18649)
Shape of guide_embeddings_df.values: (20729, 10)
(20729, 18659)


In [None]:
QuasiSCVI.setup_anndata(adata=combined_X, batch_key='replicate', labels_key='perturbation')

In [None]:
guide_model = QuasiSCVI(
    combined_X,
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)

# Train the model
guide_model.train(max_epochs=100, plan_kwargs={"lr": 1e-4}, accelerator='gpu')

In [None]:
SCVI_LATENT_KEY = "X_scVI"

combined_X.obsm[SCVI_LATENT_KEY] = guide_model.get_latent_representation()
sc.pp.neighbors(combined_X, use_rep=SCVI_LATENT_KEY)
sc.tl.leiden(combined_X)
sc.tl.umap(combined_X)
sc.pl.umap(
    combined_X,
    color=["replicate", "Phase", "perturbation" ],
    frameon=False,
    ncols=1,
)