In [1]:

import anndata as ad
import numpy as np
import scvi
import juniper
from typing import Literal, Optional, Dict
import jax
import pandas as pd
import os

def get_preds_from_adata(adata: ad.AnnData) -> Dict[str, jax.Array]:
    d = {}

    for condition in adata.obs["condition"].cat.categories:
        d[condition] = adata[adata.obs["condition"]==condition].X
    return d 

def write_preds_to_adata(predictions: Dict[str, jax.Array]) -> ad.AnnData:
    all_data = []
    conditions = []
    for condition, array in predictions.items():
        all_data.append(array)
        conditions.extend([condition] * array.shape[0])
    all_data_array = np.vstack(all_data)
    obs_data = pd.DataFrame({
        'condition': conditions
    })
    adata_pred = ad.AnnData(X=all_data_array, obs=obs_data)
    return adata_pred


def reconstruct_data(embedding, projection_matrix, mean_to_add):
    """Reconstructs data from projections."""
    return np.matmul(embedding, projection_matrix.T) + mean_to_add


def reconstruct_data_from_vae(model_dir, adata_train, adata):
    """Reconstructs data from projections."""
    model = juniper.latent.model.FactorVI.load(model_dir, adata_train)
    adata.obsm["X_scVI"] = model.get_latent_representation(adata)
    return model.get_reconstructed_expression(adata, give_mean=True)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import scanpy as sc
output_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/combosciplex"
adata_train = sc.read(os.path.join(output_dir, "adata_train_300.h5ad"))
adata_test = sc.read(os.path.join(output_dir, "adata_test_300.h5ad"))
adata_ood = sc.read(os.path.join(output_dir, "adata_ood_300.h5ad")) 




In [3]:
model_dir = "/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test"

In [4]:
import functools

In [5]:
model = juniper.latent.model.FactorVI.load(model_dir, adata_train)

[34mINFO    [0m File [35m/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test/[0m[95mmodel.pt[0m       
         already downloaded                                                                                        


/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
  self.validate_field(adata)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
2024-07-08 09:11:14.382151: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environ

Epoch 1/154:   0%|          | 0/154 [00:00<?, ?it/s]

/ictstr01/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:385: You have overridden `transfer_batch_to_device` in `LightningModule` but have passed in a `LightningDataModule`. It will use the implementation from `LightningModule` instance.


Epoch 1/154:   1%|          | 1/154 [00:04<09:47,  3.84s/it, v_num=1, train_loss=578, rec_loss=578, kld_loss=14.1]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 1/154:   1%|          | 1/154 [00:04<11:17,  4.43s/it, v_num=1, train_loss=578, rec_loss=578, kld_loss=14.1]


In [7]:
dir(model.module)

['__annotations__',
 '__call__',
 '__class__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__dataclass_transform__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__match_args__',
 '__module__',
 '__ne__',
 '__new__',
 '__post_init__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_call_wrapped_method',
 '_check_train_state_is_not_none',
 '_compact_name_scope_methods',
 '_customized_dataclass_transform',
 '_find_compact_name_scope_methods',
 '_get_generative_input',
 '_get_inference_input',
 '_id',
 '_initialization_allowed',
 '_module_checks',
 '_name_taken',
 '_parent_ref',
 '_register_submodules',
 '_rngs',
 '_set_rngs',
 '_split_rngs',
 '_state',
 '_try_setup',
 '_validate_setup',
 '_verify_single_or_no_compact',
 '

In [8]:
model.module.params

{'decoder': {'disp': Array([[-0.00901103],
         [ 1.6431851 ],
         [-0.862552  ],
         ...,
         [-0.316092  ],
         [-0.8693874 ],
         [ 0.02939256]], dtype=float32),
  'mlp': {'batchnorms_0': {'bias': Array([-0.1754373 , -0.16873913, -0.09719211, ..., -0.1230128 ,
           -0.11073232, -0.05947183], dtype=float32),
    'scale': Array([0.93593216, 0.9610804 , 1.0540652 , ..., 1.0390935 , 0.95801276,
           0.97524554], dtype=float32)},
   'batchnorms_1': {'bias': Array([-0.06353407, -0.07592379, -0.13063285, ..., -0.06392703,
           -0.10699841, -0.09110859], dtype=float32),
    'scale': Array([0.9887081, 0.972156 , 1.043807 , ..., 1.0458841, 1.0110469,
           1.0576954], dtype=float32)},
   'layers_0': {'bias': Array([ 1.2064066e-09, -3.7738426e-09, -2.4081122e-09, ...,
           -1.6211040e-09, -1.6765181e-09,  1.4856989e-09], dtype=float32),
    'kernel': Array([[ 0.01861437,  0.31134918, -0.18183663, ...,  0.0770072 ,
             0.0288180

In [9]:
model.save("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test2")

In [10]:
model2 = juniper.latent.model.FactorVI.load("/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test2", adata_train)

[34mINFO    [0m File [35m/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test2/[0m[95mmodel.pt[0m      
         already downloaded                                                                                        


/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
  self.validate_field(adata)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 

Epoch 1/154:   1%|          | 1/154 [00:03<08:52,  3.48s/it, v_num=1, train_loss=564, rec_loss=564, kld_loss=14.1]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 1/154:   1%|          | 1/154 [00:03<09:02,  3.55s/it, v_num=1, train_loss=564, rec_loss=564, kld_loss=14.1]


In [12]:
model2.module.params

{'decoder': {'disp': Array([[-0.00901103],
         [ 1.6431851 ],
         [-0.862552  ],
         ...,
         [-0.316092  ],
         [-0.8693874 ],
         [ 0.02939256]], dtype=float32),
  'mlp': {'batchnorms_0': {'bias': Array([-0.1754373 , -0.16873913, -0.09719211, ..., -0.1230128 ,
           -0.11073232, -0.05947183], dtype=float32),
    'scale': Array([0.93593216, 0.9610804 , 1.0540652 , ..., 1.0390935 , 0.95801276,
           0.97524554], dtype=float32)},
   'batchnorms_1': {'bias': Array([-0.06353407, -0.07592379, -0.13063285, ..., -0.06392703,
           -0.10699841, -0.09110859], dtype=float32),
    'scale': Array([0.9887081, 0.972156 , 1.043807 , ..., 1.0458841, 1.0110469,
           1.0576954], dtype=float32)},
   'layers_0': {'bias': Array([ 1.2064066e-09, -3.7738426e-09, -2.4081122e-09, ...,
           -1.6211040e-09, -1.6765181e-09,  1.4856989e-09], dtype=float32),
    'kernel': Array([[ 0.01861437,  0.31134918, -0.18183663, ...,  0.0770072 ,
             0.0288180

In [17]:
out_tree = jax.tree_util.tree_map(lambda x,y: x==y, model2.module.params, model.module.params)

In [28]:
reconstruct_data_fn = functools.partial(
            reconstruct_data_from_vae, model_dir=model_dir, adata_train=adata_train
        )

In [29]:
reconstruct_data_fn(adata=adata_train)

[34mINFO    [0m File [35m/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test/[0m[95mmodel.pt[0m       
         already downloaded                                                                                        


/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
  self.validate_field(adata)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 

Epoch 1/154:   1%|▏                     | 1/154 [00:01<03:20,  1.31s/it, v_num=1, train_loss=564, rec_loss=564, kld_loss=14]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 1/154:   1%|▏                     | 1/154 [00:01<03:24,  1.34s/it, v_num=1, train_loss=564, rec_loss=564, kld_loss=14]


array([[2.7503736e-02, 1.1914163e-02, 4.1405678e-02, ..., 1.9252802e-04,
        4.2046292e-04, 7.4307516e-04],
       [5.4195523e-01, 2.4819948e-02, 1.4656386e-01, ..., 4.9367110e-04,
        1.6829079e-03, 3.9218913e-04],
       [6.0350452e-02, 1.3270142e-02, 4.5655414e-02, ..., 3.1717151e-04,
        5.0042296e-04, 1.1857266e-03],
       ...,
       [7.0108501e-03, 2.6738483e-03, 8.5345637e-03, ..., 1.3301597e-04,
        3.3469306e-04, 4.7523857e-04],
       [9.3866192e-02, 6.1572436e-03, 1.7388327e-01, ..., 3.4696562e-04,
        2.5106650e-03, 5.5780029e-04],
       [2.3246704e-02, 1.9229881e-02, 5.1356196e-02, ..., 3.5961738e-04,
        3.8658368e-04, 7.7950372e-04]], dtype=float32)

In [None]:
adata_with_preds = write_preds_to_adata(prediction)

In [30]:
out_ood = reconstruct_data_fn(adata=adata_ood)

[34mINFO    [0m File [35m/lustre/groups/ml01/workspace/ot_perturbation/data/sciplex/combosciplex_factorvi_test/[0m[95mmodel.pt[0m       
         already downloaded                                                                                        


/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
  self.validate_field(adata)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
/home/icb/dominik.klein/mambaforge/envs/ot_pert_genot_scvi/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/dominik.klein/mambaforge/envs/ot_pert_geno ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 

Epoch 1/154:   1%|▏                   | 1/154 [00:01<03:10,  1.24s/it, v_num=1, train_loss=582, rec_loss=582, kld_loss=13.9]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 1/154:   1%|▏                   | 1/154 [00:01<03:11,  1.25s/it, v_num=1, train_loss=582, rec_loss=582, kld_loss=13.9]
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


  self.validate_field(adata)


In [32]:
out_ood.shape

(8896, 2000)

In [31]:
adata_with_preds = write_preds_to_adata(out_ood)

AttributeError: 'numpy.ndarray' object has no attribute 'items'

In [33]:
adata_train.n_obs

51882

In [None]:



# Load the saved model
loaded_model = juniper.latent.model.FactorVI.load(save_path, adata_train)