Implement and modify a pre-trained diffusion model for a custom dataset.

In [None]:
# Install the required libraries
!pip install diffusers accelerate transformers datasets torchvision

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12=

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from diffusers import UNet2DModel, DDPMScheduler
from torchvision import transforms

print('Libraries imported.')

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Libraries imported.


In [None]:
dataset = load_dataset('cifar10', split='train')

# Rename the image column from 'img' to 'image' for consistency
dataset = dataset.rename_column('img', 'image')

print('CIFAR-10 dataset loaded.')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/120M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

CIFAR-10 dataset loaded.


In [None]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to 32x32 to match the model's input
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Apply the transformation to each example in the dataset
def transform_example(example):
    example["image"] = transform(example["image"])
    return example

dataset = dataset.with_transform(transform_example)

# Create a DataLoader for training
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

print('Dataset transformed and DataLoader created.')

Dataset transformed and DataLoader created.


In [None]:
from diffusers import UNet2DModel, DDPMScheduler

# Load the pre-trained diffusion model and its scheduler
model = UNet2DModel.from_pretrained('google/ddpm-cifar10-32')
scheduler = DDPMScheduler.from_pretrained('google/ddpm-cifar10-32')

print('Pre-trained diffusion model and scheduler loaded.')

config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/143M [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/256 [00:00<?, ?B/s]

Pre-trained diffusion model and scheduler loaded.


In [None]:
# Define a custom collate function to stack images
def custom_collate_fn(batch):
    # Each element in batch is a dictionary with key "image"
    images = [example["image"] for example in batch]
    images = torch.stack(images)
    return {"image": images}

# Create a DataLoader for training using the custom collate function
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn)

print('Dataset transformed and DataLoader created.')

Dataset transformed and DataLoader created.


In [None]:
# Save the fine-tuned model and scheduler to a directory
model.save_pretrained('./fine_tuned_diffusion_model')
scheduler.save_pretrained('./fine_tuned_diffusion_model')

print('Fine-tuned model saved.')

Fine-tuned model saved.
