# DreamBooth + LoRA Training Pipeline for Realistic Vision V5.1

In [None]:
import os
import torch
import random
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from accelerate import Accelerator

2025-07-12 17:17:33.358642: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752340653.714298      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752340653.820441      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
# HF TOKEN
from huggingface_hub import login
login("your_token")

In [None]:
# Downloading Realistic Vision V5
from huggingface_hub import hf_hub_download
import os

model_dir = "/kaggle/working//base_model"
os.makedirs(model_dir, exist_ok=True)

ckpt_path = hf_hub_download(
    repo_id="SG161222/Realistic_Vision_V5.1_noVAE",
    filename="Realistic_Vision_V5.1.safetensors",
    local_dir=model_dir,
)

Realistic_Vision_V5.1.safetensors:   0%|          | 0.00/4.27G [00:00<?, ?B/s]

In [None]:
# Downloading v1-inference.yaml
!wget -O v1-inference.yaml https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml

--2025-07-12 17:18:08--  https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1873 (1.8K) [text/plain]
Saving to: ‘v1-inference.yaml’


2025-07-12 17:18:09 (19.2 MB/s) - ‘v1-inference.yaml’ saved [1873/1873]



In [None]:
# Converting to diffusers format
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt

safetensors_path = "/kaggle/working/base_model/Realistic_Vision_V5.1.safetensors"
output_dir = "/kaggle/working/realistic_vision_diffusers"

converted_pipeline = download_from_original_stable_diffusion_ckpt(
    safetensors_path,
    "/kaggle/working/v1-inference.yaml",  # Must match SD1.5 or SD2.x
    from_safetensors=True,
    extract_ema=True,
    device="cuda"  # or "cpu"
)

# saving
converted_pipeline.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")

config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]



Model saved to /kaggle/working/realistic_vision_diffusers


**Defining configurations for the training**

trigger_word: This let's you call your character. Whenever you use this it means that you are calling/referring your character.

batch_size: how many images it will train in each step

gradient_accumulation: Using this I can increase my batch size with same GPU memory usage.

effective batch size = (batch_size) x (gradient_accumulation)

learning_rate: how fast you want the model to learn the character (too high the model might overfit, too low the model would not learn your character fully) (effective proven and tested learning rate would be from 5e-5 to 1e-6 and depending upon the training steps and dataset variation) 

max_train_steps: how many steps you want the training to run (1 step = 1 batch run)

train_text_encoder: whether we want to train train_text_encoder or not. (if yes we would update weights of train text encoder used in the base model using our setup) (Note: training this is necessary when training on specific characters as the base model might struggle learning your trigger_word and how it is in different scenarios)

lr_scheduler: how do you want your learning rate to decay. (I tried constant, no lr scheduler it doesn't work, it model overfitted or underfitted, cosine worked best for me as the learning_rate would decay slowly.)

lr_warmup_steps: when you want your learning rate to start decay (e.g. 100 means leraning rate would start decaying after 100 steps with the lr_scheduler)

lora_r: Determines the rank of the low-rank adaptation matrices, controlling the number of trainable parameters added for fine-tuning.

lora_alpha: A scaling factor applied to the LoRA updates, adjusting their overall impact on the model's weights.

lora_dropout: Probability of randomly dropping LoRA adaptation connections during training, which helps prevent overfitting by adding regularization.

In [None]:
class Config:
    # ✅ Paths
    model_path = "/kaggle/working/realistic_vision_diffusers"
    dataset_dir = "/kaggle/input/gwen-phase1"
    output_dir = "/kaggle/working/lora_gwen"

    # ✅ Training Behavior
    trigger_word = "sks"  # Keep this for identity consistency
    resolution = 512       # Good middle ground for face detail
    
    batch_size = 1
    gradient_accumulation = 4
    learning_rate = 4e-5   # 
    max_train_steps = 3330 # 
    
    mixed_precision = "fp16"
    train_text_encoder = True  

    # ✅ LoRA Settings
    lr_scheduler = "cosine"
    lr_warmup_steps = 185
    lora_r = 4
    lora_alpha = 8
    lora_dropout = 0.1
    # lora_target_modules = ["CrossAttention", "Attention"]

    # ✅ Logging & Checkpoints
    save_every_n_steps = 1110
    log_every_n_steps = 370
    generate_every_n_steps = 370
    seed = 151101

cfg = Config()

In [9]:
import os, json, inspect

def save_config(cfg, path=None):
    if path is None:
        path = os.path.join(cfg.output_dir, "lora_config.json")
    os.makedirs(os.path.dirname(path), exist_ok=True)

    # grab only data attributes declared on the *class*
    config_dict = {
        k: v for k, v in cfg.__class__.__dict__.items()
        if not k.startswith("__") and not inspect.isfunction(v) and not inspect.ismethod(v)
    }

    with open(path, "w") as f:
        json.dump(config_dict, f, indent=4)
    print(f"✅  Config saved to {path}")

save_config(cfg)

✅  Config saved to /kaggle/working/lora_gwen/lora_config.json


In [10]:
# @title Reproducibility
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
random.seed(cfg.seed)

In [None]:
#  Dataset Loader
class ImageCaptionDataset(Dataset):
    def __init__(self, image_dir, tokenizer, size):
        self.image_paths = []
        self.caption_paths = []
        self.tokenizer = tokenizer

        for fname in sorted(os.listdir(image_dir)):
            if fname.endswith(".png") or fname.endswith(".jpg"):
                img_path = os.path.join(image_dir, fname)
                txt_path = os.path.splitext(img_path)[0] + ".txt"
                if os.path.exists(txt_path):
                    self.image_paths.append(img_path)
                    self.caption_paths.append(txt_path)

        self.image_transforms = transforms.Compose([
            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if image.getbbox() is None:
            raise ValueError(f"Empty image found: {self.image_paths[idx]}")
        else:
            image = self.image_transforms(image)
            with open(self.caption_paths[idx], "r") as f:
                caption = f.read().strip()

            inputs = self.tokenizer(caption, truncation=True, padding="max_length", max_length=77, return_tensors="pt")
            
        return {"pixel_values": image, "input_ids": inputs.input_ids.squeeze(0)}

I was frustrated by version mismatches and compatibility issues with popular libraries (like peft) for LoRA in pytorch. Instead of fighting package dependencies, I researched the fundamental principles behind LoRA and wrote a minimal, robust implementation compatible with any pytorch model using nn.Linear layers.

My LoRALinear class is a simple, direct wrapper for any nn.Linear layer. It adds trainable “LoRA” weights on top of the original weights, enabling efficient fine-tuning.

In [12]:
import torch.nn as nn

# 🔧 Safer and more flexible LoRA wrapper for nn.Linear
class LoRALinear(nn.Module):
    def __init__(self, linear: nn.Linear, rank: int, alpha: float):
        super().__init__()

        if not isinstance(linear, nn.Linear):
            raise TypeError(f"LoRALinear can only wrap nn.Linear, but got {type(linear)}")

        self.linear = linear
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        # LoRA layers
        self.lora_down = nn.Linear(linear.in_features, rank, bias=False)
        self.lora_up = nn.Linear(rank, linear.out_features, bias=False)

        # Initialization (standard LoRA practice)
        nn.init.kaiming_uniform_(self.lora_down.weight, a=5**0.5)
        nn.init.zeros_(self.lora_up.weight)

    def forward(self, x):
        # Original + LoRA residual
        return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scaling


Now our patcher for patching unet.

The most important question what are we patching, why that part and we would we patch?

Answer: We would patch some layers of the model not all, the layers most important for learning, new style, new character, new things are "to_q" ,"to_k", "to_v" and "to_out". These layers work together when you train your model they contain their own weights for what they learned. If we patched the old weights we might mess with the working of the model and it may produce messy, or noisy results during image generation. That's why we patch new layers and train the model these contain new weights for your specific style, character and new things. And you can load these on top of your Stable Diffusion Model and use them for image generation. 

In [13]:
def patch_unet_cross_attn_with_lora(unet, rank, alpha):
    lora_params = []

    for module in unet.modules():
        for attr in ['to_q', 'to_k', 'to_v', 'to_out']:
            if hasattr(module, attr):
                original = getattr(module, attr)

                # ✅ Directly patch if it's a Linear
                if isinstance(original, nn.Linear) and not isinstance(original, LoRALinear):
                    lora_layer = LoRALinear(original, rank=rank, alpha=alpha)
                    setattr(module, attr, lora_layer)
                    lora_params.extend(lora_layer.lora_down.parameters())
                    lora_params.extend(lora_layer.lora_up.parameters())
                    print(f"✅ Patched {attr} in {module.__class__.__name__}")

                # 🔍 Special case: to_out is a ModuleList with a Linear inside
                elif isinstance(original, nn.ModuleList):
                    for i, sublayer in enumerate(original):
                        if isinstance(sublayer, nn.Linear) and not isinstance(sublayer, LoRALinear):
                            lora_layer = LoRALinear(sublayer, rank=rank, alpha=alpha)
                            original[i] = lora_layer
                            lora_params.extend(lora_layer.lora_down.parameters())
                            lora_params.extend(lora_layer.lora_up.parameters())
                            print(f"✅ Patched {attr}[{i}] in {module.__class__.__name__}")
                        else:
                            print(f"⚠️ Skipping {attr}[{i}] — Not a Linear: {type(sublayer)}")

                else:
                    print(f"⚠️ Skipping {attr} in {module.__class__.__name__} — Not a Linear or ModuleList")

    total_params = sum(p.numel() for p in lora_params)
    print(f"\n✅ UNet LoRA patched. Total trainable LoRA params: {total_params}")
    return lora_params


This function is to unfrreze train_text encoder. I defined it for testing.  

In [None]:
# def unfreeze_text_encoder_attention(text_encoder):
#     trainable_params = []
#     total_unfroze = 0

#     for module in text_encoder.modules():
#         if all(hasattr(module, attr) for attr in ['q_proj', 'k_proj', 'v_proj', 'out_proj']):
#             for attr in ['q_proj', 'k_proj', 'v_proj', 'out_proj']:
#                 proj = getattr(module, attr)
#                 for param in proj.parameters():
#                     param.requires_grad = True
#                     trainable_params.append(param)
#                 total_unfroze += 1
#                 print(f"✅ Unfroze text encoder layer: {module.__class__.__name__} → {attr}")

#     print(f"\n✅ Total unfrozen attention projection blocks in text encoder: {total_unfroze}")
#     return trainable_params


Now we will patch the text_encoder the theory and reason is same as patching the unet layers. We would patch "q_proj", "k_proj", "v_proj", "out_proj" of the text_encoder.

In [15]:
def patch_text_encoder(text_encoder, rank, alpha):
    lora_params = []
    target_names = {"q_proj", "k_proj", "v_proj", "out_proj"}

    for module in text_encoder.modules():
        for name in target_names:
            if hasattr(module, name):
                proj = getattr(module, name)

                # Only patch raw nn.Linear layers (skip if already patched)
                if isinstance(proj, nn.Linear) and not isinstance(proj, LoRALinear):
                    lora_layer = LoRALinear(proj, rank=rank, alpha=alpha)
                    setattr(module, name, lora_layer)

                    lora_params += list(lora_layer.lora_down.parameters())
                    lora_params += list(lora_layer.lora_up.parameters())

                    print(f"🔧 Patched {module.__class__.__name__}.{name} with LoRA")

    print(f"✅ Finished patching text encoder — LoRA params: {sum(p.numel() for p in lora_params):,}")
    return lora_params


Now this is where we would create our pipeline, load our model into pipeline, load our tokenizer, load our trigger word into the tokenizer, load our unet and patch it, load our text_encoder and patch it. But the important thing is that we need to freeze our params of unet and text_encoder before patching so we accidently don't train them or update them and we only train or update our newly patched layers. We are using AdamW optimer here originally used with base our base model. We would also load our dataset into the pipeline and prepare everything for training.

In [16]:
import torch, os
from accelerate import Accelerator
from diffusers          import StableDiffusionPipeline, UNet2DConditionModel
from transformers        import AutoTokenizer, get_scheduler
from torch.utils.data    import DataLoader

# ──────────────────────────────────────────────────────────
#  1.  Accelerate & Config
# ──────────────────────────────────────────────────────────
accelerator = Accelerator(split_batches=True)
device       = accelerator.device
cfg          = Config()                       # <- your existing config class

# ──────────────────────────────────────────────────────────
#  2.  Load full pipeline   (VAE, Text‑Encoder, UNet, etc.)
# ──────────────────────────────────────────────────────────
pipe = StableDiffusionPipeline.from_pretrained(
    cfg.model_path,
    torch_dtype=torch.float16
).to(device)

# ──────────────────────────────────────────────────────────
#  3.  Trigger‑token → add BEFORE any dataset tokenisation
# ──────────────────────────────────────────────────────────
tokenizer      = AutoTokenizer.from_pretrained(cfg.model_path, subfolder="tokenizer")
trigger_token  = cfg.trigger_word                    # e.g. "pendugwen"

if len(tokenizer.tokenize(trigger_token)) > 1:       # splits → need custom token
    tokenizer.add_tokens([trigger_token])
    pipe.text_encoder.resize_token_embeddings(len(tokenizer))

    # initialise the new embedding from "person"
    with torch.no_grad():
        emb           = pipe.text_encoder.get_input_embeddings()
        new_id        = tokenizer.convert_tokens_to_ids(trigger_token)
        base_id       = tokenizer.convert_tokens_to_ids("person")
        emb.weight[new_id] = emb.weight[base_id].clone()

    print(f"✅ Added custom token '{trigger_token}' (id {new_id})")
else:
    print(f"✅ '{trigger_token}' already a single token")

pipe.tokenizer = tokenizer        # keep pipeline & dataset in sync

# ──────────────────────────────────────────────────────────
#  4.  Load UNet separately (so we can LoRA‑patch it)
# ──────────────────────────────────────────────────────────
unet = UNet2DConditionModel.from_pretrained(
    cfg.model_path,
    subfolder="unet",
    torch_dtype=torch.float16
).to(device)

# Freeze everything first
for p in unet.parameters():            p.requires_grad = False
for p in pipe.text_encoder.parameters(): p.requires_grad = False

# ──────────────────────────────────────────────────────────
#  5.  LoRA‑patch UNet  (+ optional text‑encoder)
# ──────────────────────────────────────────────────────────
lora_params  = patch_unet_cross_attn_with_lora(unet, cfg.lora_r, cfg.lora_alpha)

if cfg.train_text_encoder:
    lora_params += patch_text_encoder(pipe.text_encoder, cfg.lora_r, cfg.lora_alpha)
    print("✅ Text‑encoder LoRA patched")

if not lora_params:
    raise RuntimeError("No trainable LoRA params collected!")

print(f"🔍 LoRA trainable parameters: {sum(p.numel() for p in lora_params):,}")

# ──────────────────────────────────────────────────────────
#  6.  Optimiser & Scheduler (LoRA params only)
# ──────────────────────────────────────────────────────────
optimizer     = torch.optim.AdamW(lora_params, lr=cfg.learning_rate)

lr_scheduler  = get_scheduler(
    cfg.lr_scheduler,
    optimizer          = optimizer,
    num_warmup_steps   = cfg.lr_warmup_steps,
    num_training_steps = cfg.max_train_steps,
)

# ──────────────────────────────────────────────────────────
#  7.  Dataset & Dataloader  (tokenizer now has fixed token!)
# ──────────────────────────────────────────────────────────
dataset    = ImageCaptionDataset(cfg.dataset_dir, tokenizer, cfg.resolution)
dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

# ──────────────────────────────────────────────────────────
#  8.  Prepare for Accelerate & training
# ──────────────────────────────────────────────────────────
pipe.text_encoder.to(device, dtype=torch.float16)
unet.train()

unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)

print("🚀 Setup complete – ready to train LoRA adapters.")


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


✅ Added custom token 'pendugwen' (id 49408)
✅ Patched to_q in Attention
✅ Patched to_k in Attention
✅ Patched to_v in Attention
✅ Patched to_out[0] in Attention
⚠️ Skipping to_out[1] — Not a Linear: <class 'torch.nn.modules.dropout.Dropout'>
✅ Patched to_q in Attention
✅ Patched to_k in Attention
✅ Patched to_v in Attention
✅ Patched to_out[0] in Attention
⚠️ Skipping to_out[1] — Not a Linear: <class 'torch.nn.modules.dropout.Dropout'>
✅ Patched to_q in Attention
✅ Patched to_k in Attention
✅ Patched to_v in Attention
✅ Patched to_out[0] in Attention
⚠️ Skipping to_out[1] — Not a Linear: <class 'torch.nn.modules.dropout.Dropout'>
✅ Patched to_q in Attention
✅ Patched to_k in Attention
✅ Patched to_v in Attention
✅ Patched to_out[0] in Attention
⚠️ Skipping to_out[1] — Not a Linear: <class 'torch.nn.modules.dropout.Dropout'>
✅ Patched to_q in Attention
✅ Patched to_k in Attention
✅ Patched to_v in Attention
✅ Patched to_out[0] in Attention
⚠️ Skipping to_out[1] — Not a Linear: <class 't

Just check that how many params we are going to train now out of total params. 

In [18]:
total_params = sum(p.numel() for p in unet.parameters()) + sum(p.numel() for p in pipe.text_encoder.parameters())
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) + \
                   sum(p.numel() for p in pipe.text_encoder.parameters() if p.requires_grad)

print(f"✅ Trainable parameters: {trainable_params} / {total_params}")


✅ Trainable parameters: 1092096 / 983674308


Loading the vae used in the base model. The vae is what encodes and decodes our latent spaces into images.

In [19]:
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
pipe.vae = vae

pipe.vae.to(accelerator.device, dtype=torch.float32)

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

In [20]:
print(accelerator.device)  # Should print: cuda


cuda


Sanity check that if our paramters contain NAN values before our training begins.

In [21]:
for name, param in unet.named_parameters():
    if torch.isnan(param).any():
        print(f"NaN detected in UNet parameter: {name}")
        break


Now we would create functions for checking if our trigger_word is in the tokenizer, a function to generate sample images during the model training for checks if our model is learning or not, whether it is overfitting or underfitting. Whether if we need to stop our training, whether we need to change our configurations settings.

In [22]:
import os
import random
import torch
from diffusers import DDIMScheduler
from torch import autocast

# ✅ Optional: Reproducibility
def seed_everything(seed=151101):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ✅ Optional: Disable NSFW filter (for local testing)
def disable_safety(pipe):
    pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))

# ✅ Fix tokenizer trigger word if needed
def ensure_trigger_token(pipe, trigger_word):
    tokens = pipe.tokenizer.tokenize(trigger_word)
    if len(tokens) > 1:
        print(f"⚠️ Trigger word '{trigger_word}' is split: {tokens}. Fixing...")
        pipe.tokenizer.add_tokens([trigger_word])
        pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))

        with torch.no_grad():
            embeddings = pipe.text_encoder.get_input_embeddings()
            init_id = pipe.tokenizer.convert_tokens_to_ids("person")
            new_id = pipe.tokenizer.convert_tokens_to_ids(trigger_word)
            embeddings.weight[new_id] = embeddings.weight[init_id].clone()

        print(f"✅ Re-added and initialized embedding for trigger word '{trigger_word}'")
    else:
        print(f"✅ Trigger word '{trigger_word}' is tokenized correctly: {tokens}")

# ✅ Generate and save a sample image
def generate_sample_image(step, save_path, prompt=None, negative_prompt=None, seed=151101):
    print(f"\n🎨 Generating preview at step {step}...")

    # ✅ Reproducible randomness
    seed_everything(seed)
    generator = torch.Generator("cuda").manual_seed(seed)

    # ✅ Restore UNet (EMA if available, else current)
    if 'ema_unet' in globals() and ema_unet is not None:
        pipe.unet = ema_unet
        print("📦 Using EMA UNet for inference.")
    else:
        pipe.unet = accelerator.unwrap_model(unet)
        print("📦 Using current UNet for inference.")
    
    # ✅ Restore LoRA-trained text_encoder if trained
    if cfg.train_text_encoder:
        pipe.text_encoder = accelerator.unwrap_model(pipe.text_encoder)
        print("🧠 Restored LoRA-trained text encoder for inference.")

    pipe.to("cuda")

    # ✅ Disable NSFW checker for previewing
    disable_safety(pipe)

    # ✅ Ensure tokenizer supports the trigger word
    ensure_trigger_token(pipe, cfg.trigger_word)

    # ✅ Use fast scheduler
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    # ✅ Default prompts if not provided
    if prompt is None:
        prompt = f"close-up of {cfg.trigger_word}, detailed facial features, natural skin texture, soft expression, white blonde hair, illuminated by natural light, shallow depth of field, blurred background, ultra high quality"
    if negative_prompt is None:
        negative_prompt = (
            "blurry, low resolution, grainy, overexposed, underexposed, poor lighting, jpeg artifacts, glitch, "
            "cropped, out of frame, watermark, duplicate, poorly drawn face, asymmetrical face, deformed face, "
            "unnatural skin texture, doll-like face, bad eyes, mutated hands, extra fingers, distorted anatomy, "
            "unrealistic proportions, cartoon, anime, illustration, painting, horror, morbid"
        )

    # ✅ Generate image
    with autocast("cuda"):
        # Generate 4 images
        result = pipe(
            prompt=[prompt] * 4,
            negative_prompt=[negative_prompt] * 4,
            num_inference_steps=30,
            guidance_scale=6.0,
            height=cfg.resolution,
            width=cfg.resolution,
            generator=generator,
        )

    # ✅ Save images
    os.makedirs(save_path, exist_ok=True)
    for i, image in enumerate(result.images):
        save_name = os.path.join(save_path, f"preview_step_{step}_{i+1}.png")
        image.save(save_name)
        print(f"✅ Saved: {save_name}")




Our main training loop for training our model. 

In [None]:
from torch.amp import autocast
from safetensors.torch import save_file
from diffusers.models.attention_processor import LoRAAttnProcessor
from PIL import Image

# 🛠️ Optional: Force LoRA layers to float32 for stability
for module in unet.modules():
    if isinstance(module, LoRAAttnProcessor):
        for param in module.parameters():
            param.data = param.data.to(torch.float32)

global_step = 0
unet.train()

for epoch in range(100):
    for step, batch in enumerate(tqdm(dataloader)):
        with accelerator.accumulate(unet):
            # ✅ 1. Move batch to device
            pixel_values = batch["pixel_values"].to(accelerator.device, dtype=torch.float32)  # VAE prefers float32
            input_ids = batch["input_ids"].to(accelerator.device)
            # print(tokenizer.decode(input_ids[0]))
# 
            # ✅ 2. Encode with VAE (float32), then clamp and convert
            with torch.no_grad():
                latents = pipe.vae.encode(pixel_values).latent_dist.sample()
                latents = latents.clamp(-10, 10)  # Avoid extreme latent values
                latents = latents * 0.18215
                latents = latents.to(accelerator.device, dtype=torch.float16)

            # ✅ 3. Add scaled noise
            noise = 0.9 * torch.randn_like(latents)  # reduce intensity for early stability
            max_timestep = 300 if global_step < 100 else pipe.scheduler.config.num_train_timesteps
            timesteps = torch.randint(0, max_timestep, (latents.shape[0],), device=latents.device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            for name, param in unet.named_parameters():
                if torch.isnan(param).any() or torch.isinf(param).any():
                    print(f"[❌ NaN/Inf detected] in parameter: {name}")
                    break


            # ✅ 4. Encode text
            with torch.no_grad():
                encoder_hidden_states = pipe.text_encoder(input_ids)[0]
                encoder_hidden_states = encoder_hidden_states.to(accelerator.device, dtype=torch.float16)

            # ✅ 5. UNet forward with autocast
            with autocast("cuda", dtype=torch.float32):
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                ).sample

            # ✅ Debug logs
            # print(f"Latents mean/std: {latents.mean().item():.4f}/{latents.std().item():.4f}")
            # print(f"model_pred mean/std: {model_pred.mean().item():.4f}/{model_pred.std().item():.4f}")

            # ✅ Check for NaNs
            if torch.isnan(model_pred).any():
                print("❌ NaN in model_pred!")
                print("Input stats:", noisy_latents.mean(), noisy_latents.std())
                print("Timesteps:", timesteps)
                print("Encoder stats:", encoder_hidden_states.mean(), encoder_hidden_states.std())
                continue

            # ✅ 6. Compute loss
            noise = noise.to(model_pred.dtype)
            loss = torch.nn.functional.l1_loss(model_pred, noise)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"⚠️ Skipping invalid loss at step {global_step}")
                continue

            # ✅ 7. Backward + optimizer
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                # ✅ Logging
                if global_step % cfg.log_every_n_steps == 0:
                    print(f"Step {global_step} | Loss: {loss.item():.4f}")


                # ✅ Save LoRA Weights
                if global_step % cfg.save_every_n_steps == 0:
                    save_path = os.path.join(cfg.output_dir, f"step_{global_step}")
                    os.makedirs(save_path, exist_ok=True)
                
                    def extract_lora_weights(state_dict):
                        return {k: v for k, v in state_dict.items() if "lora" in k.lower()}
                
                    # ✅ Extract only LoRA weights from UNet and text encoder
                    unet_lora = extract_lora_weights(accelerator.unwrap_model(unet).state_dict())
                    text_lora = extract_lora_weights(accelerator.unwrap_model(pipe.text_encoder).state_dict())
                
                    # ✅ Combine weights into a single dictionary
                    combined_lora = {**unet_lora, **text_lora}
                
                    # ✅ Save using safetensors
                    save_file(combined_lora, os.path.join(save_path, "lora_only.safetensors"))
                
                    print(f"✅ Saved combined LoRA checkpoint at step {global_step} → {save_path}/lora_only.safetensors")

                    
                # 🔍 Generate sample image
                if global_step % cfg.generate_every_n_steps == 0:
                    img_path = os.path.join('/kaggle/working/gen_images', f"step_{global_step}")
                    generate_sample_image(global_step, img_path)

        # ✅ Exit condition
        if global_step >= cfg.max_train_steps:
            break
    if global_step >= cfg.max_train_steps:
        break


100%|██████████| 37/37 [00:44<00:00,  1.20s/it]
100%|██████████| 37/37 [00:44<00:00,  1.20s/it]
100%|██████████| 37/37 [00:46<00:00,  1.27s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.24s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
 97%|█████████▋| 36/37 [00:44<00:01,  1.24s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (89 > 77). Running this sequence through the model will result in indexing errors


Step 370 | Loss: 0.5961

🎨 Generating preview at step 370...
📦 Using current UNet for inference.
🧠 Restored LoRA-trained text encoder for inference.
✅ Trigger word 'pendugwen' is tokenized correctly: ['pendugwen']


  0%|          | 0/30 [00:00<?, ?it/s]

✅ Saved: /kaggle/working/gen_images/step_370/preview_step_370_1.png
✅ Saved: /kaggle/working/gen_images/step_370/preview_step_370_2.png
✅ Saved: /kaggle/working/gen_images/step_370/preview_step_370_3.png


100%|██████████| 37/37 [01:09<00:00,  1.87s/it]


✅ Saved: /kaggle/working/gen_images/step_370/preview_step_370_4.png


100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
 97%|█████████▋| 36/37 [00:45<00:01,  1.25s/it]

Step 740 | Loss: 0.2361

🎨 Generating preview at step 740...
📦 Using current UNet for inference.
🧠 Restored LoRA-trained text encoder for inference.
✅ Trigger word 'pendugwen' is tokenized correctly: ['pendugwen']


  0%|          | 0/30 [00:00<?, ?it/s]

✅ Saved: /kaggle/working/gen_images/step_740/preview_step_740_1.png
✅ Saved: /kaggle/working/gen_images/step_740/preview_step_740_2.png


100%|██████████| 37/37 [01:09<00:00,  1.88s/it]


✅ Saved: /kaggle/working/gen_images/step_740/preview_step_740_3.png
✅ Saved: /kaggle/working/gen_images/step_740/preview_step_740_4.png


100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.26s/it]
100%|██████████| 37/37 [00:46<00:00,  1.26s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.26s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
 97%|█████████▋| 36/37 [00:45<00:01,  1.25s/it]

Step 1110 | Loss: 0.2345
✅ Saved combined LoRA checkpoint at step 1110 → /kaggle/working/lora_gwen/step_1110/lora_only.safetensors

🎨 Generating preview at step 1110...
📦 Using current UNet for inference.
🧠 Restored LoRA-trained text encoder for inference.
✅ Trigger word 'pendugwen' is tokenized correctly: ['pendugwen']


  0%|          | 0/30 [00:00<?, ?it/s]

✅ Saved: /kaggle/working/gen_images/step_1110/preview_step_1110_1.png
✅ Saved: /kaggle/working/gen_images/step_1110/preview_step_1110_2.png


100%|██████████| 37/37 [01:09<00:00,  1.88s/it]


✅ Saved: /kaggle/working/gen_images/step_1110/preview_step_1110_3.png
✅ Saved: /kaggle/working/gen_images/step_1110/preview_step_1110_4.png


100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
 97%|█████████▋| 36/37 [00:44<00:01,  1.25s/it]

Step 1480 | Loss: 0.2336

🎨 Generating preview at step 1480...
📦 Using current UNet for inference.
🧠 Restored LoRA-trained text encoder for inference.
✅ Trigger word 'pendugwen' is tokenized correctly: ['pendugwen']


  0%|          | 0/30 [00:00<?, ?it/s]

✅ Saved: /kaggle/working/gen_images/step_1480/preview_step_1480_1.png
✅ Saved: /kaggle/working/gen_images/step_1480/preview_step_1480_2.png


100%|██████████| 37/37 [01:09<00:00,  1.87s/it]


✅ Saved: /kaggle/working/gen_images/step_1480/preview_step_1480_3.png
✅ Saved: /kaggle/working/gen_images/step_1480/preview_step_1480_4.png


100%|██████████| 37/37 [00:45<00:00,  1.24s/it]
100%|██████████| 37/37 [00:46<00:00,  1.24s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
 97%|█████████▋| 36/37 [00:45<00:01,  1.26s/it]

Step 1850 | Loss: 0.2331

🎨 Generating preview at step 1850...
📦 Using current UNet for inference.
🧠 Restored LoRA-trained text encoder for inference.
✅ Trigger word 'pendugwen' is tokenized correctly: ['pendugwen']


  0%|          | 0/30 [00:00<?, ?it/s]

✅ Saved: /kaggle/working/gen_images/step_1850/preview_step_1850_1.png
✅ Saved: /kaggle/working/gen_images/step_1850/preview_step_1850_2.png


100%|██████████| 37/37 [01:09<00:00,  1.88s/it]


✅ Saved: /kaggle/working/gen_images/step_1850/preview_step_1850_3.png
✅ Saved: /kaggle/working/gen_images/step_1850/preview_step_1850_4.png


100%|██████████| 37/37 [00:46<00:00,  1.25s/it]
 78%|███████▊  | 29/37 [00:36<00:10,  1.26s/it]