# Tutorial initialization and use of DDPM 

In [3]:
import torch
from celldreamer.data.utils import Args
from celldreamer.models.base.autoencoder import MLP_AutoEncoder
from celldreamer.models.diffusion.denoising_model import MLPTimeStep
from celldreamer.models.diffusion.conditional_ddpm import ConditionalGaussianDDPM

Define the hyperparameters of the autoencoder, densoising module and the DDPM class 

In [4]:
args = Args({"autoencoder_kwargs": {"in_dim": 2000,
                                  "batch_size": 32, 
                                  "hidden_dim_encoder": [256, 128, 64], 
                                  "hidden_dim_decoder": [64, 128, 64], 
                                  "batch_norm": True, 
                                  "layer_norm": False,
                                  "activation": torch.nn.ReLU,
                                  "output_activation": torch.nn.Identity, 
                                  "reconst_loss": "mse", 
                                  "dropout": 0.0,
                                  "weight_decay": 0.1, 
                                  "learning_rate": 0.001,
                                   "optimizer": torch.optim.Adam, 
                                   "lr_scheduler": None,
                                   "lr_scheduler_kwargs": None
                                  },
             "denoising_module_kwargs": {"in_dim": 2000, 
                                        "dims": [256, 128, 64],
                                        "time_embed_size": 100,
                                        "num_classes": {"cell_type":3, "drug":250}, 
                                        "class_emb_size": 100,
                                        "dropout": 0.0 }})

Intialze modules 

In [5]:
autoencoder = MLP_AutoEncoder(**args.autoencoder_kwargs)
autoencoder

MLP_AutoEncoder(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (encoder): MLP(
    (0): Linear(in_features=2000, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.0, inplace=False)
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): BatchNorm1d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.0, inpl

Denoising model

In [6]:
denoising_model = MLPTimeStep(**args.denoising_module_kwargs)
denoising_model

MLPTimeStep(
  (encoder): ModuleList(
    (0): MLPTimeEmbedCond(
      (linear_map_class): Sequential(
        (0): Linear(in_features=253, out_features=100, bias=True)
        (1): ReLU()
        (2): Linear(in_features=100, out_features=100, bias=True)
      )
      (net): Sequential(
        (0): Linear(in_features=2100, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=256, bias=True)
      )
      (l_embedding): Sequential(
        (0): GELU(approximate='none')
        (1): Linear(in_features=100, out_features=256, bias=True)
      )
      (relu): ReLU()
      (out_layer): Sequential(
        (0): GELU(approximate='none')
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=256, out_features=256, bias=True)
      )
      (skip_connection): Sequential(
        (0): Linear(in_features=2100, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, ou

DDPM class

In [7]:
generative_model = ConditionalGaussianDDPM(
                    denoising_model=denoising_model,
                    autoencoder_model=autoencoder,
                    task="perturbation_modelling", 
                    feature_embeddings=None, 
                    T= 4000,  # default: 4_000
                    w= 0.3,  # default: 0.3
                    v= 0.1, 
                    n_covariates= 2, 
                    p_uncond= 0.1,
                    logging_freq= 5,   
                    classifier_free= False, 
                    optimizer= torch.optim.Adam,
                    )

generative_model

ConditionalGaussianDDPM(
  (denoising_model): MLPTimeStep(
    (encoder): ModuleList(
      (0): MLPTimeEmbedCond(
        (linear_map_class): Sequential(
          (0): Linear(in_features=253, out_features=100, bias=True)
          (1): ReLU()
          (2): Linear(in_features=100, out_features=100, bias=True)
        )
        (net): Sequential(
          (0): Linear(in_features=2100, out_features=256, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=256, out_features=256, bias=True)
        )
        (l_embedding): Sequential(
          (0): GELU(approximate='none')
          (1): Linear(in_features=100, out_features=256, bias=True)
        )
        (relu): ReLU()
        (out_layer): Sequential(
          (0): GELU(approximate='none')
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=256, out_features=256, bias=True)
        )
        (skip_connection): Sequential(
          (0): Linear(in_features=2100, out_features