# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

# 1. Setup

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once.

In [1]:
# !git clone https://github.com/facebookresearch/DiT.git
# import DiT, os
# os.chdir('DiT')
# os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
# !pip install diffusers timm --upgrade
# # DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

caused by: ["[Errno 2] The file to load file system plugin from does not exist.: '/opt/homebrew/lib/python3.10/site-packages/tensorflow_io-0.26.0-py3.10-macosx-12-arm64.egg/tensorflow_io/python/ops/libtensorflow_io_plugins.so'"]
caused by: ["dlopen(/opt/homebrew/lib/python3.10/site-packages/tensorflow_io-0.26.0-py3.10-macosx-12-arm64.egg/tensorflow_io/python/ops/libtensorflow_io.so, 0x0006): tried: '/opt/homebrew/lib/python3.10/site-packages/tensorflow_io-0.26.0-py3.10-macosx-12-arm64.egg/tensorflow_io/python/ops/libtensorflow_io.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/python3.10/site-packages/tensorflow_io-0.26.0-py3.10-macosx-12-arm64.egg/tensorflow_io/python/ops/libtensorflow_io.so' (no such file), '/opt/homebrew/lib/python3.10/site-packages/tensorflow_io-0.26.0-py3.10-macosx-12-arm64.egg/tensorflow_io/python/ops/libtensorflow_io.so' (no such file)"]


GPU not found. Using CPU instead.


# Download DiT-XL/2 Models

You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too.

In [2]:
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8
# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # important!

compiled_model = torch.compile(model)

vae = AutoencoderKL.from_pretrained(vae_model).to(device)

Process ForkProcess-2:
Process ForkProcess-3:
Process ForkProcess-5:
Process ForkProcess-8:
Process ForkProcess-6:
Process ForkProcess-9:
Process ForkProcess-1:
Process ForkProcess-10:
Process ForkProcess-7:
Process ForkProcess-4:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framewo

  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/opt/homebrew/Cellar/python@3.10/3.10.9/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/opt/ho

# 2. Sample from Pre-trained DiT Models

You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a).

In [4]:
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# Sample images:
samples = diffusion.p_sample_loop(
    compiled_model.forward_with_cfg, z.shape, z, clip_denoised=False, 
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples = vae.decode(samples / 0.18215).sample

# Save and display images:
save_image(samples, "sample.png", nrow=int(samples_per_row), 
           normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)

  0%|          | 0/250 [00:00<?, ?it/s]

> [0;32m/opt/homebrew/lib/python3.10/site-packages/timm/models/vision_transformer.py[0m(222)[0;36mforward[0;34m()[0m
[0;32m    220 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    221 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 222 [0;31m        [0mattn[0m [0;34m=[0m [0;34m([0m[0mq[0m [0;34m@[0m [0mk[0m[0;34m.[0m[0mtranspose[0m[0;34m([0m[0;34m-[0m[0;36m2[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mscale[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    223 [0;31m        [0mattn[0m [0;34m=[0m [0mattn[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    224 [0;31m        [0mattn[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattn_drop[0m[0;34m([0m[0mattn[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0

ipdb> qkv[0]
tensor([[[[-0.5830,  1.0270, -0.3734,  ..., -0.3863, -0.3983,  1.1537],
          [-0.5830,  1.0270, -0.3734,  ..., -0.3863, -0.3983,  1.1537],
          [-0.5830,  1.0270, -0.3734,  ..., -0.3863, -0.3983,  1.1537],
          ...,
          [-0.5830,  1.0270, -0.3734,  ..., -0.3863, -0.3983,  1.1537],
          [-0.5830,  1.0270, -0.3734,  ..., -0.3863, -0.3983,  1.1537],
          [-0.5830,  1.0270, -0.3734,  ..., -0.3863, -0.3983,  1.1537]],

         [[-0.4378, -1.3028, -0.1277,  ...,  0.8786, -0.2804, -1.0825],
          [-0.4378, -1.3028, -0.1277,  ...,  0.8786, -0.2804, -1.0825],
          [-0.4378, -1.3028, -0.1277,  ...,  0.8786, -0.2804, -1.0825],
          ...,
          [-0.4378, -1.3028, -0.1277,  ...,  0.8786, -0.2804, -1.0825],
          [-0.4378, -1.3028, -0.1277,  ...,  0.8786, -0.2804, -1.0825],
          [-0.4378, -1.3028, -0.1277,  ...,  0.8786, -0.2804, -1.0825]],

         [[ 0.5910,  2.0859,  0.0138,  ...,  0.2632,  0.6964,  0.5397],
          [ 0.591

ipdb> model.learn_sigma
*** NameError: name 'model' is not defined
ipdb> exit()


In [11]:
model(torch.ones_like(z), torch.tensor([966]*n*2), y)[0]

tensor([[[ 0.8955,  0.9233,  0.8968,  ...,  0.9205,  0.9010,  0.9260],
         [ 0.8428,  0.8666,  0.8464,  ...,  0.8727,  0.8511,  0.8756],
         [ 0.8961,  0.9231,  0.8967,  ...,  0.9195,  0.9006,  0.9239],
         ...,
         [ 0.8519,  0.8558,  0.8536,  ...,  0.8614,  0.8582,  0.8665],
         [ 0.8973,  0.9237,  0.8996,  ...,  0.9205,  0.9005,  0.9239],
         [ 0.8524,  0.8561,  0.8541,  ...,  0.8620,  0.8584,  0.8666]],

        [[ 0.9287,  0.9070,  0.9294,  ...,  0.9108,  0.9420,  0.9078],
         [ 0.9394,  0.9604,  0.9344,  ...,  0.9586,  0.9473,  0.9533],
         [ 0.9307,  0.9122,  0.9322,  ...,  0.9159,  0.9424,  0.9138],
         ...,
         [ 0.9403,  0.9738,  0.9338,  ...,  0.9695,  0.9482,  0.9664],
         [ 0.9317,  0.9052,  0.9333,  ...,  0.9076,  0.9428,  0.9049],
         [ 0.9403,  0.9800,  0.9343,  ...,  0.9753,  0.9488,  0.9723]],

        [[ 0.9305,  0.9313,  0.9229,  ...,  0.9306,  0.9177,  0.9333],
         [ 0.9580,  0.9133,  0.9497,  ...,  0

In [12]:
model.eval()

DiT(
  (x_embedder): PatchEmbed(
    (proj): Conv2d(4, 1152, kernel_size=(2, 2), stride=(2, 2))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=1152, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1152, out_features=1152, bias=True)
    )
  )
  (y_embedder): LabelEmbedder(
    (embedding_table): Embedding(1001, 1152)
  )
  (blocks): ModuleList(
    (0): DiTBlock(
      (norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=1152, out_features=3456, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1152, out_features=1152, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Linear(in_features=1152, out_features=4608, bias=True)
        (act): GELU(approximate=

In [6]:
model.out_channels

8