In [26]:
import numpy as np
import pandas as pd
import yaml
import torch
from copy import deepcopy
from tqdm import tqdm
import matplotlib.pyplot as plt
from celldreamer.eval.eval_utils import normalize_and_compute_metrics
import scipy.sparse as sp

from torch import nn
import scanpy as sc    
import muon as mu

from celldreamer.data.scrnaseq_loader import RNAseqLoader
from celldreamer.models.featurizers.category_featurizer import CategoricalFeaturizer
from celldreamer.models.fm.fm import FM
from celldreamer.eval.optimal_transport import wasserstein
import random
from celldreamer.models.base.encoder_model import EncoderModel
from celldreamer.models.base.utils import unsqueeze_right

from celldreamer.paths import DATA_DIR

device  = "cuda" if torch.cuda.is_available() else "cpu"
sc.settings.figdir = 'figures'  # Directory to save figures
sc.settings.set_figure_params(dpi=80, frameon=False, figsize=(6, 6)) 

## Step 1: Initialize data

In [27]:
dataset_config = {"dataset_path": DATA_DIR / "processed" / "atac" / "pbmc" / "pbmc10k_multiome_test.h5mu",
                    "layer_key": "X_counts",
                    "covariate_keys": ["cell_type"],
                    "conditioning_covariate": "cell_type",
                    "subsample_frac": 1,
                    "encoder_type": "learnt_autoencoder",
                    "target_max": 1,
                    "target_min": -1,
                    "one_hot_encode_features": False,
                    "split_rates": [0.80, 0.10, 0.10],
                    "cov_embedding_dimensions": 100,
                    "multimodal": True,
                    "is_binarized": True
                }

data_path = dataset_config["dataset_path"]
dataset = RNAseqLoader(data_path = data_path,
                        layer_key=dataset_config["layer_key"],
                        covariate_keys=dataset_config["covariate_keys"],
                        subsample_frac=dataset_config["subsample_frac"], 
                        encoder_type=dataset_config["encoder_type"], 
                        multimodal=dataset_config["multimodal"],
                        is_binarized=dataset_config["is_binarized"])



## Step 2: Initialize encoder

In [28]:
encoder_config = {"x0_from_x_kwargs": 
                      {"rna": {"dims": [512, 300, 200],
                               "batch_norm": True,
                               "dropout": False,
                               "dropout_p": 0.0},
                       "atac": {"dims": [1024, 512, 200],
                               "batch_norm": True,
                               "dropout": False,
                               "dropout_p": 0.0}},
                    "learning_rate": 0.001,
                    "weight_decay": 0.00001,
                    "covariate_specific_theta": False,
                    "multimodal": True, 
                    "is_binarized": True,
                    "encoder_multimodal_joint_layers":
                        {"dims": [200],
                        "batch_norm": False,
                        "dropout": False, 
                        "dropout_p": 0.0}
                 }

# encoder_config = {"x0_from_x_kwargs": 
#                       {"rna": {"dims": [512, 300, 100],
#                                "batch_norm": True,
#                                "dropout": False,
#                                "dropout_p": 0.0},
#                        "atac": {"dims": [1024, 512, 100],
#                                "batch_norm": True,
#                                "dropout": False,
#                                "dropout_p": 0.0}},
#                     "learning_rate": 0.001,
#                     "weight_decay": 0.00001,
#                     "covariate_specific_theta": False,
#                     "multimodal": True, 
#                     "is_binarized": True,
#                     "encoder_multimodal_joint_layers":
#                         {"dims": [100],
#                         "batch_norm": False,
#                         "dropout": False, 
#                         "dropout_p": 0.0}
#                  }

state_dict_path = "/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/REBUTTAL_train_autoencoder_pbmc10k_multimodal_joint/e00c2685-94b1-4a0a-b4a0-8e1fef5b8336/checkpoints/epoch_39.ckpt"
# state_dict_path = "/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/train_autoencoder_pbmc10k_multimodal_joint/64584c28-b271-40b4-8279-562f583f649b/checkpoints/last.ckpt"

In [29]:
gene_dim = {mod: dataset.X[mod].shape[1] for mod in dataset.X}
modality_list = list(gene_dim.keys())
in_dim = {}

for mod in dataset.X:
    if dataset_config["encoder_type"]!="learnt_autoencoder":
        in_dim[mod] = gene_dim[mod]
    else:
        in_dim[mod] = encoder_config["x0_from_x_kwargs"][mod]["dims"][-1]
                    
size_factor_statistics = {"mean": {mod: dataset.log_size_factor_mu[mod] for mod in dataset.log_size_factor_mu}, 
                            "sd": {mod: dataset.log_size_factor_sd[mod] for mod in dataset.log_size_factor_sd}}


n_cat = len(dataset.id2cov["cell_type"])

In [30]:
encoder_model = EncoderModel(in_dim=gene_dim,
                              n_cat=n_cat,
                              conditioning_covariate=dataset_config["conditioning_covariate"], 
                              encoder_type=dataset_config["encoder_type"],
                              **encoder_config)

encoder_model.load_state_dict(torch.load(state_dict_path)["state_dict"])

encoder_model.eval()

EncoderModel(
  (x0_from_x): ModuleDict(
    (rna): MLP(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=25604, out_features=512, bias=True)
          (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (1): Sequential(
          (0): Linear(in_features=512, out_features=300, bias=True)
          (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (2): Linear(in_features=300, out_features=200, bias=True)
      )
    )
    (atac): MLP(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=40086, out_features=1024, bias=True)
          (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (1): Sequential(
          (0): Linear(in_features=1024, out_features=512, bias=True)
         

## Step 3: Initialize FM model

In [31]:
generative_model_config = {'learning_rate': 0.0001,
                            'weight_decay': 0.00001,
                            'antithetic_time_sampling': True,
                            'sigma': 0.0001
                        }

In [32]:
# ckpt = torch.load("/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/fm_resnet_autoencoder_pbmc10k_multimodal_joint/bfde6062-7515-4517-8a75-7bfcdfffa3aa/checkpoints/last.ckpt")
ckpt = torch.load("/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/REBUTTAL_fm_resnet_autoencoder_pbmc10k_multimodal_joint/85de0ed0-5bb5-4640-a5b5-e6fdd2e92251/checkpoints/last.ckpt")

denoising_model = ckpt["hyper_parameters"]["denoising_model"]
denoising_model

MLPTimeStep(
  (time_embedder): Sequential(
    (0): Linear(in_features=100, out_features=100, bias=True)
    (1): SiLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
  )
  (net_in): Linear(in_features=200, out_features=64, bias=True)
  (blocks): ModuleList(
    (0-2): 3 x ResnetBlock(
      (net1): Sequential(
        (0): SiLU()
        (1): Linear(in_features=64, out_features=64, bias=True)
      )
      (cond_proj): Sequential(
        (0): SiLU()
        (1): Linear(in_features=100, out_features=64, bias=True)
      )
      (net2): Sequential(
        (0): SiLU()
        (1): Linear(in_features=64, out_features=64, bias=True)
      )
    )
  )
  (net_out): Sequential(
    (0): SiLU()
    (1): Linear(in_features=64, out_features=200, bias=True)
  )
)

In [33]:
print(ckpt["hyper_parameters"]["feature_embeddings"]["cell_type"].embeddings.weight)
feature_embeddings = ckpt["hyper_parameters"]["feature_embeddings"]

Parameter containing:
tensor([[-7.3886e-01, -2.8437e+00,  4.9177e-01,  ..., -7.5713e-01,
          5.0665e-01, -8.7012e-02],
        [-1.7139e+00, -1.0276e+00,  3.0598e-01,  ..., -2.8949e-01,
          1.2149e+00, -4.4779e-01],
        [ 5.4839e-01,  3.6341e-02,  3.5934e-01,  ..., -1.5601e+00,
         -1.1446e-03, -1.0338e+00],
        ...,
        [-8.9924e-01, -1.7542e+00,  1.3773e+00,  ..., -2.3005e-01,
         -5.6872e-01,  1.3988e+00],
        [ 1.2023e+00,  1.0147e+00, -2.0126e+00,  ..., -9.8723e-01,
          8.4818e-01,  7.8498e-01],
        [ 2.9876e+00,  4.8126e-01, -8.1359e-01,  ...,  5.7315e-01,
         -1.0720e+00, -1.7925e+00]], device='cuda:0', requires_grad=True)


Initializations

In [34]:
generative_model = FM(
            encoder_model=encoder_model,
            denoising_model=denoising_model,
            feature_embeddings=feature_embeddings,
            plotting_folder=None,
            in_dim=in_dim,
            size_factor_statistics=size_factor_statistics,
            encoder_type=dataset_config["encoder_type"],
            conditioning_covariate=dataset_config["conditioning_covariate"],
            model_type=denoising_model.model_type, 
            multimodal=dataset_config["multimodal"],
            is_binarized=dataset_config["is_binarized"], 
            modality_list=modality_list,
            **generative_model_config  # model_kwargs should contain the rest of the arguments
            )

generative_model.load_state_dict(ckpt["state_dict"])
generative_model.to("cuda")

/home/icb/alessandro.palma/miniconda3/envs/celldreamer/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'encoder_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder_model'])`.
/home/icb/alessandro.palma/miniconda3/envs/celldreamer/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'denoising_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['denoising_model'])`.


FM(
  (encoder_model): EncoderModel(
    (x0_from_x): ModuleDict(
      (rna): MLP(
        (net): Sequential(
          (0): Sequential(
            (0): Linear(in_features=25604, out_features=512, bias=True)
            (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ELU(alpha=1.0)
          )
          (1): Sequential(
            (0): Linear(in_features=512, out_features=300, bias=True)
            (1): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ELU(alpha=1.0)
          )
          (2): Linear(in_features=300, out_features=200, bias=True)
        )
      )
      (atac): MLP(
        (net): Sequential(
          (0): Sequential(
            (0): Linear(in_features=40086, out_features=1024, bias=True)
            (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ELU(alpha=1.0)
          )
          (1): Sequential(
       

# Save three copies of generated datasest 

In [35]:
adata_original = mu.read(data_path)
adata_rna = adata_original.mod["rna"]
adata_atac = adata_original.mod["atac"]
adata_rna.obs["size_factor"]=adata_rna.X.A.sum(1)
adata_atac.obs["size_factor"]=adata_atac.X.A.sum(1)
X_rna = torch.tensor(adata_rna.layers["X_counts"].todense())
X_atac = torch.tensor(adata_atac.layers["X_counts"].todense())



In [36]:
adata_original

In [37]:
index_range = len(adata_rna)
num_indices = 1000

for it in tqdm(range(3)):
    X_generated_list_rna = []
    X_generated_list_atac = []
    classes_str = []
    for _ in range(2):
        
        # Generate random dummy indices without replacement
        indices = random.sample(range(index_range), num_indices)
        adata_rna_indices = adata_rna[indices]
        adata_atac_indices = adata_atac[indices]
        
        classes =  list(adata_rna_indices.obs.cell_type)
        classes_str += classes
        classes = torch.tensor([dataset.id2cov["cell_type"][c] for  c in classes])

        log_size_factors = torch.log(torch.tensor(adata_rna_indices.layers["X_counts"].sum(1))).cuda()
    
        X_generated = generative_model.batched_sample(batch_size=100,
                                                    repetitions=10,
                                                    n_sample_steps=2, 
                                                    covariate="cell_type", 
                                                    covariate_indices=classes, 
                                                    log_size_factor=log_size_factors)
        
        X_generated_list_rna.append(X_generated["rna"].to("cpu"))
        X_generated_list_atac.append(X_generated["atac"].to("cpu"))

    X_generated_list_rna = torch.cat(X_generated_list_rna, dim=0)
    X_generated_list_atac = torch.cat(X_generated_list_atac, dim=0)
    
    adata_generated_rna = sc.AnnData(X=sp.csr_matrix(deepcopy(X_generated_list_rna).cpu().numpy()))
    adata_generated_rna.obs["size_factor"] = adata_generated_rna.X.sum(1)
    adata_generated_rna.obs["cell_type"] = classes_str
    adata_generated_rna.obs["cell_type"] = adata_generated_rna.obs["cell_type"].astype("category")
    adata_generated_rna.layers["X_counts"] = adata_generated_rna.X.copy()
    sc.pp.normalize_total(adata_generated_rna, target_sum=10000)
    sc.pp.log1p(adata_generated_rna)
    sc.tl.pca(adata_generated_rna)
    adata_generated_rna.write_h5ad(f"/home/icb/alessandro.palma/environment/cfgen/project_folder/datasets/generated/pbmc10k_multimodal/generated_cells_{it}_rna.h5ad")

    adata_generated_atac = sc.AnnData(X=sp.csr_matrix(deepcopy(X_generated_list_atac).cpu().numpy()))
    adata_generated_atac.obs["size_factor"] = adata_generated_atac.X.sum(1)
    adata_generated_atac.obs["cell_type"] = classes_str
    adata_generated_atac.obs["cell_type"] = adata_generated_atac.obs["cell_type"].astype("category")
    adata_generated_atac.layers["X_counts"] = adata_generated_atac.X.copy()
    sc.tl.pca(adata_generated_atac)
    adata_generated_atac.write_h5ad(f"/home/icb/alessandro.palma/environment/cfgen/project_folder/datasets/generated/pbmc10k_multimodal/generated_cells_{it}_atac.h5ad")

100%|██████████| 3/3 [03:29<00:00, 69.96s/it]
