Please read README.md file on my github for more information.

https://github.com/sukhmansaran/fine-tuning-stable-diffusion-models-lora-dreambooth

**This notebook is part two for fine tuning model for a single character.**

# DreamBooth + LoRA Training Pipeline for Stable Diffusion Models

In [None]:
import os
import torch
import random
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from accelerate import Accelerator

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
# HF TOKEN
from huggingface_hub import login
login("your_token")

In [None]:
# Downloading Realistic Vision V5
from huggingface_hub import hf_hub_download
import os

model_dir = "your_dir_for_saving_downloaded_base_model"
os.makedirs(model_dir, exist_ok=True)

ckpt_path = hf_hub_download(
    repo_id="SG161222/Realistic_Vision_V5.1_noVAE",
    filename="Realistic_Vision_V5.1.safetensors",
    local_dir=model_dir,
)

This is inference file necessary. The file has to be this exact file as it was used for stable diffusion 1.5 version you cannot use any other file here.

In [None]:
# @title Downloading v1-inference.yaml
!wget -O v1-inference.yaml https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml

In [None]:
# Convert to diffusers format
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt

safetensors_path = "downloaded_base_model_path"
output_dir = "your_dir_for_saving_converted_base_model"

converted_pipeline = download_from_original_stable_diffusion_ckpt(
    safetensors_path,
    "/v1-inference.yaml",  # Must match SD1.5 or SD2.x
    from_safetensors=True,
    extract_ema=True,
    device="cuda"  # or "cpu"
)

# saving
converted_pipeline.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")

Loading the phase 1 config and setting up config for phase 2. Phase 1 config file is necessary as we need to extract rank and alpha from that file.

The current config settings are what I have tested and experimented with and have worked for me you can experiment with changing the config settings.

In [None]:
# import json
import json

class Config:
    def __init__(self, json_path=None):
        if json_path:
            self.load_from_json(json_path)
        else:
            self.set_defaults()

    def load_from_json(self, path):
        with open(path, 'r') as f:
            data = json.load(f)
        for key, value in data.items():
            setattr(self, key, value)

    def set_defaults(self):
        # ✅ Phase 2 Defaults
        self.model_path = "your_downloaded_model_path"
        self.dataset_dir = "your_dataset_path"
        self.output_dir = "your_output_dir"

        self.trigger_word = "sks"
        self.resolution = 512
        self.batch_size = 1
        self.gradient_accumulation = 4
        self.learning_rate = 4e-5
        self.max_train_steps = 3240
        self.mixed_precision = "fp16"
        self.train_text_encoder = True

        self.lr_scheduler = "cosine"
        self.lr_warmup_steps = 185
        self.lora_r = 4
        self.lora_alpha = 8
        self.lora_dropout = 0.1

        self.save_every_n_steps = 810
        self.log_every_n_steps = 405
        self.generate_every_n_steps = 405
        self.seed = 151101

        # Optional Phase 1 Reference
        self.phase1_config_path = "your_path_to_phase1_config_file"
        self.phase1_lora_path = "your_path_to_phase1_lora_weights"


In [None]:
cfg = Config()
phase1_cfg = Config(json_path=cfg.phase1_config_path)

In [None]:
import os, json, inspect

def save_config(cfg, path=None):
    if path is None:
        path = os.path.join(cfg.output_dir, "lora_config.json")
    os.makedirs(os.path.dirname(path), exist_ok=True)

    # grab only data attributes declared on the *class*
    config_dict = {
        k: v for k, v in cfg.__class__.__dict__.items()
        if not k.startswith("__") and not inspect.isfunction(v) and not inspect.ismethod(v)
    }

    with open(path, "w") as f:
        json.dump(config_dict, f, indent=4)
    print(f"✅  Config saved to {path}")

save_config(cfg)

In [None]:
# Reproducibility
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
random.seed(cfg.seed)

In [None]:
# Dataset Loader
class ImageCaptionDataset(Dataset):
    def __init__(self, image_dir, tokenizer, size):
        self.image_paths = []
        self.caption_paths = []
        self.tokenizer = tokenizer

        for fname in sorted(os.listdir(image_dir)):
            if fname.endswith(".png") or fname.endswith(".jpg"):
                img_path = os.path.join(image_dir, fname)
                txt_path = os.path.splitext(img_path)[0] + ".txt"
                if os.path.exists(txt_path):
                    self.image_paths.append(img_path)
                    self.caption_paths.append(txt_path)

        self.image_transforms = transforms.Compose([
            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if image.getbbox() is None:
            raise ValueError(f"Empty image found: {self.image_paths[idx]}")
        else:
            image = self.image_transforms(image)
            with open(self.caption_paths[idx], "r") as f:
                caption = f.read().strip()

            inputs = self.tokenizer(caption, truncation=True, padding="max_length", max_length=77, return_tensors="pt")
            
        return {"pixel_values": image, "input_ids": inputs.input_ids.squeeze(0)}

This is our class for adding two loras. It is necessary to patch 2 LORAs for phase 1 and phase 2. It will make sure that phase 1 loaded weights are frozen and accidently not trainable. If we make them trainable we might mess up the phase 1 weights and features learned during fine tuning. You can load the phase 1 weights and try to update those only and not adding new weights for training the fully body with face. I did because in phase 1 I made sure this weights learned the face features correctly and nicely, and in phase 2 it learned the rest of the body as in lower resolution like 512x512 the face features might not be learned correctly and can become messy. I use phase 1 weights for face and phase 2 weights for rest of the body. I am still experimenting it during inference and testing out what works best.  

In [None]:
import torch.nn as nn

class LoRALinearDualPhase(nn.Module):
    def __init__(self, base_linear: nn.Linear, rank1, alpha1, rank2, alpha2):
        super().__init__()
        self.base = base_linear

        # Phase 1 (frozen)
        self.lora_down = nn.Linear(base_linear.in_features, rank1, bias=False)
        self.lora_up   = nn.Linear(rank1, base_linear.out_features, bias=False)
        self.scale     = alpha1 / rank1

        # Phase 2 (trainable)
        self.lora2_down = nn.Linear(base_linear.in_features, rank2, bias=False)
        self.lora2_up   = nn.Linear(rank2, base_linear.out_features, bias=False)
        self.scale2     = alpha2 / rank2

        # Init: Phase 1 weights will be loaded, Phase 2 starts fresh
        nn.init.kaiming_uniform_(self.lora2_down.weight, a=5**0.5)
        nn.init.zeros_(self.lora2_up.weight)

        # Freeze Phase 1
        for p in self.lora_down.parameters(): p.requires_grad = False
        for p in self.lora_up.parameters():   p.requires_grad = False

    def forward(self, x):
        base_out = self.base(x)
        lora1_out = self.lora_up(self.lora_down(x)) * self.scale
        lora2_out = self.lora2_up(self.lora2_down(x)) * self.scale2
        return base_out + lora1_out + lora2_out


This is patcher for patching 2 loras into the unet. Phase 1 weights will be patched with phase 1 rank and alpha and phase 2 lora will be added using phase 2 rank and alpha. The rank and alpha of phase 1 must be the same as used during the updating or training of those weights as not having correct weights will lead to patching error. 

In [None]:
def patch_unet_dual_lora(unet, rank1, alpha1, rank2, alpha2):
    trainable_params = []

    for module in unet.modules():
        for attr in ['to_q', 'to_k', 'to_v', 'to_out']:
            if hasattr(module, attr):
                orig = getattr(module, attr)

                if isinstance(orig, nn.Linear):
                    dual_lora = LoRALinearDualPhase(orig, rank1, alpha1, rank2, alpha2)
                    setattr(module, attr, dual_lora)
                    trainable_params += list(dual_lora.lora2_down.parameters())
                    trainable_params += list(dual_lora.lora2_up.parameters())
                    print(f"✅ Patched {module.__class__.__name__}.{attr} with dual-phase LoRA")

                elif isinstance(orig, nn.ModuleList):
                    for i, sub in enumerate(orig):
                        if isinstance(sub, nn.Linear):
                            dual_lora = LoRALinearDualPhase(sub, rank1, alpha1, rank2, alpha2)
                            orig[i] = dual_lora
                            trainable_params += list(dual_lora.lora2_down.parameters())
                            trainable_params += list(dual_lora.lora2_up.parameters())
                            print(f"✅ Patched {module.__class__.__name__}.{attr}[{i}] with dual-phase LoRA")

    print(f"✅ Total trainable UNet Phase 2 params: {sum(p.numel() for p in trainable_params):,}")
    return trainable_params


The same applies in the text encoder patching up. This is patcher for patching 2 loras into the text encoder. Phase 1 weights will be patched with phase 1 rank and alpha and phase 2 lora will be added using phase 2 rank and alpha. The rank and alpha of phase 1 must be the same as used during the updating or training of those weights as not having correct weights will lead to patching error. 

In [None]:
def patch_text_encoder_dual_lora(text_encoder, rank1, alpha1, rank2, alpha2):
    trainable_params = []
    target_names = {"q_proj", "k_proj", "v_proj", "out_proj"}

    for module in text_encoder.modules():
        for name in target_names:
            if hasattr(module, name):
                orig = getattr(module, name)
                if isinstance(orig, nn.Linear):
                    dual_lora = LoRALinearDualPhase(orig, rank1, alpha1, rank2, alpha2)
                    setattr(module, name, dual_lora)
                    trainable_params += list(dual_lora.lora2_down.parameters())
                    trainable_params += list(dual_lora.lora2_up.parameters())
                    print(f"🔧 Patched {module.__class__.__name__}.{name} with dual-phase LoRA")

    print(f"✅ Total trainable Text Encoder Phase 2 params: {sum(p.numel() for p in trainable_params):,}")
    return trainable_params


This function loads the phase 1 weights into our patched LORA for phase 1. It loads unet and text_encoder weights into the LORA.

In [None]:
from safetensors.torch import load_file

def load_phase1_lora_weights(unet, text_encoder, weights_path):
    state_dict = load_file(weights_path)

    n_loaded = 0
    for name, module in list(unet.named_modules()) + list(text_encoder.named_modules()):
        if isinstance(module, LoRALinearDualPhase):
            for part, param in [('lora_down', module.lora_down), ('lora_up', module.lora_up)]:
                key = f"{name}.{part}.weight"
                if key in state_dict:
                    with torch.no_grad():
                        param.weight.copy_(state_dict[key])
                        n_loaded += 1
                else:
                    print(f"⚠️ Missing key in weights: {key}")
    print(f"✅ Loaded {n_loaded} Phase 1 LoRA weights")


In [None]:
def setup_phase2(unet, text_encoder, rank_phase1, alpha_phase1, rank_phase2, alpha_phase2,
                 train_text_encoder, load_phase1_path):

    # Patch UNet
    lora_params = patch_unet_dual_lora(unet, rank_phase1, alpha_phase1, rank_phase2, alpha_phase2)

    # Patch Text Encoder
    if train_text_encoder:
        lora_params += patch_text_encoder_dual_lora(text_encoder, rank_phase1, alpha_phase1, rank_phase2, alpha_phase2)

    # Load Phase 1 LoRA weights
    load_phase1_lora_weights(unet, text_encoder, load_phase1_path)

    return lora_params


Now this is where we would create our pipeline, load our model into pipeline, load our tokenizer, load our trigger word into the tokenizer, load our unet and patch it, load our text_encoder and patch it. But the important thing is that we need to freeze our params of unet and text_encoder before patching so we accidently don't train them or update them and we only train or update our newly patched layers. We will also freeze phase 1 weights here for unet and text encoder so they don't get updated or trained accidently. We are using AdamW optimer here originally used with base our base model. We would also load our dataset into the pipeline and prepare everything for training.

In [None]:
import os, torch
from accelerate import Accelerator
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from transformers import AutoTokenizer, get_scheduler
from torch.utils.data import DataLoader

# ──────────────────────────────────────────────────────────────
# 1. Accelerator & device setup
# ──────────────────────────────────────────────────────────────
accelerator = Accelerator(split_batches=True)
device = accelerator.device

# ──────────────────────────────────────────────────────────────
# 2. Load Stable Diffusion base pipeline
# ──────────────────────────────────────────────────────────────
pipe = StableDiffusionPipeline.from_pretrained(
    cfg.model_path,
    torch_dtype=torch.float16
).to(device)

# ──────────────────────────────────────────────────────────────
# 3. Add trigger word (tokenize before dataset creation!)
# ──────────────────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(cfg.model_path, subfolder="tokenizer")
trigger_token = cfg.trigger_word

if len(tokenizer.tokenize(trigger_token)) > 1:
    tokenizer.add_tokens([trigger_token])
    pipe.text_encoder.resize_token_embeddings(len(tokenizer))
    with torch.no_grad():
        emb = pipe.text_encoder.get_input_embeddings()
        new_id = tokenizer.convert_tokens_to_ids(trigger_token)
        base_id = tokenizer.convert_tokens_to_ids("person")
        emb.weight[new_id] = emb.weight[base_id].clone()
    print(f"✅ Added custom token '{trigger_token}' (id {new_id})")
else:
    print(f"✅ '{trigger_token}' is already a single token")

pipe.tokenizer = tokenizer

# ──────────────────────────────────────────────────────────────
# 4. Load UNet and freeze base weights
# ──────────────────────────────────────────────────────────────
unet = UNet2DConditionModel.from_pretrained(
    cfg.model_path,
    subfolder="unet",
    torch_dtype=torch.float16
).to(device)

# Freeze base model
for p in unet.parameters(): p.requires_grad = False
for p in pipe.text_encoder.parameters(): p.requires_grad = False

# ──────────────────────────────────────────────────────────────
# 5. Patch UNet and text encoder with dual-phase LoRA
# ──────────────────────────────────────────────────────────────
lora_params = setup_phase2(
    unet               = unet,
    text_encoder       = pipe.text_encoder,
    rank_phase1        = phase1_cfg.lora_r,
    alpha_phase1       = phase1_cfg.lora_alpha,
    rank_phase2        = cfg.lora_r,
    alpha_phase2       = cfg.lora_alpha,
    train_text_encoder = cfg.train_text_encoder,
    load_phase1_path   = cfg.phase1_lora_path,
)

if not lora_params:
    raise RuntimeError("❌ No LoRA parameters collected!")

print(f"🔍 LoRA trainable parameters: {sum(p.numel() for p in lora_params):,}")

# ──────────────────────────────────────────────────────────────
# 6. Optimizer and LR Scheduler
# ──────────────────────────────────────────────────────────────
optimizer = torch.optim.AdamW(lora_params, lr=cfg.learning_rate, fused=False)

lr_scheduler = get_scheduler(
    cfg.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=cfg.lr_warmup_steps,
    num_training_steps=cfg.max_train_steps,
)

# ──────────────────────────────────────────────────────────────
# 7. Dataset and Dataloader
# ──────────────────────────────────────────────────────────────
dataset = ImageCaptionDataset(cfg.dataset_dir, tokenizer, cfg.resolution)
dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

# ──────────────────────────────────────────────────────────────
# 8. Final Prep
# ──────────────────────────────────────────────────────────────
pipe.text_encoder.to(device, dtype=torch.float16)
pipe.text_encoder.train()
unet.train()

pipe.text_encoder, unet, optimizer, dataloader = accelerator.prepare(pipe.text_encoder, unet, optimizer, dataloader)


print("🚀 Setup complete – ready to train Phase 2 LoRA adapters.")


Sanity check for how many parameters are trainable out of total parameters.

In [None]:
# Count total and trainable params across both modules
total_params = sum(p.numel() for p in unet.parameters()) + sum(p.numel() for p in pipe.text_encoder.parameters())
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) + \
                   sum(p.numel() for p in pipe.text_encoder.parameters() if p.requires_grad)

trainable_percent = (trainable_params / total_params) * 100
print(f"✅ Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_percent:.4f}%)")


Loading the vae used in the base model. The vae is what encodes and decodes our latent spaces into images.

In [None]:
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32)
pipe.vae = vae

pipe.vae.to(accelerator.device, dtype=torch.float32)

In [None]:
print(accelerator.device)  # Should print: cuda


Sanity check that if our paramters contain NAN values before our training begins.

In [None]:
for name, param in unet.named_parameters():
    if torch.isnan(param).any():
        print(f"NaN detected in UNet parameter: {name}")
        break

Now we would create functions for checking if our trigger_word is in the tokenizer, a function to generate sample images during the model training for checks if our model is learning or not, whether it is overfitting or underfitting. Whether if we need to stop our training, whether we need to change our configurations settings.

In [None]:
import os
import random
import torch
from diffusers import DDIMScheduler
from torch.cuda.amp import autocast  # More up-to-date than torch.autocast
from torchvision.utils import save_image  # Optional if using tensor images

# ✅ Optional: Reproducibility
def seed_everything(seed=151101):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ✅ Optional: Disable NSFW filter (for local testing)
def disable_safety(pipe):
    if hasattr(pipe, "safety_checker"):
        pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
        print("🛑 NSFW safety checker disabled.")
    else:
        print("⚠️ No safety checker found in pipeline.")

# ✅ Fix tokenizer trigger word if needed
def ensure_trigger_token(pipe, trigger_word):
    tokens = pipe.tokenizer.tokenize(trigger_word)
    if len(tokens) > 1:
        print(f"⚠️ Trigger word '{trigger_word}' splits into: {tokens}. Fixing...")
        pipe.tokenizer.add_tokens([trigger_word])
        pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))

        with torch.no_grad():
            embeddings = pipe.text_encoder.get_input_embeddings()
            new_id = pipe.tokenizer.convert_tokens_to_ids(trigger_word)
            base_id = pipe.tokenizer.convert_tokens_to_ids("person")
            embeddings.weight[new_id] = embeddings.weight[base_id].clone()

        print(f"✅ Re-initialized token embedding for '{trigger_word}'")
    else:
        print(f"✅ Trigger word '{trigger_word}' is already tokenized correctly: {tokens}")

# ✅ Generate and save preview images
def generate_sample_image(step, save_path, prompt=None, negative_prompt=None, seed=151101):
    print(f"\n🎨 Generating preview at step {step}...")

    seed_everything(seed)
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Restore UNet (EMA or current)
    try:
        if 'ema_unet' in globals() and ema_unet is not None:
            pipe.unet = ema_unet
            print("📦 Using EMA UNet.")
        else:
            pipe.unet = accelerator.unwrap_model(unet)
            print("📦 Using current LoRA-patched UNet.")
    except Exception as e:
        print(f"⚠️ Error restoring UNet: {e}")

    # Restore text encoder if needed
    if cfg.train_text_encoder:
        try:
            pipe.text_encoder = accelerator.unwrap_model(pipe.text_encoder)
            print("🧠 Restored LoRA-patched text encoder.")
        except Exception as e:
            print(f"⚠️ Failed to unwrap text encoder: {e}")

    pipe.unet.eval()
    pipe.text_encoder.eval()

    disable_safety(pipe)
    ensure_trigger_token(pipe, cfg.trigger_word)

    # Replace scheduler with DDIM for faster inference
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    # Default prompt if not given
    if not prompt:
        prompt = f"{cfg.trigger_word} your prompt"
    if not negative_prompt:
        negative_prompt = (
            "blurry, low resolution, jpeg artifacts, cropped, watermark, distorted face, extra limbs, poorly drawn, cartoon"
        )

    # Generate
    with torch.no_grad(), autocast("cuda"):
        result = pipe(
            prompt=[prompt] * 4,
            negative_prompt=[negative_prompt] * 4,
            num_inference_steps=30,
            guidance_scale=6.0,
            height=cfg.resolution,
            width=cfg.resolution,
            generator=generator,
        )

    # Save
    os.makedirs(save_path, exist_ok=True)
    for i, image in enumerate(result.images):
        save_path_i = os.path.join(save_path, f"preview_step_{step}_{i+1}.png")
        image.save(save_path_i)
        print(f"✅ Saved: {save_path_i}")


Our main training loop for training our model. 

In [None]:
from torch.amp import autocast  # ⚠️ Correct to torch.cuda.amp.autocast
from safetensors.torch import save_file
from diffusers.models.attention_processor import LoRAAttnProcessor
from PIL import Image
import torch, os
from tqdm import tqdm

# ✅ LoRA layers to float32 for stability
# for module in unet.modules():
#     if isinstance(module, LoRAAttnProcessor):
#         for param in module.parameters():
#             param.data = param.data.to(torch.float32)

global_step = 0
unet.train()

for epoch in range(100):
    for step, batch in enumerate(tqdm(dataloader)):
        with accelerator.accumulate(unet):
            pixel_values = batch["pixel_values"].to(accelerator.device, dtype=torch.float32)
            input_ids = batch["input_ids"].to(accelerator.device)

            with torch.no_grad(), torch.autocast("cuda", dtype=torch.float16):
                pixel_values = pixel_values.to(dtype=torch.float16)  # ✅ convert input
                latents = pipe.vae.encode(pixel_values).latent_dist.sample()
                latents = latents.clamp(-10, 10)
                latents = latents * 0.18215
                latents = latents.to(accelerator.device, dtype=torch.float16)

            noise = 0.9 * torch.randn_like(latents)
            max_timestep = 300 if global_step < 100 else pipe.scheduler.config.num_train_timesteps
            timesteps = torch.randint(0, max_timestep, (latents.shape[0],), device=latents.device).long()
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

            # ✅ NaN check for LoRA params
            for name, param in unet.named_parameters():
                if torch.isnan(param).any() or torch.isinf(param).any():
                    print(f"[❌ NaN/Inf detected] in: {name}")
                    break

            with torch.no_grad():
                encoder_hidden_states = pipe.text_encoder(input_ids)[0]
                encoder_hidden_states = encoder_hidden_states.to(accelerator.device, dtype=torch.float16)

            # ✅ Use correct autocast context
            with autocast("cuda", dtype=torch.float32):
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                ).sample

            if torch.isnan(model_pred).any():
                print("❌ NaN in model_pred!")
                continue

            noise = noise.to(model_pred.dtype)
            loss = torch.nn.functional.l1_loss(model_pred, noise)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"⚠️ Skipping invalid loss at step {global_step}")
                continue

            accelerator.backward(loss)
            
            if accelerator.sync_gradients:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                if global_step % cfg.log_every_n_steps == 0:
                    print(f"Step {global_step} | Loss: {loss.item():.4f}")

                # ✅ Save Combined LoRA Weights
                if global_step % cfg.save_every_n_steps == 0:
                    save_path = os.path.join(cfg.output_dir, f"step_{global_step}")
                    os.makedirs(save_path, exist_ok=True)

                    def extract_lora_weights(state_dict):
                        return {k: v for k, v in state_dict.items() if "lora" in k.lower()}

                    unet_lora = extract_lora_weights(accelerator.unwrap_model(unet).state_dict())
                    text_lora = extract_lora_weights(accelerator.unwrap_model(pipe.text_encoder).state_dict())
                    combined_lora = {**unet_lora, **text_lora}

                    save_file(combined_lora, os.path.join(save_path, "lora_only.safetensors"))
                    print(f"✅ Saved combined LoRA weights → {save_path}/lora_only.safetensors")

                # 🔍 Preview
                if global_step % cfg.generate_every_n_steps == 0:
                    preview_dir = os.path.join("/kaggle/working/gen_images", f"step_{global_step}")
                    generate_sample_image(global_step, preview_dir)

                
                global_step += 1

        if global_step >= cfg.max_train_steps:
            break
    if global_step >= cfg.max_train_steps:
        break


The code for saving both phase 1 and phase 2 weights at the end of the training into a single file. It contains unet, text_encoder weights into a single lora.

In [None]:
# # Final save
# # Save both UNet and text encoder
# unet_state_dict = accelerator.unwrap_model(unet).state_dict()
# text_encoder_state_dict = accelerator.unwrap_model(pipe.text_encoder).state_dict()

# # Combine and save as safetensors
# save_file({**unet_state_dict, **text_encoder_state_dict}, "full_model_lora.safetensors")