# Fine Tune Diffusion Model

In [1]:
import os
import random
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T

from diffusers import StableDiffusionPipeline, DDPMScheduler
from transformers import AutoTokenizer
from accelerate import Accelerator
import warnings
warnings.filterwarnings("ignore")

A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "d:\work_space\projects\deep_learning\CAP\lib\site-packages\xformers\__init__.py", line 57, in _is_triton_available
    import triton  # noqa
ModuleNotFoundError: No module named 'triton'


## Configuratin & Parameters

In [2]:
# CONFIG
model_id = "runwayml/stable-diffusion-v1-5"

data_dir = r"D:\work_space\projects\deep_learning\CAP6415_F25_project-Finding-and-solving-hard-to-generate-examples\Data_set\processed"   
resolution = 512         
train_batch_size = 1
gradient_accumulation_steps = 1
learning_rate = 5e-4
num_epochs = 10           # will be cut off by max_train_steps
max_train_steps = 300  # total update steps (not epochs)
lora_rank = 2             # safe for 6GB
num_workers = 0           # Windows + small batch = 0 is safest
seed = 42

save_dir = r"D:\work_space\projects\deep_learning\CAP6415_F25_project-Finding-and-solving-hard-to-generate-examples\model\lora_output"    
CHECKPOINT_DIR = save_dir + r"\checkpoints"
os.makedirs(save_dir, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [3]:
# UTILS
def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print("Seed set to", seed)

In [4]:
# DATASET
class CaptionDataset(Dataset):
    def __init__(self, folder: str | Path, res: int):
        self.items = []
        self.res = res
        folder = Path(folder)

        for img_path in sorted(folder.glob("*.jpg")):
            txt_path = img_path.with_suffix(".txt")
            if txt_path.exists():
                with open(txt_path, "r", encoding="utf-8") as f:
                    lines = [ln.strip() for ln in f.readlines() if ln.strip()]
                if not lines:
                    continue
                self.items.append((str(img_path), lines))

        self.transform = T.Compose([
            T.Resize((res, res)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3),
        ])

        print("Dataset size (images with captions):", len(self.items))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        img_path, captions = self.items[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)

        # randomly choose one caption line for this pass
        caption = random.choice(captions)
        return img, caption


dataset = CaptionDataset(data_dir, resolution)

Dataset size (images with captions): 222


In [5]:
# LoRA MODULES
def get_parent(model, name: str):
    parts = name.split(".")
    obj = model
    for p in parts[:-1]:
        obj = getattr(obj, p)
    return obj


In [6]:
class LoRALinear(nn.Module):
    def __init__(self, original: nn.Linear, rank: int = 2):
        super().__init__()
        self.original = original
        in_f = original.in_features
        out_f = original.out_features

        self.lora_A = nn.Linear(in_f, rank, bias=False)
        self.lora_B = nn.Linear(rank, out_f, bias=False)

        nn.init.kaiming_uniform_(self.lora_A.weight, a=5 ** 0.5)
        nn.init.zeros_(self.lora_B.weight)

        self.scale = 1.0 / rank

    def forward(self, x):
        return self.original(x) + self.scale * self.lora_B(self.lora_A(x))


In [7]:
def inject_lora(unet: nn.Module, rank: int = 2):
    count = 0
    for name, module in list(unet.named_modules()):
        if isinstance(module, nn.Linear) and any(k in name for k in ["to_q", "to_k", "to_v"]):
            parent = get_parent(unet, name)
            child_name = name.split(".")[-1]
            setattr(parent, child_name, LoRALinear(module, rank=rank))
            count += 1
    print("Injected LoRA into", count, "linear layers")


In [8]:
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True

In [9]:
# STATS of gpu
def print_cuda_stats():
    print(f"VRAM Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
    print(f"VRAM Reserved:  {torch.cuda.memory_reserved()/1024**2:.2f} MB")

In [10]:
# TRAINING FUNCTION
def train_lora():
    set_seed(seed)
    accelerator = Accelerator()  
    device = accelerator.device
    print("Accelerator device:", device)

    torch.cuda.empty_cache()

    # LOAD PIPE
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        safety_checker=None,
        torch_dtype=torch.float32,
    )
    pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
    pipe.to(device)
    print("Pipeline loaded")

    # EXTRACT COMPONENTS
    tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer")
    unet = pipe.unet
    vae = pipe.vae
    text_encoder = pipe.text_encoder

    # GRADIENT CHECKPOINTING + FREEZE BASE WEIGHTS
    unet.enable_gradient_checkpointing()
    if hasattr(text_encoder, "gradient_checkpointing_enable"):
        text_encoder.gradient_checkpointing_enable()

    for p in unet.parameters():
        p.requires_grad = False
    for p in text_encoder.parameters():
        p.requires_grad = False  # we train only UNet LoRA

    print("UNet and text encoder frozen (base)")

    # ADD LORA
    inject_lora(unet, lora_rank)

    # COLLECT TRAINABLE PARAMS (ONLY LoRA)
    trainable_params = [p for p in unet.parameters() if p.requires_grad]
    print("Trainable parameters (LoRA only):", sum(p.numel() for p in trainable_params))
    optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate)

    # DATALOADER
    loader = DataLoader(
        dataset,
        batch_size=train_batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    unet, optimizer, loader = accelerator.prepare(unet, optimizer, loader)
    print("Training start")

# TRAIN LOOP 
    steps = 0
    unet.train()

    print("\n Debug mode enabled — watch first few steps carefully!\n")

    for epoch in range(num_epochs):
        for img, caption in loader:

            # Stop if max steps reached
            if steps >= max_train_steps:
                print("Stopped after", steps, "steps")
                pipe.unet = accelerator.unwrap_model(unet)
                return pipe

            # DEBUG (FIRST BATCH ONLY)
            if steps == 0:
                print("[debug] First caption:", caption)
                print("[debug] Image tensor shape:", img.shape)

            img = img.to(device)

            # TOKENIZE CAPTION 
            tokens = tokenizer(
                list(caption),
                truncation=True,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                return_tensors="pt"
            ).to(device)

            if steps == 0:   # debug only first iteration
                print("[debug] Tokenizer output IDs shape:", tokens.input_ids.shape)
                print("[debug] Example input IDs:", tokens.input_ids[0][:10])

            with torch.no_grad():
                latents = vae.encode(img).latent_dist.sample() * vae.config.scaling_factor
                enc = text_encoder(tokens.input_ids)[0]

            noise = torch.randn_like(latents)
            t = torch.randint(0, pipe.scheduler.config.num_train_timesteps,
                            (latents.shape[0],), device=device, dtype=torch.long)
            noisy = pipe.scheduler.add_noise(latents, noise, t)

            pred = unet(noisy, t, enc).sample
            loss = nn.functional.mse_loss(pred, noise)

            # SAVE MODLES
            if steps > 0 and steps % 50 == 0 and accelerator.is_main_process:
                torch.cuda.empty_cache()
                print(f"Running test inference at step: {steps}")
                with torch.no_grad():
                    lora_dict = accelerator.unwrap_model(unet).state_dict()
                    torch.save(lora_dict, f"{CHECKPOINT_DIR}/step_{steps}.pt")
                    print(f"Saved checkpoint: step_{steps}.pt")
                    print(f"Loss: {loss.item():.4f}")
                    print_cuda_stats()
                    print("\n")
                    
            # BACKPROP 
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            steps += 1

    pipe.unet = accelerator.unwrap_model(unet)
    return pipe

In [11]:
# RUN TRAINING
if __name__ == "__main__":
    pipe = train_lora()
    #SAVE LORA WEIGHTS 
    lora_state_dict = {k: v.cpu() for k, v in pipe.unet.state_dict().items() if "lora_" in k}
    out_path = os.path.join(save_dir, "speedbump_lora.pt")
    torch.save(lora_state_dict, out_path)
    print("LoRA weights saved at:", out_path)

Seed set to 42
Accelerator device: cuda


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Pipeline loaded
UNet and text encoder frozen (base)
Injected LoRA into 96 linear layers
Trainable parameters (LoRA only): 298752
Training start

 Debug mode enabled — watch first few steps carefully!

[debug] First caption: ['a speed bump on a city street.']
[debug] Image tensor shape: torch.Size([1, 3, 512, 512])
[debug] Tokenizer output IDs shape: torch.Size([1, 77])
[debug] Example input IDs: tensor([49406,   320,  4163, 15877,   525,   320,  1305,  2012,   269, 49407],
       device='cuda:0')
Running test inference at step: 50
Saved checkpoint: step_50.pt
Loss: 0.0810
VRAM Allocated: 4400.38 MB
VRAM Reserved:  4574.00 MB


Running test inference at step: 100
Saved checkpoint: step_100.pt
Loss: 0.5787
VRAM Allocated: 4400.38 MB
VRAM Reserved:  4466.00 MB


Running test inference at step: 150
Saved checkpoint: step_150.pt
Loss: 0.0209
VRAM Allocated: 4400.38 MB
VRAM Reserved:  4494.00 MB


Running test inference at step: 200
Saved checkpoint: step_200.pt
Loss: 0.0216
VRAM Allocated: 