Satvika Eda, Divya Sri Bandaru & Dhriti Anjaria
13th April 2025

# Stroke-Based Image Colorization with Stable Diffusion & ControlNet

This notebook contains the **fine-tuning of the pretrained ControlNet model** and **zero-shot inference using pretrained Stable Diffusion + fine tuned ControlNet** for stroke-based image colorization.


In [1]:
import torch.multiprocessing as mp
mp.set_start_method('spawn')  # or 'forkserver'

In [2]:
import os
import glob
import random
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import models
import matplotlib.pyplot as plt

### Loading the Multi Cue Stroke Dataset

In [3]:
import sys
# sys.path.remove('/kaggle/input/lab-seg-dataset')
sys.path.append("/kaggle/input/dataset-seg")

from dataset import MultiCueStrokeDataset
print("Dataset class loaded successfully ✅")

Dataset class loaded successfully ✅


In [4]:
# configuration parameters for training 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LR = 1e-5
ACCUM_STEPS = 2
IMG_SIZE = 256
BATCH_SIZE = 4
EPOCHS = 4

In [5]:
# === DataLoader wrapper with prefetching ===
def get_dataloader(image_paths, batch_size=4, num_workers=4, shuffle=True, img_size=IMG_SIZE, device='cuda'):
    dataset = MultiCueStrokeDataset(image_paths, img_size=img_size, device=device)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, prefetch_factor=2)


In [6]:
import os
import glob

# all training image paths (.JPEG files) from the subset class folders
TRAIN_FOLDER = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train"
VAL_FOLDER = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/val"
IMG_SIZE = IMG_SIZE

subset_class_ids = [
    "n01440764", "n01514859", "n01629819", "n01664065", "n01742172",
    "n01882714", "n01978455", "n02002556", "n02071294", "n02099601",
    "n02104029", "n02112137", "n02123045", "n02165456", "n02206856",
    "n02279972", "n02317335", "n02395406", "n02415577", "n02480495",
    "n02509815", "n02692877", "n02786058", "n02823428", "n02879718",
    "n02966193", "n03047690", "n03126707", "n03179701", "n03255030",
    "n03379051", "n03424325", "n03494278", "n03584829", "n03633091",
    "n03770439", "n03814639", "n03888257", "n03976657", "n04037443",
    "n04118538", "n04552348", "n02113799", "n02391049", "n03478589",
    "n03085013", "n03100240", "n03666591", "n03314780", "n02795169"
]

train_image_paths = []
for class_id in subset_class_ids:
    class_path = os.path.join(TRAIN_FOLDER, class_id)
    if os.path.exists(class_path):
        images = glob.glob(os.path.join(class_path, "*.JPEG"))
        train_image_paths.extend(images)
        
# all validation image paths (.JPEG files) from the val folder
val_image_paths = glob.glob(os.path.join(VAL_FOLDER, "*.JPEG"))
len(train_image_paths), len(val_image_paths), train_image_paths[:2], val_image_paths[:2]

(64733,
 50000,
 ['/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train/n01440764/n01440764_3198.JPEG',
  '/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train/n01440764/n01440764_10845.JPEG'],
 ['/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/val/ILSVRC2012_val_00003485.JPEG',
  '/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/val/ILSVRC2012_val_00021211.JPEG'])

In [None]:
# Training and validation Dataset
train_dataset = MultiCueStrokeDataset(train_image_paths, img_size=IMG_SIZE, device='cuda')
val_dataset = MultiCueStrokeDataset(val_image_paths, img_size=IMG_SIZE, device='cuda')

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth
100%|██████████| 161M/161M [00:00<00:00, 213MB/s] 


In [8]:
# Dataloader
device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader = get_dataloader(train_image_paths, batch_size=BATCH_SIZE, device=device)
val_loader = get_dataloader(val_image_paths, batch_size=BATCH_SIZE, shuffle=False, device=device)


In [9]:
# LOADING PIPELINE
# ====================
# Loading the Stable Diffusion pipeline with ControlNet
# ControlNet model for Canny edge detection
# https://huggingface.co/lllyasviel/control_v11p_sd15_canny
# Stable Diffusion model
# https://huggingface.co/runwayml/stable-diffusion-v1-5

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/control_v11p_sd15_canny",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE)

pipe.vae.enable_tiling()
pipe.enable_attention_slicing()
# === FREEZE UNET ===
for param in pipe.unet.parameters():
    param.requires_grad = False

2025-04-23 22:28:51.358629: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745447331.560837      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745447331.611078      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/1.45G [00:00<?, ?B/s]

model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [10]:
# LORA CONFIG

# import torch
# import torch.nn as nn
# import torch.nn.functional as 

# class LoRALinear(nn.Module):
#     def __init__(self, original, r=4, alpha=16):
#         super().__init__()
#         self.original = original
#         self.lora_A = nn.Linear(original.in_features, r, bias=False)
#         self.lora_B = nn.Linear(r, original.out_features, bias=False)
#         self.scaling = alpha / r

#         # Init weights
#         nn.init.kaiming_uniform_(self.lora_A.weight, a=5**0.5)
#         nn.init.zeros_(self.lora_B.weight)

#     def forward(self, x):
#         # Force all components on same device
#         device = self.original.weight.device
#         x = x.to(device)
#         self.lora_A = self.lora_A.to(device)
#         self.lora_B = self.lora_B.to(device)

#         lora_out = self.lora_B(self.lora_A(x)) * self.scaling
#         return self.original(x) + lora_out


In [11]:
# # === MANUALLY INJECT LORA INTO LINEAR PROJECTIONS (to_q, to_k, to_v) ===
# import torch.nn as nn
# def patch_controlnet_lora(controlnet, r=4, alpha=16):
#     for name, module in controlnet.named_modules():
#         if hasattr(module, "to_q") and isinstance(module.to_q, nn.Linear):
#             device = module.to_q.weight.device
#             module.to_q = LoRALinear(module.to_q, r=r, alpha=alpha).to(device)
#             module.to_k = LoRALinear(module.to_k, r=r, alpha=alpha).to(device)
#             module.to_v = LoRALinear(module.to_v, r=r, alpha=alpha).to(device)
#     return controlnet

# pipe.controlnet = patch_controlnet_lora(pipe.controlnet)

# # === COLLECT LORA PARAMS ONLY ===
# lora_params = [p for n, p in pipe.controlnet.named_parameters() if "lora" in n and p.requires_grad]
# optimizer = torch.optim.AdamW(lora_params, lr=LR)
# loss_fn = nn.MSELoss()


In [12]:
# converting lab tensor to rgb for control net training

def lab_tensor_to_rgb(lab_tensor):
    lab_np = lab_tensor.detach().cpu().numpy()
    rgb_images = []
    for img in lab_np:
        L = (img[0] * 100).astype(np.float32)
        ab = (img[1:3] * 128).astype(np.float32)
        lab = np.stack([L, ab[0], ab[1]], axis=-1)
        lab = lab.astype(np.uint8)
        rgb = cv2.cvtColor(lab, cv2.COLOR_Lab2RGB)
        rgb_images.append(rgb)
    rgb_images = np.stack(rgb_images, axis=0)
    rgb_tensor = torch.from_numpy(rgb_images).permute(0, 3, 1, 2).float() / 255.0
    return rgb_tensor


In [13]:
# === TRAINING CONTROLNET ===
def train_controlnet(pipe, train_loader, val_loader, device="cuda", save_dir="controlnet_ckpts"):
    os.makedirs(save_dir, exist_ok=True)
    optimizer = torch.optim.AdamW(pipe.controlnet.parameters(), lr=1e-5)
    loss_fn = nn.MSELoss()
    ACCUM_STEPS = 1
    EPOCHS = 3

    global_step = 0

    for epoch in range(EPOCHS):
        pipe.controlnet.train()
        pipe.unet.eval()
        total_loss = 0

        optimizer.zero_grad()
        for step, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch} - Training")):
            inputs = inputs.to(device, dtype=pipe.vae.dtype)
            targets = targets.to(device)

            # ✅ Convert AB strokes into RGB control image
            control_image = lab_tensor_to_rgb(inputs).to(device, dtype=pipe.vae.dtype)

            # ✅ Dummy text prompt → encoder_hidden_states
            prompt = ["a photo"] * inputs.shape[0]
            with torch.no_grad():
                encoder_hidden_states = pipe.text_encoder(
                    pipe.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").input_ids.to(device)
                )[0]

                latents = pipe.vae.encode(inputs).latent_dist.sample()

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            down_block_res_samples, mid_block_res_sample = pipe.controlnet(
                sample=noisy_latents,
                timestep=timesteps,
                encoder_hidden_states=encoder_hidden_states,
                controlnet_cond=control_image,
                return_dict=False
            )

            pred = pipe.unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states,
                down_block_additional_residuals=down_block_res_samples,
                mid_block_additional_residual=mid_block_res_sample,
            ).sample

            loss = loss_fn(pred, noise)
            loss.backward()

            if (step + 1) % ACCUM_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()
            global_step += 1

        print(f"Epoch {epoch} | Train Loss: {total_loss / len(train_loader):.4f}")
        torch.save(pipe.controlnet.state_dict(), os.path.join(save_dir, f"controlnet_epoch{epoch}.pth"))

        # === VALIDATION ===
        pipe.controlnet.eval()
        pipe.unet.eval()
        val_loss = 0
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch} - Validation"):
                inputs = inputs.to(device, dtype=pipe.vae.dtype)
                targets = targets.to(device)
                control_image = lab_tensor_to_rgb(inputs).to(device, dtype=pipe.vae.dtype)

                prompt = ["a photo"] * inputs.shape[0]
                encoder_hidden_states = pipe.text_encoder(
                    pipe.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").input_ids.to(device)
                )[0]

                latents = pipe.vae.encode(inputs).latent_dist.sample()
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()
                noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

                down_block_res_samples, mid_block_res_sample = pipe.controlnet(
                    sample=noisy_latents,
                    timestep=timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    controlnet_cond=control_image,
                    return_dict=False
                )

                pred = pipe.unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    down_block_additional_residuals=down_block_res_samples,
                    mid_block_additional_residual=mid_block_res_sample,
                ).sample

                val_loss += loss_fn(pred, noise).item()

            print(f"Epoch {epoch} | Val Loss: {val_loss / len(val_loader):.4f}")
            torch.cuda.empty_cache()


In [None]:
# TRAINING
train_controlnet(pipe, train_loader, val_loader)

Epoch 0 - Training: 100%|██████████| 16184/16184 [3:27:16<00:00,  1.30it/s]  


Epoch 0 | Train Loss: nan


Epoch 0 - Validation:   5%|▍         | 612/12500 [04:56<1:31:41,  2.16it/s]