In [1]:
import numpy as np
import pandas as pd
import yaml
import torch
from copy import deepcopy
from tqdm import tqdm
import matplotlib.pyplot as plt
from celldreamer.eval.eval_utils import normalize_and_compute_metrics
import scipy.sparse as sparse
import scipy as sp

from torch import nn
import scanpy as sc    

from celldreamer.data.scrnaseq_loader import RNAseqLoader
from celldreamer.models.featurizers.category_featurizer import CategoricalFeaturizer
from celldreamer.models.fm.fm import FM
from celldreamer.eval.optimal_transport import wasserstein
import random
from celldreamer.models.base.encoder_model import EncoderModel
from celldreamer.models.base.utils import unsqueeze_right

from celldreamer.paths import DATA_DIR

device  = "cuda" if torch.cuda.is_available() else "cpu"

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Step 1: Initialize data

In [2]:
dataset_config = {'dataset_path': DATA_DIR / 'processed' / 'classifier_experiment_pbmc' / 'pbmc_covid_train.h5ad',
                    'layer_key': 'X_counts',
                    'covariate_keys': ['cell_type'],
                    'conditioning_covariate': 'cell_type',
                    'subsample_frac': 1,
                    'encoder_type': 'learnt_autoencoder',
                    'target_max': 1,
                    'target_min': -1,
                    'one_hot_encode_features': False,
                    'split_rates': [0.90, 0.05, 0.05],
                    'cov_embedding_dimensions': 256}


data_path = dataset_config["dataset_path"]

dataset = RNAseqLoader(data_path=data_path,
                                layer_key=dataset_config["layer_key"],
                                covariate_keys=dataset_config["covariate_keys"],
                                subsample_frac=dataset_config["subsample_frac"], 
                                encoder_type=dataset_config["encoder_type"],)

dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=256,
                                            shuffle=True,
                                            num_workers=4, 
                                            drop_last=True)

in_dim = dataset.X.shape[1]
size_factor_statistics = {"mean": dataset.log_size_factor_mu, 
                                  "sd": dataset.log_size_factor_sd}
n_cat = len(dataset.id2cov["cell_type"])



## Step 2: Initialize encoder

In [3]:
encoder_config = {
    "x0_from_x_kwargs": {
        "dims": [512, 256, 50],
        "batch_norm": True,
        "dropout": False,
        "dropout_p": 0.0
    },
    "learning_rate": 0.001,
    "weight_decay": 0.00001,
    "covariate_specific_theta": False
}

state_dict_path = "/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/train_autoencoder_pbmc_covid/21d30ca0-932d-4891-9796-938e9033b267/checkpoints/last.ckpt"

In [4]:
encoder_model = EncoderModel(in_dim=in_dim,
                              n_cat=n_cat,
                              conditioning_covariate=dataset_config["conditioning_covariate"], 
                              encoder_type=dataset_config["encoder_type"],
                              **encoder_config)

encoder_model.load_state_dict(torch.load(state_dict_path)["state_dict"])

encoder_model.eval()

EncoderModel(
  (x0_from_x): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=2000, out_features=512, bias=True)
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
      (1): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
      (2): Linear(in_features=256, out_features=50, bias=True)
    )
  )
  (x_from_x0): MLP(
    (net): Sequential(
      (0): Sequential(
        (0): Linear(in_features=50, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
      )
      (1): Sequential(
        (0): Linear(in_features=256, out_features=512, bias=True)
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running

## Initialize FM model

In [5]:
generative_model_config = {'learning_rate': 0.0001,
                            'weight_decay': 0.00001,
                            'antithetic_time_sampling': True,
                            'sigma': 0.0001
                        }

In [6]:
ckpt = torch.load("/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/fm_resnet_autoencoder_pbmc_covid_whole_genome/a075cbb3-f9d8-4c34-bc3d-139ba928ba56/checkpoints/last.ckpt")

denoising_model = ckpt["hyper_parameters"]["denoising_model"]
denoising_model.multimodal = False

In [7]:
denoising_model.embed_size_factor = True

In [8]:
print(ckpt["hyper_parameters"]["feature_embeddings"]["cell_type"].embeddings.weight)
feature_embeddings = ckpt["hyper_parameters"]["feature_embeddings"]

Parameter containing:
tensor([[-2.6352, -0.3629, -0.4067,  ...,  0.2508, -0.0170, -0.7093],
        [ 2.3055, -0.2089,  0.3508,  ...,  0.0629,  1.0793,  1.6488],
        [ 1.5015,  1.1049, -0.8106,  ...,  0.3859, -0.1244,  1.3552],
        ...,
        [ 1.3194,  1.0029, -3.6609,  ...,  0.3767, -0.1811,  1.9171],
        [ 0.7573,  1.5707,  0.3263,  ...,  1.5110,  0.9219, -2.6991],
        [-0.5664,  0.6156, -0.3074,  ...,  1.2680,  0.8071,  1.5441]],
       device='cuda:0', requires_grad=True)


Initializations

In [9]:
generative_model = FM(
            encoder_model=encoder_model,
            denoising_model=denoising_model,
            feature_embeddings=feature_embeddings,
            plotting_folder=None,
            in_dim=512,
            size_factor_statistics=size_factor_statistics,
            encoder_type=dataset_config["encoder_type"],
            conditioning_covariate=dataset_config["conditioning_covariate"],
            model_type=denoising_model.model_type, 
            **generative_model_config  # model_kwargs should contain the rest of the arguments
            )

generative_model.load_state_dict(ckpt["state_dict"])
generative_model.to("cuda")

/home/icb/alessandro.palma/miniconda3/envs/celldreamer/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'encoder_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder_model'])`.
/home/icb/alessandro.palma/miniconda3/envs/celldreamer/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'denoising_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['denoising_model'])`.


FM(
  (encoder_model): EncoderModel(
    (x0_from_x): MLP(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=2000, out_features=512, bias=True)
          (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (1): Sequential(
          (0): Linear(in_features=512, out_features=256, bias=True)
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (2): Linear(in_features=256, out_features=50, bias=True)
      )
    )
    (x_from_x0): MLP(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=50, out_features=256, bias=True)
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (1): Sequential(
          (0): Linear(in_features=256, out_features=512, bias=True)
          

## Read original data 

In [10]:
adata_original = sc.read_h5ad(data_path)
adata_original.obs["size_factor"]=adata_original.X.A.sum(1)
X = torch.tensor(adata_original.layers["X_counts"].todense())

In [11]:
adata_original.obs.cell_type

CV001_KM10202384-CV001_KM10202394_AAACCTGCAGATGGGT-1                           CD4-positive helper T cell
CV001_KM10202384-CV001_KM10202394_AAACCTGGTTACGCGC-1       central memory CD8-positive, alpha-beta T cell
CV001_KM10202384-CV001_KM10202394_AAACCTGGTTCGAATC-1                                         naive B cell
CV001_KM10202384-CV001_KM10202394_AAACCTGTCAGCACAT-1    naive thymus-derived CD8-positive, alpha-beta ...
CV001_KM10202384-CV001_KM10202394_AAACCTGTCCCATTAT-1    naive thymus-derived CD8-positive, alpha-beta ...
                                                                              ...                        
S28_TTTGTCAGTTCTGTTT-1                                                                natural killer cell
S28_TTTGTCATCAACCAAC-1                                                                 classical monocyte
S28_TTTGTCATCATTATCC-1                                                                 classical monocyte
S28_TTTGTCATCCTATGTT-1                        

## Generate and save cells

In [12]:
unique_classes = np.unique(adata_original.obs.cell_type, return_counts=True)
class_freq_dict = dict(zip(unique_classes[0], unique_classes[1]))
norm_const = np.sum(1/unique_classes[1])
class_prop_dict = dict(zip(unique_classes[0], (1/unique_classes[1])/norm_const))
class_prop = (1/unique_classes[1])/norm_const
class_idx = dict(zip(range(len(unique_classes[0])), unique_classes[0]))

# Generate 100000 cells
to_gen = (class_prop*400100).astype(int)

In [13]:
samples = []

for i, freq in enumerate(to_gen):
    samples += [i] * freq

In [14]:
len(samples)

400086

In [15]:
samples = samples[:400000]

## General conditional

In [16]:
# n_to_sample = 1000

In [17]:
# np.random.seed(42)
# samples = np.random.choice(range(len(class_idx)), size=n_to_sample, replace=True, p=sp.special.softmax(class_prop))

In [18]:
condition_names = [class_idx[sample] for sample in samples]
condition_names_unique = np.unique(condition_names, return_counts=True)
condition_names_unique = dict(zip(condition_names_unique[0], condition_names_unique[1]))
condition_names = sorted(condition_names)
size_factor = []
for ct in sorted(condition_names_unique.keys()):
    adata_original_c = adata_original[adata_original.obs["cell_type"]==ct]
    i = np.random.choice(len(adata_original_c), condition_names_unique[ct], replace=True)
    size_factor += list(adata_original_c[i].layers["X_counts"].A.sum(1))

size_factor = torch.log(torch.tensor(size_factor)).unsqueeze(1).cuda()

In [19]:
condition_val = torch.tensor([dataset.id2cov["cell_type"][condition_name] for condition_name in condition_names]).long()

X_generated = generative_model.batched_sample(batch_size=100,
                                            repetitions=4000,
                                            n_sample_steps=2, 
                                            covariate="cell_type", 
                                            covariate_indices=condition_val, 
                                            log_size_factor=size_factor)

X_generated_list= X_generated.to("cpu")

## Merge and plot the generated and real data

In [20]:
dataset_type = ["Real" for _ in range(X.shape[0])] + ["Generated" for _ in range(X_generated.shape[0])]
dataset_type = pd.DataFrame(dataset_type)
dataset_type.columns = ["dataset_type"]
dataset_type["cell_type"] = list(adata_original.obs.cell_type)+condition_names

In [21]:
adata_merged = sc.AnnData(X=torch.cat([X, X_generated], dim=0).numpy(), 
                             obs=dataset_type)



In [22]:
adata_merged.layers["X_counts"] = adata_merged.X.copy()
adata_merged.var = adata_original.var.copy()

In [23]:
sc.pp.log1p(adata_merged)
sc.tl.pca(adata_merged)
# sc.pp.neighbors(adata_merged)
# sc.tl.umap(adata_merged)

In [24]:
# sc.pl.pca(adata_merged, color="dataset_type",  annotate_var_explained=True)
# sc.pl.umap(adata_merged, color="dataset_type")
# sc.pl.umap(adata_merged, color="cell_type")

In [25]:
# sc.write("/home/icb/alessandro.palma/environment/celldreamer/project_folder/datasets/processed/classifier_experiment/augmented/pbmc_covid_augmented_prop.h5ad", adata_merged)

In [26]:
adata_merged.write_h5ad("/home/icb/alessandro.palma/environment/cfgen/project_folder/datasets/processed/classifier_experiment_pbmc/augmented/rebuttals_augmentations/pbmc_covid_train_augmented_prop_10000.h5ad")

In [28]:
adata_merged.obs.cell_type.value_counts().sum()

744820