In [1]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
import torch
import yaml
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.model_summary import ModelSummary

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

from celldreamer.paths import ROOT
from celldreamer.estimator.celldreamer_estimator import CellDreamerEstimator
from celldreamer.paths import DATA_DIR
from celldreamer.data.utils import Args

In [4]:
cd $ROOT

/home/icb/till.richter/git/celldreamer


Load configuration 

In [5]:
config = yaml.safe_load(open(ROOT / "configs/toy/config_ddpm.yaml", 
                            "rb"))
args_toy = Args(config["args"]) 

Initialize estimator 

In [6]:
estimator = CellDreamerEstimator(args_toy)

Create the training folders...
Initialize data module...




GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Initialize feature embeddings...
Initialize model...


In [7]:
estimator.generative_model

ConditionalGaussianDDPM(
  (denoising_model): UNetTimeStepClassSetConditioned(
    (downsample_blocks): ModuleList(
      (0): ResBlockTimeEmbedCond(
        (linear_map_class): Identity()
        (time_embed_net): Sequential(
          (0): Linear(in_features=100, out_features=32, bias=True)
          (1): SELU()
          (2): Linear(in_features=32, out_features=32, bias=True)
        )
        (conv): Sequential(
          (0): GroupNorm(3, 3, eps=1e-05, affine=True)
          (1): GELU(approximate='none')
          (2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (relu): ReLU()
        (l_embedding): Sequential(
          (0): GELU(approximate='none')
          (1): Linear(in_features=100, out_features=32, bias=True)
        )
        (out_layer): Sequential(
          (0): GroupNorm(4, 32, eps=1e-05, affine=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Conv2d(32, 32, kernel_size=(3, 

In [11]:
# sample images from the toy dataset
shapes = ["circle", "square"]
colors = ["lightblue", "lightgreen", "lightyellow", "lightgray"]
sizes = ["small", "medium", "large"]
positions = ["topleft", "topright", "bottomleft", "bottomright"]
for data in estimator.datamodule.train_dataloader():
    images = data['X']
    y = data['y']
    shape_labels = y['y_shapes']
    color_labels = y['y_colors']
    size_labels = y['y_sizes']
    position_labels = y['y_positions']
    # create a matplotlib figure with 4x4 images with the title of each sub-image being the set of labels
    fig, ax = plt.subplots(4, 4, figsize=(10, 10))
    # increase distance between subplots
    fig.subplots_adjust(hspace=0.6, wspace=0.6)
    for i in range(4):
        for j in range(4):
            ax[i, j].imshow(images[i*4+j].permute(1, 2, 0))
            ax[i, j].set_title(f'{sizes[size_labels[i*4+j]]} {colors[color_labels[i*4+j]]} \n {shapes[shape_labels[i*4+j]]} at {positions[position_labels[i*4+j]]}')
    break

NameError: name 'plt' is not defined

Train model

In [8]:
estimator.train()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                            | Params
--------------------------------------------------------------------
0 | denoising_model | UNetTimeStepClassSetConditioned | 1.3 M 
1 | mse             | MSELoss                         | 0     
--------------------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.070     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

TypeError: 'method' object is not iterable

In [None]:
estimator.generative_model = estimator.generative_model.to("cuda")
estimator.generative_model.denoising_model = estimator.generative_model.denoising_model.to("cuda")

In [None]:
# ckpt = torch.load("/nfs/students/pala/celldreamer/try_experiment_pbmc/checkpoints/epoch_270.ckpt")
# estimator.generative_model.load_state_dict(ckpt["state_dict"])

**Generate**

In [None]:
T = estimator.generative_model.T
T

In [None]:
vec = torch.randn(10, 50).to("cuda")
t1 = 1000*torch.ones(10).to("cuda")
t2 = 1*torch.ones(10).to("cuda")

In [None]:
estimator.generative_model.denoising_model(vec, t1, None)

In [None]:
estimator.generative_model.denoising_model(vec.to('cuda'), t2.to('cuda'), None)

**Check timestep embedding**

In [None]:
# X_gen = estimator.generative_model.sample(batch_size=1000,
#                                              y=None, 
#                                    `          return_all_timesteps=False,
#                                              clip_denoised=True)

X_gen= estimator.generative_model.ddim_sample(batch_size=1000, 
                      y=None, 
                      return_all_timesteps = False, 
                      ddim_sampling_eta=0)

In [None]:
X_gen

**Plot generated**

In [None]:
adata_tmp = sc.AnnData(X=X_gen.detach().cpu().numpy())
sc.tl.pca(adata_tmp)
sc.pp.neighbors(adata_tmp)
sc.tl.umap(adata_tmp)

In [None]:
sc.pl.umap(adata_tmp)

In [None]:
d = []

for batch in estimator.datamodule.train_dataloader:
    d.append(batch["X"])
    
d = torch.cat(d, dim=0)
# d = torch.clip(d, -3,3)

In [None]:
adata = sc.AnnData(X = np.concatenate([X_gen.detach().cpu().numpy(), d.cpu().numpy()]),
                   obs = pd.DataFrame({"type":["gen"]*len(X_gen)+["real"]*len(d)}))

In [None]:
sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)

In [None]:
sc.pl.pca(adata, color="type")

In [None]:
sc.pl.umap(adata, color="type")

In [None]:
d.mean()

In [None]:
X_gen.mean()

In [None]:
d.max()

In [None]:
d.max()

In [None]:
d.min()

In [None]:
d.mean()

In [None]:
d.min()

In [None]:
# d = 2 * (d - d.min(1).values.unsqueeze(-1)) / (d.max(1).values.unsqueeze(-1) - d.min(1).values.unsqueeze(-1)) - 1

In [None]:
# d = (d - d.min(1).values.unsqueeze(-1)) / (d.max(1).values.unsqueeze(-1) - d.min(1).values.unsqueeze(-1)) 

In [None]:
# d