In [1]:
import numpy as np
import pandas as pd
import yaml
import torch
from copy import deepcopy
import matplotlib.pyplot as plt
from celldreamer.eval.eval_utils import normalize_and_compute_metrics
import scipy.sparse as sp
from celldreamer.eval.eval_utils import join_real_generated

from torch import nn
import scanpy as sc    

from celldreamer.data.scrnaseq_loader import RNAseqLoader
from celldreamer.models.featurizers.category_featurizer import CategoricalFeaturizer
from celldreamer.models.fm.fm import FM
from celldreamer.eval.optimal_transport import wasserstein
import random
from celldreamer.models.base.encoder_model import EncoderModel
from celldreamer.models.base.utils import unsqueeze_right
from celldreamer.paths import DATA_DIR

device  = "cuda" if torch.cuda.is_available() else "cpu"

sc.set_figure_params(dpi=100, frameon=False, fontsize=12)


import matplotlib
from matplotlib import rc
import matplotlib.pyplot as plt 
matplotlib.rcdefaults()

from pathlib import Path
matplotlib_rc = { 
    # 'text': {'usetex': True}, 
    'font': {'family': 'serif'} } # Apply changes for k, v in matplotlib_rc.items(): rc(k, **v)
for k, v in matplotlib_rc.items():
          rc(k, **v)

from matplotlib import rcParams

FIGSIZE = (3, 3)
rcParams["figure.figsize"] = FIGSIZE

KeyboardInterrupt: 

## Step 1: Initialize data

In [None]:
dataset_config = {'dataset_path': DATA_DIR / 'processed_full_genome' / 'neurips' / 'neurips_test.h5ad',
                    'layer_key': 'X_counts',
                    'covariate_keys': ['cell_type', 'DonorNumber'],
                    'subsample_frac': 1,
                    'encoder_type': 'learnt_autoencoder',
                    'one_hot_encode_features': False,
                    'split_rates': [0.90, 0.10],
                    'cov_embedding_dimensions': 128, 
                    'multimodal': False, 
                    'is_binarized': False,
                    'theta_covariate': 'cell_type', 
                    'size_factor_covariate': 'DonorNumber',
                    'guidance_weights': {'cell_type': 1, 'DonorNumber': 1}
                 }

data_path = dataset_config["dataset_path"]

dataset = RNAseqLoader(data_path=data_path,
                        layer_key=dataset_config["layer_key"],
                        covariate_keys=dataset_config["covariate_keys"],
                        subsample_frac=dataset_config["subsample_frac"], 
                        encoder_type=dataset_config["encoder_type"],
                        multimodal=dataset_config["multimodal"],
                        is_binarized=dataset_config["is_binarized"]
                      )

dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=256,
                                            shuffle=True,
                                            num_workers=4, 
                                            drop_last=True)

in_dim = dataset.X.shape[1]
size_factor_statistics = {"mean": dataset.log_size_factor_mu, 
                                  "sd": dataset.log_size_factor_sd}

n_cat = len(dataset.id2cov["cell_type"])

## Step 2: Initialize encoder

In [None]:
encoder_config = {
    "x0_from_x_kwargs": {
        "dims": [512, 256, 50],
        "batch_norm": True,
        "dropout": False,
        "dropout_p": 0.0
    },
    "learning_rate": 0.001,
    "weight_decay": 0.00001,
    "covariate_specific_theta": False, 
    "multimodal": False,
    "is_binarized": False}


state_dict_path = "/home/icb/alessandro.palma/environment/cfgen/project_folder/experiments/off_train_autoencoder_neurips_whole_genome/eabc6534-947b-4486-8012-c9e351b297ca/checkpoints/epoch_59.ckpt"

In [None]:
encoder_model = EncoderModel(in_dim=in_dim,
                              n_cat=n_cat,
                              conditioning_covariate=dataset_config["theta_covariate"], 
                              encoder_type=dataset_config["encoder_type"],
                              **encoder_config)

encoder_model.eval()

encoder_model.load_state_dict(torch.load(state_dict_path)["state_dict"])

## Initialize FM model

In [None]:
generative_model_config = {'learning_rate': 0.0001,
                            'weight_decay': 0.00001,
                            'antithetic_time_sampling': True,
                            'sigma': 0.5}

In [None]:
ckpt = torch.load("/home/icb/alessandro.palma/environment/celldreamer/project_folder/experiments/GUIDED_MULTILAB_NEURIPS/04ddda45-0e7f-4d9c-84fa-00386d0e9668/checkpoints/last.ckpt")

denoising_model = ckpt["hyper_parameters"]["denoising_model"]
denoising_model.multimodal = False

In [None]:
print(ckpt["hyper_parameters"]["feature_embeddings"]["cell_type"].embeddings.weight)
feature_embeddings = ckpt["hyper_parameters"]["feature_embeddings"]

Initializations

In [None]:
generative_model = FM(
            encoder_model=encoder_model,
            denoising_model=denoising_model,
            feature_embeddings=feature_embeddings,
            plotting_folder=None,
            in_dim=50,
            size_factor_statistics=size_factor_statistics,
            covariate_list=dataset_config["covariate_keys"],
            theta_covariate=dataset_config["theta_covariate"],
            size_factor_covariate=dataset_config["size_factor_covariate"],
            model_type=denoising_model.model_type, 
            encoder_type=dataset_config["encoder_type"],
            multimodal=dataset_config["multimodal"],
            is_binarized=False,
            modality_list=None,
            guidance_weights={'cell_type': 1, 'DonorNumber': 1},
            **generative_model_config  # model_kwargs should contain the rest of the arguments
            )

generative_model.load_state_dict(ckpt["state_dict"])
generative_model.to("cuda")

In [None]:
generative_model.denoising_model.guided_conditioning = True

**Check the data**

In [None]:
adata_original = sc.read_h5ad(data_path)
adata_original.obs["size_factor"]=adata_original.X.A.sum(1)
adata_original.X = adata_original.layers["X_counts"].A.copy()
X = torch.tensor(adata_original.layers["X_counts"].todense())

In [None]:
sc.pl.umap(adata_original, color=["cell_type"], s=10)
sc.pl.umap(adata_original, color=["DonorNumber"], s=10)

In [None]:
sc.pl.umap(adata_original, color="cell_type", groups="CD4+ T activated")

# COMBINATION OF CONDITIONS 

In [None]:
generative_model.guidance_weights

In [None]:
dataset.id2cov["cell_type"]

In [None]:
dataset.id2cov["DonorNumber"]

In [None]:
# Pick condition 
covariates_cell_types = 'CD14+ Mono'
covariates_donor = 'donor1'

# Get the labels
condition_id_ct = dataset.id2cov["cell_type"][covariates_cell_types]
condition_id_donor = dataset.id2cov["DonorNumber"][covariates_donor]

# Get the labels
adata_cond = adata_original[np.logical_and(adata_original.obs.cell_type==covariates_cell_types, 
                                           adata_original.obs.DonorNumber==covariates_donor)]
n_to_generate = adata_cond.shape[0]

classes = {"cell_type": condition_id_ct*torch.ones(n_to_generate).long(), 
           "DonorNumber": condition_id_donor*torch.ones(n_to_generate).long()}

classes_str = {"cell_type": [covariates_cell_types for n in range(n_to_generate)], 
           "DonorNumber": [covariates_donor for n in range(n_to_generate)]}

log_size_factors = torch.log(torch.tensor(adata_cond.layers["X_counts"].sum(1))).cuda()

In [None]:
X_generated = generative_model.sample(batch_size=n_to_generate,
                                        n_sample_steps=100, 
                                        theta_covariate="cell_type", 
                                        size_factor_covariate="cell_type",
                                        conditioning_covariates=["cell_type", "DonorNumber"],
                                        covariate_indices=classes, 
                                        log_size_factor=log_size_factors)

X_generated = X_generated.to("cpu")

adata_generated = sc.AnnData(X=X_generated.numpy())

In [None]:
adata_merged = join_real_generated(adata_original, adata_generated, True, classes_str, ["cell_type", "DonorNumber"])

In [None]:
sc.pp.neighbors(adata_merged)
sc.tl.umap(adata_merged)

In [None]:
sc.pl.umap(adata_merged[adata_merged.obs.dataset_type=="Real"], color="cell_type")
sc.pl.umap(adata_merged[adata_merged.obs.dataset_type=="Real"], color="DonorNumber")
sc.pl.umap(adata_merged, color="dataset_type", groups="Generated", s=10)

In [None]:
# sc.pl.pca(adata_merged, color="dataset_type",  annotate_var_explained=True)
sc.pl.umap(adata_merged, color="dataset_type", groups="Generated",s=10)
sc.pl.umap(adata_merged, color="cell_type", groups="CD14+ Mono",s=10)
sc.pl.umap(adata_merged, color="DonorNumber", groups="donor1",s=10, palette="jet", add_outline=True)

## Experiment guidance weights

In [None]:
covariates_cell_types = 'CD14+ Mono'
covariates_donor = 'donor1'

# Get the labels
condition_id_ct = dataset.id2cov["cell_type"][covariates_cell_types]
condition_id_donor = dataset.id2cov["DonorNumber"][covariates_donor]

donor_weight_data = {}
adata_cond = adata_original[np.logical_and(adata_original.obs.cell_type==covariates_cell_types, 
                                           adata_original.obs.DonorNumber==covariates_donor)]
n_to_generate = adata_cond.shape[0]

classes = {"cell_type": condition_id_ct*torch.ones(2000).long(), 
           "DonorNumber": condition_id_donor*torch.ones(2000).long()}

classes_str = {"cell_type": [covariates_cell_types for n in range(n_to_generate)], 
           "DonorNumber": [covariates_donor for n in range(n_to_generate)]}

log_size_factors = torch.log(torch.tensor(adata_cond.layers["X_counts"].sum(1))).cuda()

for donor_weight in [0, 1, 2, 5]:
    # Get the labels

    
    X_generated = generative_model.sample(batch_size=2000,
                                        n_sample_steps=2, 
                                        theta_covariate="cell_type", 
                                        size_factor_covariate="cell_type",
                                        conditioning_covariates=["cell_type", "DonorNumber"],
                                        covariate_indices=classes, 
                                        log_size_factor=None, 
                                        guidance_weights={'cell_type': 1, 'DonorNumber': donor_weight})

    X_generated = X_generated.to("cpu")
        
    donor_weight_data[donor_weight] = X_generated

In [None]:
adata_joint_mat = [adata_original.layers["X_counts"].A.copy()]
weights = ["Real data" for _ in range(len(adata_joint_mat[0]))]
donor = list(adata_original.obs["DonorNumber"])
cell_type = list(adata_original.obs["cell_type"])


for w in donor_weight_data:
    adata_joint_mat.append(donor_weight_data[w])
    weights += [str(w) for _ in range(len(donor_weight_data[w]))]
    donor += ["g_donor1" for _ in range(len(donor_weight_data[w]))]
    cell_type += ["g_CD14+ Mono" for _ in range(len(donor_weight_data[w]))]

# adata_joint_mat.append(X_generated_uncond)
# weights += ["0_0" for _ in range(len(X_generated_uncond))]
# donor += ["g_donor1 0_0" for _ in range(len(X_generated_uncond))]
# cell_type += ["g_CD14+ 0_0" for _ in range(len(X_generated_uncond))]

In [None]:
obs = pd.DataFrame(weights)
obs.columns = ["guidance weight"]
obs["donor"] = donor
obs["cell_type"] = cell_type

In [None]:
adata_joint = sc.AnnData(X=np.concatenate(adata_joint_mat,axis=0),
                        obs=obs)

In [None]:
sc.pp.normalize_total(adata_joint, target_sum=1e4)
sc.pp.log1p(adata_joint)
sc.tl.pca(adata_joint)
sc.pp.neighbors(adata_joint)
sc.tl.umap(adata_joint)

In [None]:
sc.pl.umap(adata_joint, color="cell_type")

In [None]:
sc.pl.umap(adata_joint, color="donor")

In [None]:
sc.pl.umap(adata_joint, color="guidance weight")

In [None]:
adata_joint.uns["cell_type_colors"][15]="steelblue"
sc.pl.umap(adata_joint, color='cell_type', groups="CD14+ Mono",s=8,save="real_neurips_cd14.png")

In [None]:
adata_joint.uns["donor_colors"][0]="steelblue"
sc.pl.umap(adata_joint, color='donor', groups="donor1",s=8,save="real_neurips_d1.png")

In [None]:
adata_joint.uns["guidance weight_colors"][0] = "firebrick"
sc.pl.umap(adata_joint, color='guidance weight', groups="0", save="neurips_cd14_d1_0.png")

In [None]:
adata_joint.uns["guidance weight_colors"][1] = "firebrick"
sc.pl.umap(adata_joint, color='guidance weight', groups="1", save="neurips_cd14_d1_1.png")

In [None]:
adata_joint.uns["guidance weight_colors"][2] = "firebrick"
sc.pl.umap(adata_joint, color='guidance weight', groups="2", save="neurips_cd14_d1_2.png")

In [None]:
adata_joint.uns["guidance weight_colors"][3] = "firebrick"
sc.pl.umap(adata_joint, color='guidance weight', groups="5", save="neurips_cd14_d1_5.png")

Add uncond generation

In [None]:
classes = {"cell_type": condition_id_ct*torch.ones(12000).long(), 
           "DonorNumber": condition_id_donor*torch.ones(12000).long()}
X_generated_uncond = generative_model.sample(batch_size=12000,
                                    n_sample_steps=2, 
                                    theta_covariate="cell_type", 
                                    size_factor_covariate="cell_type",
                                    conditioning_covariates=["cell_type", "DonorNumber"],
                                    covariate_indices=classes, 
                                    log_size_factor=None, 
                                    guidance_weights={'cell_type': 0, 'DonorNumber': 0})

X_generated_uncond = X_generated_uncond.to("cpu")

In [None]:
adata_joint_uncond.obs["dataset_type"] = np.where(adata_joint_uncond.obs["dataset_type"]=="gen", "Generated", "Real")

In [None]:
gener_uncond = sc.AnnData(X = X_generated_uncond.numpy())

In [None]:
adata_joint_uncond = sc.AnnData(X=np.concatenate([adata_original.layers["X_counts"].A, gener_uncond.X], axis=0), 
                               obs={"dataset_type":["real" for _ in range(len(adata_original))]+["gen" for _ in range(len(gener_uncond))]})

In [None]:
sc.pp.normalize_total(adata_joint_uncond, target_sum=1e4)
sc.pp.log1p(adata_joint_uncond)

In [None]:
sc.tl.pca(adata_joint_uncond)
sc.pp.neighbors(adata_joint_uncond)
sc.tl.umap(adata_joint_uncond)

In [None]:
sc.pl.umap(adata_joint_uncond, color="dataset_type", palette="tab20", save="neurips_0_0.png")