In [1]:
import numpy as np
import pandas as pd
import yaml
import torch
from copy import deepcopy
import matplotlib.pyplot as plt
from celldreamer.eval.eval_utils import normalize_and_compute_metrics
import scipy.sparse as sp
import scipy
from tqdm import tqdm 

from torch import nn
import pickle
import scanpy as sc    

from celldreamer.data.scrnaseq_loader import RNAseqLoader
from celldreamer.models.featurizers.category_featurizer import CategoricalFeaturizer
from celldreamer.models.fm.fm import FM
from celldreamer.eval.optimal_transport import wasserstein
import random
from celldreamer.models.base.encoder_model import EncoderModel
from celldreamer.models.base.utils import unsqueeze_right

from celldreamer.paths import DATA_DIR

device  = "cuda" if torch.cuda.is_available() else "cpu"

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Step 1: Initialize data

In [5]:
dataset_config = {'dataset_path': DATA_DIR / 'processed' / 'classifier_experiment_hlca' / 'hlca_train.h5ad',
                    'layer_key': 'X_counts',
                    'covariate_keys': ['cell_type'],
                    'conditioning_covariate': 'cell_type',
                    'subsample_frac': 1,
                    'encoder_type': 'learnt_autoencoder',
                    'n_dimensions': None,
                    'one_hot_encode_features': True,
                    'split_rates': [0.90, 0.05, 0.05],
                    'cov_embedding_dimensions': None}

data_path = dataset_config["dataset_path"]

dataset = RNAseqLoader(data_path=data_path,
                                layer_key=dataset_config["layer_key"],
                                covariate_keys=dataset_config["covariate_keys"],
                                subsample_frac=dataset_config["subsample_frac"], 
                                encoder_type=dataset_config["encoder_type"])

in_dim = dataset.X.shape[1]
size_factor_statistics = {"mean": dataset.log_size_factor_mu, 
                                  "sd": dataset.log_size_factor_sd}
n_cat = len(dataset.id2cov["cell_type"])

id2cov = dataset.id2cov



## Step 2: Initialize encoder

In [6]:
encoder_config = {
    "x0_from_x_kwargs": {
        "dims": [512, 256, 100],
        "batch_norm": True,
        "dropout": False,
        "dropout_p": 0.0
    },
    "learning_rate": 0.001,
    "weight_decay": 0.00001,
    "covariate_specific_theta": False
}

state_dict_path = "/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/CLASSIFICATION_off_train_autoencoder_hlca_core/3dde9d9c-2fbc-4417-bdd7-9e8ebe1c948b/checkpoints/last.ckpt"

In [7]:
encoder_model = EncoderModel(in_dim=in_dim,
                              n_cat=n_cat,
                              conditioning_covariate=dataset_config["conditioning_covariate"], 
                              encoder_type=dataset_config["encoder_type"],
                              **encoder_config)
encoder_model.eval()

encoder_model.load_state_dict(torch.load(state_dict_path)["state_dict"])

<All keys matched successfully>

## Step 3: Initialize FM model

In [8]:
generative_model_config = {'learning_rate': 0.0001,
                            'weight_decay': 0.00001,
                            'antithetic_time_sampling': True,
                            'sigma': 0.0001}

In [9]:
ckpt = torch.load("/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/CLASSIFIER_off_fm_resnet_autoencoder_hlca_core_whole_genome/b32358e8-ce61-4095-aa92-a9a153cbc8fb/checkpoints/last.ckpt")

denoising_model = ckpt["hyper_parameters"]["denoising_model"]
denoising_model.multimodal = False

In [10]:
ckpt["hyper_parameters"]

{'encoder_model': EncoderModel(
   (x0_from_x): MLP(
     (net): Sequential(
       (0): Sequential(
         (0): Linear(in_features=2000, out_features=512, bias=True)
         (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): ELU(alpha=1.0)
       )
       (1): Sequential(
         (0): Linear(in_features=512, out_features=256, bias=True)
         (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): ELU(alpha=1.0)
       )
       (2): Linear(in_features=256, out_features=100, bias=True)
     )
   )
   (x_from_x0): MLP(
     (net): Sequential(
       (0): Sequential(
         (0): Linear(in_features=100, out_features=256, bias=True)
         (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (2): ELU(alpha=1.0)
       )
       (1): Sequential(
         (0): Linear(in_features=256, out_features=512, bias=True)
         (1): BatchNorm1d(512, eps=1e

In [11]:
print(ckpt["hyper_parameters"]["feature_embeddings"]["cell_type"].embeddings.weight)
feature_embeddings = ckpt["hyper_parameters"]["feature_embeddings"]

Parameter containing:
tensor([[-0.6750,  1.5693, -0.0830,  ..., -3.1055,  1.3521,  1.8101],
        [ 0.4216,  0.1655, -3.0730,  ..., -0.9761, -0.6504, -0.3432],
        [ 1.1484,  1.6554,  0.4003,  ..., -0.2824,  1.6959,  0.0993],
        ...,
        [ 0.8393,  0.4958, -2.2871,  ...,  1.3499,  0.6753, -0.8039],
        [-0.4984,  0.1710, -1.4032,  ...,  1.5160, -0.7917, -0.8067],
        [ 0.7541,  1.8842,  2.0116,  ...,  0.8489,  0.4313,  0.9325]],
       device='cuda:0', requires_grad=True)


Initializations

In [12]:
generative_model = FM(
            encoder_model=encoder_model,
            denoising_model=denoising_model,
            feature_embeddings=feature_embeddings,
            plotting_folder=None,
            in_dim=512,
            size_factor_statistics=size_factor_statistics,
            encoder_type=dataset_config["encoder_type"],
            conditioning_covariate=dataset_config["conditioning_covariate"],
            model_type=denoising_model.model_type, 
            **generative_model_config  # model_kwargs should contain the rest of the arguments
            )

generative_model.load_state_dict(ckpt["state_dict"])
generative_model.to("cuda")

/home/icb/alessandro.palma/miniconda3/envs/celldreamer/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'encoder_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder_model'])`.
/home/icb/alessandro.palma/miniconda3/envs/celldreamer/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'denoising_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['denoising_model'])`.


FM(
  (encoder_model): EncoderModel(
    (x0_from_x): MLP(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=2000, out_features=512, bias=True)
          (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (1): Sequential(
          (0): Linear(in_features=512, out_features=256, bias=True)
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (2): Linear(in_features=256, out_features=100, bias=True)
      )
    )
    (x_from_x0): MLP(
      (net): Sequential(
        (0): Sequential(
          (0): Linear(in_features=100, out_features=256, bias=True)
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ELU(alpha=1.0)
        )
        (1): Sequential(
          (0): Linear(in_features=256, out_features=512, bias=True)
        

In [13]:
del ckpt
del encoder_model

## Read original data 

In [14]:
adata_original = sc.read_h5ad(data_path)
adata_original.obs["size_factor"]=adata_original.X.A.sum(1)
X = torch.tensor(adata_original.layers["X_counts"].todense())

In [15]:
adata_original

AnnData object with n_obs × n_vars = 429918 × 2000
    obs: 'suspension_type', 'donor_id', 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ann_coarse_for_GWAS_and_modeling', 'ann_finest_level', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'cause_of_death', 'dataset', 'entropy_dataset_leiden_3', 'entropy_original_ann_level_1_leiden_3', 'entropy_original_ann_level_2_clean_leiden_3', 'entropy_original_ann_level_3_clean_leiden_3', 'entropy_subject_ID_leiden_3', 'fresh_or_frozen', 'leiden_1', 'leiden_2', 'leiden_3', 'leiden_4', 'leiden_5', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'n_genes_detected', 'original_ann_highest_res', 'original_ann_level_1', 'o

In [16]:
adata_original.X.A.max(1)

array([3.534778 , 3.1689792, 3.0558972, ..., 3.1154528, 3.119567 ,
       3.487646 ], dtype=float32)

## Generate and save cells

In [17]:
# unique_classes = np.unique(adata_original.obs.cell_type, return_counts=True)
# class_freq_dict = dict(zip(unique_classes[0], unique_classes[1]))
# norm_const = np.sum(1/unique_classes[1])
# class_prop_dict = dict(zip(unique_classes[0], (1/unique_classes[1])/norm_const))
# class_prop = (1/unique_classes[1])/norm_const
# class_idx = dict(zip(range(len(unique_classes[0])), unique_classes[0]))

In [18]:
unique_classes = np.unique(adata_original.obs.cell_type, return_counts=True)
class_freq_dict = dict(zip(unique_classes[0], unique_classes[1]))
norm_const = np.sum(1/unique_classes[1])
class_prop_dict = dict(zip(unique_classes[0], (1/unique_classes[1])/norm_const))
class_prop = (1/unique_classes[1])/norm_const
class_idx = dict(zip(range(len(unique_classes[0])), unique_classes[0]))

# Generate 100000 cells
to_gen = (class_prop*300100).astype(int)

In [20]:
samples = []

for i, freq in enumerate(to_gen):
    samples += [i] * freq

In [21]:
samples = samples[:300000]

## General conditional

In [23]:
condition_names = [class_idx[sample] for sample in samples]
condition_names_unique = np.unique(condition_names, return_counts=True)
condition_names_unique = dict(zip(condition_names_unique[0], condition_names_unique[1]))
condition_names = sorted(condition_names)
size_factor = []
for ct in sorted(condition_names_unique.keys()):
    adata_original_c = adata_original[adata_original.obs["cell_type"]==ct]
    i = np.random.choice(len(adata_original_c), condition_names_unique[ct], replace=True)
    size_factor += list(adata_original_c[i].layers["X_counts"].A.sum(1))

size_factor = torch.log(torch.tensor(size_factor)).unsqueeze(1).cuda()

In [None]:
# GENERATE REAL
# ids = np.random.choice(range(len(adata_original)), size=n_to_sample)
# condition_names = list(adata_original[ids].obs.cell_type)
# log_lib_sizes = torch.log(torch.tensor(adata_original[ids].layers["X_counts"].sum(1))).cuda()

In [24]:
condition_val = torch.tensor([dataset.id2cov["cell_type"][condition_name] for condition_name in condition_names]).long()

X_generated = generative_model.batched_sample(batch_size=100,
                                            repetitions=3000,
                                            n_sample_steps=2, 
                                            covariate="cell_type", 
                                            covariate_indices=condition_val, 
                                            log_size_factor=size_factor)

In [25]:
# sc.pp.normalize_total(adata_generated, target_sum=1e4)
# sc.pp.log1p(adata_generated)
# sc.tl.pca(adata_generated)
# sc.pp.neighbors(adata_generated)
# sc.tl.umap(adata_generated)

In [26]:
# adata_generated.obs["size_factor"] = adata_generated.X.sum(1)
# sc.pl.pca(adata_generated, color="size_factor", annotate_var_explained=True)
# sc.pl.umap(adata_generated, color="size_factor")

## Preprocess real data object

In [27]:
# sc.pp.normalize_total(adata_original, target_sum=1e4)
# sc.pp.log1p(adata_original)
# sc.tl.pca(adata_original)
# sc.pp.neighbors(adata_original)
# sc.tl.umap(adata_original)

In [28]:
# sc.pl.pca(adata_original, color=["size_factor", "cell_type"], annotate_var_explained=True)
# sc.pl.umap(adata_original, color=["size_factor", "cell_type"])

In [29]:
# adata_generated.write_h5ad("/home/icb/alessandro.palma/environment/celldreamer/project_folder/datasets/generated/hlca_core.h5ad")

## Merge and plot the generated and real data

In [30]:
dataset_type = ["Real" for _ in range(X.shape[0])] + ["Generated" for _ in range(X_generated.shape[0])]
dataset_type = pd.DataFrame(dataset_type)
dataset_type.columns = ["dataset_type"]
dataset_type["cell_type"] = list(adata_original.obs.cell_type)+condition_names

In [31]:
adata_merged = sc.AnnData(X=torch.cat([X, X_generated], dim=0).numpy(), 
                             obs=dataset_type)



In [32]:
adata_merged.layers["X_counts"] = adata_merged.X.copy()
adata_merged.var = adata_original.var.copy()

In [33]:
sc.pp.log1p(adata_merged)
sc.tl.pca(adata_merged)
# sc.pp.neighbors(adata_merged)
# sc.tl.umap(adata_merged)

In [34]:
# sc.pl.pca(adata_merged, color="dataset_type",  annotate_var_explained=True)
# sc.pl.umap(adata_merged, color="dataset_type")
# sc.pl.umap(adata_merged, color="cell_type")

In [35]:
adata_merged.write_h5ad("/home/icb/alessandro.palma/environment/cfgen/project_folder/datasets/processed/classifier_experiment_hlca/augmented/rebuttals_augmentations/hlca_train_augmented_prop_10000.h5ad")

In [36]:
adata_merged

AnnData object with n_obs × n_vars = 729918 × 2000
    obs: 'dataset_type', 'cell_type'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'log1p', 'pca'
    obsm: 'X_pca'
    varm: 'PCs'
    layers: 'X_counts'

In [37]:
adata_merged.obs.cell_type.value_counts()

cell_type
hematopoietic stem cell                            73584
respiratory basal cell                             66788
alveolar macrophage                                47892
type II pneumocyte                                 41558
lung neuroendocrine cell                           29656
club cell                                          29068
nasal mucosa goblet cell                           27550
brush cell of trachebronchial tree                 25832
ciliated columnar cell of tracheobronchial tree    25028
elicited macrophage                                22538
CD8-positive, alpha-beta T cell                    21610
mesothelial cell                                   19426
capillary endothelial cell                         19107
CD4-positive, alpha-beta T cell                    16209
natural killer cell                                14896
conventional dendritic cell                        14849
fibroblast                                         14364
classical monocyte   