Installing the pytorch version of Imagen from lucidrains github profile : https://github.com/lucidrains/imagen-pytorch

In [1]:
!python --version

Python 3.10.12


In [None]:
!pip install imagen-pytorch



In [None]:
import torch
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Version:", torch.version.cuda)

CUDA Available: True
CUDA Version: 12.1


In [None]:
import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
    loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.)

images.shape # (3, 3, 256, 256)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

torch.Size([3, 3, 256, 256])

Using ImageTrainer wrapper class

In [None]:
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    text_encoder_name = 't5-large',
    image_sizes = (64, 256),
    timesteps = 1000,
    cond_drop_prob = 0.1
).cuda()

# wrap imagen with the trainer class

trainer = ImagenTrainer(imagen)

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(64, 256, 1024).cuda()
images = torch.randn(64, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = trainer(
    images,
    text_embeds = text_embeds,
    unet_number = 1,            # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
    max_batch_size = 4          # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)

trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = trainer.sample(texts = [
    'a puppy looking anxiously at a giant donut on the table',
    'the milky way galaxy in the style of monet'
], cond_scale = 3.)

images.shape # (2, 3, 256, 256)



config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

dataloader_config = DataLoaderConfiguration(split_batches=True)


unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


model.safetensors:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

torch.Size([2, 3, 256, 256])

Training Imagen without text ( unconditional image generation)

In [None]:
import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer

# unets for unconditional imagen

unet1 = Unet(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True),
    layer_cross_attns = False,
    use_linear_attn = True
)

unet2 = SRUnet256(
    dim = 32,
    dim_mults = (1, 2, 4),
    num_resnet_blocks = (2, 4, 8),
    layer_attns = (False, False, True),
    layer_cross_attns = False
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text = False,   # this must be set to False for unconditional Imagen
    unets = (unet1, unet2),
    image_sizes = (64, 128),
    timesteps = 1000
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet separately
# in this example, only training on unet number 1

loss = trainer(training_images, unet_number = 1)
trainer.update(unet_number = 1)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)

unet 2 has not been trained
when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Train only using super-resolution unets

In [None]:
import torch
from imagen_pytorch import Unet, NullUnet, Imagen

# unet for imagen

unet1 = NullUnet()  # add a placeholder "null" unet for the base unet

unet2 = Unet(
    dim = 32,
    cond_dim = 512,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = (2, 4, 8, 8),
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True)
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 250,
    cond_drop_prob = 0.1
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings as well as low resolution images

lowres_images = torch.randn(3, 3, 64, 64).cuda()  # starting un-resoluted images

images = imagen.sample(
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles'
    ],
    start_at_unet_number = 2,              # start at unet number 2
    start_image_or_video = lowres_images,  # pass in low resolution images to be resoluted
    cond_scale = 3.)

images.shape # (3, 3, 256, 256)

0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([3, 3, 256, 256])

Saving and loading the trainer and all associated states

In [None]:
import os

# Define the path where you want to save the checkpoint
checkpoint_path = '/content/checkpoints/checkpoint.pt'

# Ensure the directory exists
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

# Save the checkpoint
trainer.save(checkpoint_path)

# Load the checkpoint
trainer.load(checkpoint_path)

# Check the steps
print(trainer.steps)  # Should print (2,) if successfully loaded

checkpoint saved to /content/checkpoints/checkpoint.pt
checkpoint loaded from /content/checkpoints/checkpoint.pt
tensor([1, 0], device='cuda:0')


Dataloader : unconditional training

In [None]:
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os

# Define the U-Net for the Imagen model
unet = Unet(
    dim=32,
    dim_mults=(1, 2, 4, 8),
    num_resnet_blocks=1,
    layer_attns=(False, False, False, True),
    layer_cross_attns=False
)

# Define the Imagen model (unconditional)
imagen = Imagen(
    condition_on_text=False,  # Set to False for unconditional training
    unets=unet,
    image_sizes=128,
    timesteps=1000
)

# Initialize the trainer
trainer = ImagenTrainer(
    imagen=imagen,
    split_valid_from_train=True  # Whether to split the validation dataset from training
).cuda()  # Ensure the trainer is on GPU

# Download CIFAR-10 dataset and define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to 128x128 pixels
    transforms.ToTensor(),  # Convert PIL Image to PyTorch Tensor
])

# Load CIFAR-10 dataset
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Create a directory to save CIFAR-10 images
os.makedirs('./data/cifar10_images', exist_ok=True)

# Save CIFAR-10 images into the directory (adjust based on your need)
for idx, (image, label) in enumerate(cifar10_train):
    # Convert tensor to PIL Image and save
    save_image = transforms.ToPILImage()(image)  # Convert tensor to PIL Image
    save_image.save(f'./data/cifar10_images/image_{idx}.png')  # Save each image with a unique name

# Create a Dataset instance for training using the directory
dataset = Dataset('./data/cifar10_images', image_size=128)  # Only images are returned for unconditional training

# Add the dataset to the trainer with a specified batch size
trainer.add_train_dataset(dataset, batch_size=16)

# Training loop
for i in range(200000):
    loss = trainer.train_step(unet_number=1, max_batch_size=4)
    print(f'loss: {loss}')

    if not (i % 50):
        valid_loss = trainer.valid_step(unet_number=1, max_batch_size=4)
        print(f'valid loss: {valid_loss}')

    if not (i % 100) and trainer.is_main:  # Ensure this runs in distributed mode if needed
        images = trainer.sample(batch_size=1, return_pil_images=True)  # Sample an image
        images[0].save(f'./sample-{i // 100}.png')  # Save the sampled image


Files already downloaded and verified
training with dataset of 48750 samples and validating with randomly splitted 1250 samples
loss: 0.8159782290458679
valid loss: 0.7993254065513611


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

loss: 0.7658392488956451
loss: 0.822913408279419
loss: 0.7485736310482025
loss: 0.9333889037370682
loss: 0.7041250020265579
loss: 0.6761122047901154
loss: 0.6341600194573402
loss: 0.9220888912677765
loss: 0.8661463707685471
loss: 0.8199429214000702
loss: 0.8015386015176773
loss: 0.8706642538309097
loss: 0.8484079837799072
loss: 0.9012722373008728
loss: 0.9128666073083878
loss: 0.7628585994243622
loss: 0.8038588613271713
loss: 0.7146209329366684
loss: 0.9066739976406097
loss: 0.8205430507659912
loss: 0.8026849776506424
loss: 0.7222362533211708
loss: 0.7482868432998657
loss: 0.7565257251262665
loss: 0.7209142595529556
loss: 0.7035306990146637
loss: 0.728437140583992
loss: 0.7975873053073883
loss: 0.7773160040378571
loss: 0.7509311586618423
loss: 0.6910482496023178
loss: 0.720490574836731
loss: 0.5634416788816452
loss: 0.6729527786374092
loss: 0.6703889071941376
loss: 0.6790521889925003
loss: 0.7735923677682877
loss: 0.7333037853240967
loss: 0.8117127418518066
loss: 0.6821551471948624
los