In [None]:
import json
import numpy as np
import pandas as pd
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.utils as vutils
import math
from typing import List
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import cv2
from diffusers import DDPMScheduler

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 256
BATCH_SIZE = 32
TIMESTEP_COUNT = 100
EPOCHS = 100

<p style="color:red; font-weight:bold; font-size:16px">
Since training on cpu took much time, model was trained with a small subset of dataset. Results were not good enough to share.
</p>

In [None]:
class QuickDrawDataset(Dataset):
    def __init__(
            self,
            ndjson_path: str,
            indice_path: str,
            img_size: int,
            test_train: str = "train"
    ):
        self.ndjson_path = ndjson_path
        self.indice_path = indice_path
        self.img_size = img_size
        self.test_train = test_train
        self.sketch_indices = self._build_indices()
        self.transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                                             transforms.RandomHorizontalFlip(),
                                              transforms.ToTensor(),
                                              transforms.Normalize(std=(0.5,),
                                                                   mean=(0.5,))])

    def _build_indices(self) -> List:
        indices = []
        with open(self.indice_path, 'r') as f:
            all_indices = json.loads(f.read())[self.test_train]
            with open(self.ndjson_path, 'r') as f:
                for line_idx, line in enumerate(f):
                    if line_idx in all_indices:
                        indices.append(line_idx)
        return indices

    def create_images_with_lines(self, coords):
        img = Image.new('L', (self.img_size, self.img_size), 'white')
        draw = ImageDraw.Draw(img)
        for i, (x_coords, y_coords) in enumerate(coords):
            points = list(zip(x_coords, y_coords))
            if len(points) > 1:
                draw.line(points, fill="black", width=2)
        return self.transform(img)

    def __getitem__(self, idx: int) -> torch.Tensor:
        line_idx = self.sketch_indices[idx]
        # Load and process sketch
        with open(self.ndjson_path, 'r') as f:
            for current_idx, line in enumerate(f):
                if current_idx == line_idx:
                    sketch_data = json.loads(line.strip())
                    break

        return self.create_images_with_lines(sketch_data["drawing"])
    def __len__(self):
        return len(self.sketch_indices)

<p style="color:red; font-weight:bold; font-size:16px">
This method is used to normalize prediction result.
</p>

In [None]:
def denormalize(tensor, mean=(0.5,), std=(0.5,)):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    denormalized = tensor * std + mean
    denormalized = torch.clamp(denormalized, 0, 1)
    if denormalized.dim() == 3:  # Single image
        denormalized = denormalized.permute(1, 2, 0)
    denormalized = denormalized.squeeze()
    denormalized = denormalized.cpu().numpy()
    denormalized = (denormalized * 255).clip(0, 255).astype(np.uint8)
    return denormalized

<p style="color:red; font-weight:bold; font-size:16px">
The SinusoidalTimeEmbedding module encodes diffusion timesteps into a vector representation that can be added to image features inside the UNet. 
This gives the model information about which stage of the denoising process it's in. In diffusion models, each image is progressively noised over T timesteps. During training and sampling, the model needs to know how much noise is present — in other words, what timestep t it's on. We encode this timestep using a sinusoidal positional encoding.
</p>

In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    """Sinusoidal positional embedding for time steps."""
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
    def forward(self, t):
        # Create sinusoidal embeddings (half dim sin, half cos)
        half_dim = self.embedding_dim // 2
        # Prepare frequencies
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        # Outer product of time and frequencies
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        # Concatenate sin and cos
        emb = torch.cat([emb.sin(), emb.cos()], dim=1)
        return emb

<p style="color:red; font-weight:bold; font-size:16px">
ResidualBlock adds a shortcut between input and output not to lose data through model layers. Skip connections also allow data move from encoder layers to decoder layers directly.
</p>

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.time_proj = nn.Linear(time_emb_dim, out_ch)

        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        # Add time embedding to the feature map
        time_feature = self.time_proj(t_emb).unsqueeze(2).unsqueeze(3)
        h = self.conv1(x)
        h = self.act1(h + time_feature)
        h = self.conv2(h)
        h = self.act2(h)
        return h + self.skip(x)

<p style="color:red; font-weight:bold; font-size:16px">
UNet is popular for semantic segmentation.
</p>

In [None]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=16, time_emb_dim=64):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        # Encoder (downsampling)
        self.enc_conv1 = ResidualBlock(in_channels, base_channels, time_emb_dim)        
        self.down1 = nn.Conv2d(base_channels, base_channels, 4, 2, 1)                 

        self.enc_conv2 = ResidualBlock(base_channels, base_channels, time_emb_dim)  
        self.down2 = nn.Conv2d(base_channels, base_channels, 4, 2, 1)           

        self.enc_conv3 = ResidualBlock(base_channels, base_channels * 2, time_emb_dim)  
        self.down3 = nn.Conv2d(base_channels * 2, base_channels * 2, 4, 2, 1)               

        # Bottleneck
        self.bottleneck = ResidualBlock(base_channels * 2, base_channels * 2, time_emb_dim)

        # Decoder (upsampling)
        self.up3 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, 2, 1)  
        self.dec_conv3 = ResidualBlock(base_channels * 2 + base_channels * 2, base_channels, time_emb_dim)

        self.up2 = nn.ConvTranspose2d(base_channels, base_channels, 4, 2, 1)  
        self.dec_conv2 = ResidualBlock(base_channels + base_channels, base_channels, time_emb_dim)

        self.up1 = nn.ConvTranspose2d(base_channels, base_channels, 4, 2, 1)         
        self.dec_conv1 = ResidualBlock(base_channels + base_channels, base_channels, time_emb_dim)

        # Final output
        self.out_conv = nn.Conv2d(base_channels, out_channels, kernel_size=1)

    def forward(self, x, t):
        # Embed timestep
        t_emb = self.time_mlp(t)

        # Encoder with skip connections
        x1 = self.enc_conv1(x, t_emb)
        x2 = self.enc_conv2(self.down1(x1), t_emb)
        x3 = self.enc_conv3(self.down2(x2), t_emb)

        # Bottleneck
        x4 = self.bottleneck(self.down3(x3), t_emb)

        # Decoder with skip connections
        x = self.up3(x4)
        x = torch.cat([x, x3], dim=1)                
        x = self.dec_conv3(x, t_emb)

        x = self.up2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.dec_conv2(x, t_emb)

        x = self.up1(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec_conv1(x, t_emb)

        return self.out_conv(x)

<p style="color:red; font-weight:bold; font-size:16px">
This method is used to generate a sample starting from pure noise, iterating through timesteps.
</p>

In [None]:
def generate_samples(model, scheduler):
    model.eval()
    with torch.no_grad():
        x = torch.randn(1, 1, 256, 256, device=device)
        batch_size = x.size(0)
        for t in range(scheduler.num_train_timesteps - 1, -1, -1):
            t_batch = torch.full((1,), t, device=device, dtype=torch.long)
            pred_noise = model(x, t_batch)
            x = scheduler.step(pred_noise, t, x).prev_sample
        return x

<p style="color:red; font-weight:bold; font-size:16px">
This is the training pipeline for all three classes. Scheduler is taken from diffusers library.
</p>

In [None]:
loss_dict = {}
CATEGORIES = ['bus', 'cat', 'rabbit']

In [None]:
for ctgry in CATEGORIES:
    loss_dict[ctgry] = {
        "train_loss": [],
        "test_loss": []
    }
    train_dataset = QuickDrawDataset(ndjson_path=f"data/{ctgry}.ndjson",
                                         indice_path=f"subset/{ctgry}/indices.json",
                                         img_size=IMAGE_SIZE)
    test_dataset = QuickDrawDataset(ndjson_path=f"data/{ctgry}.ndjson",
                                        indice_path=f"subset/{ctgry}/indices.json",
                                        test_train="test",
                                        img_size=IMAGE_SIZE)


    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    scheduler = DDPMScheduler(num_train_timesteps=TIMESTEP_COUNT, beta_start=1e-3, beta_end=1e-1)
    scheduler.set_timesteps(TIMESTEP_COUNT)
    model = SimpleUNet().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.MSELoss()
    print(f"Training for {ctgry}...")
    min_loss = 10e5
    for epoch in range(1, EPOCHS + 1):
        model.train()
        train_loss = 0.0
        num_train_batches = 0
        for images in train_loader:
            images = images.to(device)  # shape [B, 1, 256, 256]
            # Sample random timesteps for each image in the batch
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.size(0),), device=device).long()
            # Sample random noise
            noise = torch.randn_like(images)
            # Get x_t by adding noise at timestep t
            noisy_images = scheduler.add_noise(images, noise, timesteps)
            # Predict the noise using UNet
            predicted_noise = model(noisy_images, timesteps)
            # Compute MSE loss between the predicted noise and the true noise
            loss = criterion(predicted_noise, noise)
            # Backpropagation and optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Accumulate loss for reporting
            train_loss += loss.item()
            num_train_batches += 1
        avg_train_loss = train_loss / num_train_batches

        # Periodic evaluation on test set
        model.eval()
        test_loss = 0.0
        num_test_batches = 0
        with torch.no_grad():
            for batch in test_loader:
                if isinstance(batch, (list, tuple)):
                    images = batch[0]
                else:
                    images = batch
                images = images.to(device)
                # Random timesteps for test
                timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.size(0),), device=device).long()
                noise = torch.randn_like(images)
                noisy_images = scheduler.add_noise(images, noise, timesteps)
                predicted_noise = model(noisy_images, timesteps)
                loss = criterion(predicted_noise, noise)
                test_loss += loss.item()
                num_test_batches += 1
        avg_test_loss = test_loss / num_test_batches

        loss_dict[ctgry]["train_loss"].append(avg_train_loss)
        loss_dict[ctgry]["test_loss"].append(avg_test_loss)
        if avg_test_loss < min_loss:
            torch.save(model.state_dict(), f"model_{ctgry}.pth")
            min_loss = avg_test_loss

        print(f"Epoch {epoch}: avg_train_loss={avg_train_loss:.4f}, avg_test_loss={avg_test_loss:.4f}")

        # Test samples are saved
        if epoch % 10 == 0:
            sample_img = generate_samples(model, scheduler)[0]
            sample_img = denormalize(sample_img)
            cv2.imwrite(f"sample_{ctgry}.png", sample_img)
    joblib.dump(loss_dict, f"loss_dict_{ctgry}.sav")

<p style="color:red; font-weight:bold; font-size:16px">
Train results were displayed
</p>

In [None]:
for ctgry in CATEGORIES:
    df = pd.DataFrame.from_dict(loss_dict[ctgry])
    df.plot()
    plt.show()
    plt.close()

<p style="color:red; font-weight:bold; font-size:16px">
Since train results were insufficient, final drawing videos were created with images from dataset
</p>

In [25]:


for ctgry in CATEGORIES:
    test_dataset = QuickDrawDataset(ndjson_path=f"data/{ctgry}.ndjson",
                                        indice_path=f"subset/{ctgry}/indices.json",
                                        test_train="test",
                                        img_size=IMAGE_SIZE)

    for img in test_dataset:
        img = denormalize(img)
        # cv2.imshow("original", img)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
    
        blurred = cv2.GaussianBlur(img, (1,1), sigmaX=3.0, sigmaY=3.0)
        _, thresh1 = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
        # cv2.imshow("thresh1", thresh1)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
        # edges = cv2.Canny(thresh1, 25, 200, apertureSize=3)
        lsd = cv2.createLineSegmentDetector(0)
        lines, _, _, _ = lsd.detect(thresh1)
    
        lines = [tuple(line[0]) for line in lines]
        lines = sorted(lines, key=lambda l: (l[0], l[1]))
    
        canvas = np.ones_like(img) * 255
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video = cv2.VideoWriter(f'drawing_process_{ctgry}.mp4', fourcc, 20.0, (IMAGE_SIZE, IMAGE_SIZE), isColor=False)
        dot_radius = 3
        step_length = 5
        jitter_sigma = 0.1
    
        for idx, (x1, y1, x2, y2) in enumerate(lines):
            dist = np.hypot(x2 - x1, y2 - y1)
            steps = max(int(dist / step_length), 1)
    
            xs = np.linspace(x1, x2, steps).astype(np.int32)
            ys = np.linspace(y1, y2, steps).astype(np.int32)
    
            for i in range(1, len(xs)):
                # Add small random jitter to simulate hand movement
                dx = np.random.normal(0, jitter_sigma)
                dy = np.random.normal(0, jitter_sigma)
                p1 = (int(xs[i - 1] + dx), int(ys[i - 1] + dy))
                p2 = (int(xs[i] + dx), int(ys[i] + dy))
                cv2.line(canvas, p1, p2, color=0, thickness=2)
                video.write(canvas)
        video.release()

<p style="color:red; font-weight:bold; font-size:16px">
References
</p>
<p style="color:red; font-weight:bold; font-size:14px">
Books
</p>

[Hands-On Generative AI with Transformers and Diffusion Models-O'Reilly Media (2024), Omar Sanseviero, Pedro Cuenca, Apolinário Passos, Jonathan Whitaker]()

<p style="color:red; font-weight:bold; font-size:14px">
Websites
</p>

[Github Awesome List](https://github.com/zju-pi/Awesome-Conditional-Diffusion-Models)

[PyTorch Website](https://pytorch.org/)

[Quick Draw Dataset](https://github.com/googlecreativelab/quickdraw-dataset)

[Quick Draw Documentation](https://quickdraw.readthedocs.io/en/latest/index.html)

[Diffusers](https://huggingface.co/docs/diffusers/en/index)

[DZData Medium Tutorial](https://dzdata.medium.com/intro-to-diffusion-model-part-1-29fe7724c043)

[ExplainingAI Youtube Channel](https://www.youtube.com/@Explaining-AI)

[DeepFindr](https://www.youtube.com/@DeepFindr)

[Kaggle-Class-conditioned-diffusion-model](https://www.kaggle.com/code/riddhich/class-conditioned-diffusion-model#Creating-a-Class-Conditioned-UNet)

[(Paper)Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239)

[(Paper)Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233)
