In [1]:
import os
from pathlib import Path

import torch
from torch.utils.data import  DataLoader

from accelerate import Accelerator
from accelerate.utils import set_seed

from diffusers import DDPMScheduler
from peft import LoraConfig
from diffusers import StableDiffusionXLPipeline
import sdxl_lora_trainer_helper as helper

print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())


torch: 2.7.0
cuda available: False


In [2]:
#Using tiny model for smoke test, cpu instead of cuda and small resulation, only 1 training step

MODEL_ID = os.getenv("SMOKE_MODEL", "dg845/tiny-random-stable-diffusion-xl")
DEVICE = os.getenv("SMOKE_DEVICE", "cuda" if torch.cuda.is_available() else "cpu").lower()
STEPS = int(os.getenv("SMOKE_STEPS", "1"))
RES = int(os.getenv("SMOKE_RES", "64"))

if DEVICE == "cuda" and not torch.cuda.is_available():
    raise RuntimeError("DEVICE=cuda requested but CUDA is not available.")

DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

print("MODEL_ID:", MODEL_ID)
print("DEVICE:", DEVICE)
print("DTYPE:", DTYPE)
print("STEPS:", STEPS, "RES:", RES)


MODEL_ID: dg845/tiny-random-stable-diffusion-xl
DEVICE: cpu
DTYPE: torch.float32
STEPS: 1 RES: 64


In [3]:
set_seed(0)

ds_root = Path("dataset/small_test")

out_root = Path("./smoke_out")
out_root.mkdir(parents=True, exist_ok=True)

accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=None)

local_files_only = os.getenv("HF_HUB_OFFLINE", "") == "1"
print("HF_HUB_OFFLINE:", os.getenv("HF_HUB_OFFLINE", ""), "=> local_files_only:", local_files_only)

print("Loading pipeline…")
pipe = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    use_safetensors=True,
    local_files_only=local_files_only,
)
pipe = pipe.to(DEVICE)
print("Pipeline loaded ")

unet = pipe.unet
vae = pipe.vae
te1 = pipe.text_encoder
te2 = pipe.text_encoder_2
tok1 = pipe.tokenizer
tok2 = pipe.tokenizer_2

# Freezing base weights
vae.requires_grad_(False)
te1.requires_grad_(False)
te2.requires_grad_(False)
unet.requires_grad_(False)

# Add LoRA
unet_lora_config = LoraConfig(
    r=4,
    lora_alpha=4,
    init_lora_weights="gaussian",
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
)

unet.add_adapter(unet_lora_config)

lora_params = [p for n, p in unet.named_parameters() if "lora" in n and p.requires_grad]
print("trainable LoRA params:", sum(p.numel() for p in lora_params))
assert len(lora_params) > 0, "No LoRA params marked trainable."

noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

ds = helper.ImageCaptionFolder(str(ds_root), resolution=RES, center_crop=True)
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)

opt = torch.optim.AdamW(lora_params, lr=1e-4)

unet, opt, dl = accelerator.prepare(unet, opt, dl)
device_t = accelerator.device

vae.to(device_t)
te1.to(device_t)
te2.to(device_t)
unet.train()

print("Running train step…")
global_step = 0
for batch in dl:
    if global_step >= STEPS:
        break

    with accelerator.accumulate(unet):
        pixel_values = batch["pixel_values"].to(device_t, dtype=vae.dtype)

        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device_t, dtype=torch.long
        )
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        captions = batch["caption"]
        input_ids_1 = tok1(
            captions, padding="max_length", truncation=True, max_length=77, return_tensors="pt"
        ).input_ids.to(device_t)
        input_ids_2 = tok2(
            captions, padding="max_length", truncation=True, max_length=77, return_tensors="pt"
        ).input_ids.to(device_t)

        with torch.no_grad():
            enc1 = te1(input_ids_1, output_hidden_states=True)
            enc2 = te2(input_ids_2, output_hidden_states=True)
            prompt_embeds = torch.cat([enc1.hidden_states[-2], enc2.hidden_states[-2]], dim=-1)
            pooled_prompt_embeds = enc2[0]

        add_time_ids = torch.tensor(
            [RES, RES, 0, 0, RES, RES], device=device_t, dtype=prompt_embeds.dtype
        ).unsqueeze(0).repeat(bsz, 1)

        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
        ).sample

        loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
        accelerator.backward(loss)
        opt.step()
        opt.zero_grad(set_to_none=True)

    print(f"step={global_step} loss={loss.item():.6f}")
    global_step += 1

# Testing saving + reloading LoRA weights
ckpt = out_root / "ckpt"
print("Saving LoRA…")
helper.save_unet_lora_peft(accelerator.unwrap_model(unet), str(ckpt))
print("Saved ")

print("Reload LoRA into a fresh pipeline…")
pipe2 = StableDiffusionXLPipeline.from_pretrained(
    MODEL_ID,
    use_safetensors=True,
    local_files_only=local_files_only,
).to(DEVICE)

helper.load_lora_into_pipe(pipe2, str(ckpt))
print("Reload ")

print("\n TEST PASSED")
print("Artifacts dir:", out_root.resolve())


HF_HUB_OFFLINE:  => local_files_only: False
Loading pipeline…


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Pipeline loaded 
trainable LoRA params: 8832
Running train step…


  arr = torch.ByteTensor(torch.ByteStorage.from_buffer(image.tobytes()))


step=0 loss=1.087599
Saving LoRA…
Saved 
Reload LoRA into a fresh pipeline…


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]



Reload 

 TEST PASSED
Artifacts dir: /techfak/user/sguszausky/ArtistDiffusionModel/smoke_out
