<a href="https://www.kaggle.com/code/sukhmansaran/dual-phase-lora-image-generation-sd-models?scriptVersionId=254720593" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

Please read README.md file on my github for more information.

https://github.com/sukhmansaran/fine-tuning-stable-diffusion-models-lora-dreambooth

# DreamBooth + LoRA Inference Pipeline for Stable Diffusion Models

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor
from accelerate import Accelerator

In [None]:
# HF TOKEN
from huggingface_hub import login
login("your_token")

In [None]:
# Downloading Realistic Vision V5
from huggingface_hub import hf_hub_download
import os

model_dir = "your_dir_for_saving_downloaded_base_model"
os.makedirs(model_dir, exist_ok=True)

ckpt_path = hf_hub_download(
    repo_id="SG161222/Realistic_Vision_V5.1_noVAE",
    filename="Realistic_Vision_V5.1.safetensors",
    local_dir=model_dir,
)

This is inference file necessary. The file has to be this exact file as it was used for stable diffusion 1.5 version you cannot use any other file here.

In [None]:
# Downloading v1-inference.yaml
!wget -O v1-inference.yaml https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml

In [None]:
# Convert to diffusers format
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
import os

safetensors_path = "downloaded_base_model_path"
output_dir = "your_dir_for_saving_converted_base_model"

converted_pipeline = download_from_original_stable_diffusion_ckpt(
    safetensors_path,
    "/v1-inference.yaml",  # Must match SD1.5 or SD2.x
    from_safetensors=True,
    extract_ema=True,
    device="cuda"  # or "cpu"
)

# saving
converted_pipeline.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")

Model page: https://huggingface.co/SG161222/Realistic_Vision_V5.1_noVAE

⚠️ If the generated code snippets do not work, please open an issue on either the [model repo](https://huggingface.co/SG161222/Realistic_Vision_V5.1_noVAE)
			and/or on [huggingface.js](https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries-snippets.ts) 🙏

In [None]:
# Paths
base_model_path = "your_base_model_path"  # your base model dir
output_dir = "./outputs"
os.makedirs(output_dir, exist_ok=True)

Creating the pipeline for loading the model and then using it for inference.

In [None]:
from diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler

# Load pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    safety_checker=None,
    requires_safety_checker=False,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

The below code patches and loads trained weights for unet and text_encoder into the Stable Diffusion model used during the fine tuning process. This patches and loads both phase 1 and phase 2 weights. You just have to provide the correct rank and alpha used in both the phases using both. If one thing wrong will result into error and not patching the lora.

In [None]:
import torch
import torch.nn as nn
from safetensors.torch import load_file

# Dual-phase LoRA wrapper
class LoRALinearDualPhase(nn.Module):
    def __init__(self, linear, rank1, alpha1, rank2, alpha2):
        super().__init__()
        self.linear = linear

        self.rank1 = rank1
        self.alpha1 = alpha1
        self.rank2 = rank2
        self.alpha2 = alpha2

        self.scale1 = alpha1 / rank1
        self.scale2 = alpha2 / rank2
        self.phase2_weight = 1.0  # default; can be changed during inference

        self.lora_down = nn.Linear(linear.in_features, rank1, bias=False)
        self.lora_up = nn.Linear(rank1, linear.out_features, bias=False)
        self.lora2_down = nn.Linear(linear.in_features, rank2, bias=False)
        self.lora2_up = nn.Linear(rank2, linear.out_features, bias=False)

        nn.init.zeros_(self.lora_up.weight)
        nn.init.zeros_(self.lora2_up.weight)
        nn.init.kaiming_uniform_(self.lora_down.weight, a=5**0.5)
        nn.init.kaiming_uniform_(self.lora2_down.weight, a=5**0.5)

        # Ensure correct device/dtype
        for layer in [self.lora_down, self.lora_up, self.lora2_down, self.lora2_up]:
            layer.to(linear.weight.device, dtype=linear.weight.dtype)

    def forward(self, x):
        out = self.linear(x)

        lora1 = self.lora_up(self.lora_down(x)) * (self.alpha1 / self.rank1)
        lora2 = self.lora2_up(self.lora2_down(x)) * (self.alpha2 / self.rank2)

        blended = (1 - self.phase2_weight) * lora1 + self.phase2_weight * lora2
        return out + blended

# Patch UNet cross-attn with dual-phase LoRA
def patch_unet_cross_attn_with_dual_lora(unet, rank1, alpha1, rank2, alpha2):
    for module in unet.modules():
        if hasattr(module, 'to_q') and hasattr(module, 'to_k') and hasattr(module, 'to_v'):
            for attr in ['to_q', 'to_k', 'to_v']:
                original = getattr(module, attr)
                if isinstance(original, nn.Linear) and not isinstance(original, LoRALinearDualPhase):
                    dual_lora = LoRALinearDualPhase(original, rank1, alpha1, rank2, alpha2)
                    setattr(module, attr, dual_lora)

            # Handle to_out[0] if it's a linear layer
            if hasattr(module, 'to_out') and isinstance(module.to_out, nn.ModuleList):
                if isinstance(module.to_out[0], nn.Linear) and not isinstance(module.to_out[0], LoRALinearDualPhase):
                    original = module.to_out[0]
                    dual_lora = LoRALinearDualPhase(original, rank1, alpha1, rank2, alpha2)
                    module.to_out[0] = dual_lora
    print("Patched UNet with Dual-Phase LoRA")

# Load dual-phase LoRA weights into model
def apply_dual_lora_weights(model, lora_path):
    print("Applying Dual-Phase LoRA weights...")
    state_dict = load_file(lora_path, device="cuda")
    missing = []

    for name, param in model.named_parameters():
        if "lora" in name:
            if name in state_dict:
                param.data.copy_(state_dict[name])
            else:
                missing.append(name)

    print("LoRA weights loaded.")
    if missing:
        print("Missing LoRA keys:", missing)

# Patch CLIP TextEncoder with dual-phase LoRA
def patch_text_encoder_with_dual_lora(text_encoder, rank1, alpha1, rank2, alpha2):
    total_patched = 0
    for module in text_encoder.modules():
        if all(hasattr(module, attr) for attr in ['q_proj', 'k_proj', 'v_proj', 'out_proj']):
            for attr in ['q_proj', 'k_proj', 'v_proj', 'out_proj']:
                original = getattr(module, attr)
                if isinstance(original, nn.Linear) and not isinstance(original, LoRALinearDualPhase):
                    dual_lora = LoRALinearDualPhase(original, rank1, alpha1, rank2, alpha2)
                    setattr(module, attr, dual_lora)
                    total_patched += 1

    if total_patched > 0:
        print(f"Patched {total_patched} layers in TextEncoder with Dual-Phase LoRA")
    else:
        print("No layers patched in TextEncoder")

# Load dual-phase LoRA weights into TextEncoder
def apply_dual_lora_weights_to_text_encoder(text_encoder, lora_state_dict):
    missing = []
    for name, param in text_encoder.named_parameters():
        if "lora" in name:
            if name in lora_state_dict:
                param.data.copy_(lora_state_dict[name])
            else:
                missing.append(name)
    print("TextEncoder LoRA weights loaded.")
    if missing:
        print("Missing text encoder LoRA keys:", missing)


This code below patches for both phase 1 and phase 2 weights that can be loaded into our pipeline using the ranks and alphas used during fine tuning. 

**Note:** The phase 1 rank and alpha along with phase 2 rank and alpha must be same as used during the fine tuning process.

In [None]:
# Patch UNet and Text Encoder
patch_unet_cross_attn_with_dual_lora(pipe.unet, rank1=4, alpha1=8, rank2=4, alpha2=8)
patch_text_encoder_with_dual_lora(pipe.text_encoder, rank1=4, alpha1=8, rank2=4, alpha2=8)

This code loads our phase 1 and phase 2 weights into the pipeline.

In [None]:
from safetensors.torch import load_file

state_dict = load_file("your_lora_weights_path", device="cuda")

apply_dual_lora_weights(pipe.unet, "your_lora_weights_path")
apply_dual_lora_weights_to_text_encoder(pipe.text_encoder, state_dict)


This code helps set weight for phase 2 during the inference. By default both the weights are set to equal weightage during inference.

In [None]:
def set_phase_weight(model, phase2_weight: float):
    """
    Set the blending factor between Phase 1 and Phase 2 LoRA weights.
    Args:
        model: The model (e.g., unet or text_encoder) patched with LoRALinearDualPhase.
        phase2_weight (float): Value between 0.0 (Phase 1 only) and 1.0 (Phase 2 only).
    """
    for module in model.modules():
        if isinstance(module, LoRALinearDualPhase):
            module.phase2_weight = phase2_weight


The inference process make sure to use your trigger_word here. You can change your prompt, negative prompt, guidance scale, num_inference_steps, height and width. 

Guidance scale recommended is 5.5-7.5

Num inference steps recommeded are 30-50

You must use standard images sizes don't use custom sizes as you like as the model may give unnatural, messy results.

In [None]:
# Define inference settings
prompt = f"{trigger_word}, your prompt"
negative_prompt = (
    "blurry, low resolution, grainy, overexposed, underexposed, bad lighting, jpeg artifacts, glitch, "
    "cropped, out of frame, watermark, duplicate, poorly drawn face, asymmetrical face, deformed features, "
    "bad skin texture, doll-like face, bad eyes, mutated hands, extra fingers, unrealistic proportions, "
    "cartoon, anime, illustration, painting, horror, morbid"
)
guidance_scale = 6.5
num_inference_steps = 30
height = 768
width = 768

# set phase 2 weightage during inference here (e.g. phase2_weightage=0.3 leads to phase1_weightage=0.7)
set_phase_weight(pipe.unet, phase2_weight=0.3)
set_phase_weight(pipe.text_encoder, phase2_weight=0.3)

This code generates and saves your images.

In [None]:
# Generate and save images
image = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale
).images[0]

image.save(os.path.join(output_dir, "output.png"))
# image.show()
