<a href="https://colab.research.google.com/github/syedfahdali/htmlfrontpage/blob/main/floral_text_generation_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning Stable Diffusion for Floral Text Style using LoRA

This notebook outlines the process of fine-tuning a Stable Diffusion model using Low-Rank Adaptation (LoRA) to learn a specific "floral text" style from a small dataset of images.

The process involves:
1. Setting up the environment and installing necessary libraries.
2. Loading the base Stable Diffusion model.
3. Preparing a dataset from images in a specified folder, using filenames as captions.
4. Configuring and running the LoRA training process.
5. Saving the trained LoRA weights.
6. Evaluating the fine-tuned model by generating images using the base model combined with the LoRA weights.

In [3]:
# prompt: write zip extraction code for Images.zip

import zipfile
import os

def extract_zip(zip_filepath, extract_dir):
    """Extracts a zip file to a specified directory.

    Args:
        zip_filepath: Path to the zip file.
        extract_dir: Directory to extract the zip file contents to.
    """
    with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)

# Example usage:
zip_file_path = './Images.zip'  # Replace with your zip file path
extract_directory = 'images' # Replace with your desired extraction directory

if not os.path.exists(extract_directory):
    os.makedirs(extract_directory)

extract_zip(zip_file_path, extract_directory)
print(f"Successfully extracted {zip_file_path} to {extract_directory}")


Successfully extracted ./Images.zip to images


## 1. Setup: Installations and Imports

In [None]:
# Install necessary libraries
# We need diffusers, transformers, accelerate, bitsandbytes (for 8-bit optimizer), peft (for LoRA), and datasets
%pip install -q diffusers transformers accelerate bitsandbytes ftfy Pillow # Base dependencies
%pip install -q peft datasets # LoRA and dataset handling

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m36.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch
import os
import re
from PIL import Image
from pathlib import Path
import numpy as np
import math

from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import Dataset as HFDataset # Renamed to avoid conflict with torch Dataset
import accelerate
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.auto import tqdm
from diffusers.optimization import get_scheduler
from peft import LoraConfig, get_peft_model
from diffusers.utils import make_image_grid

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# --- Training Configuration ---
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1-base"
revision = None # Use main branch

# --- Your Data ---
# Assumes images are in ../images relative to this notebook's directory
# Corrected path based on user feedback
instance_data_dir = "./images/Images"
output_dir = "./models/lora_floral_text" # Where to save LoRA weights

# --- LoRA Parameters ---
lora_rank = 16 # Rank of the LoRA matrices. Higher rank means more parameters, potentially more expressive but prone to overfitting.
lora_alpha = lora_rank # Often set equal to rank
lora_dropout = 0.1

# --- Training Parameters ---
resolution = 512 # Resolution for input images. Must be >= 512.
center_crop = True # Whether to center crop the input images to the resolution.
train_batch_size = 1 # Batch size (per device) for training. Reduce if OOM.
num_train_epochs = 100 # Number of training epochs. Adjust based on results (start with more for small datasets).
max_train_steps = None # If set, overrides num_train_epochs.
learning_rate = 1e-4 # Initial learning rate.
scale_lr = False # Scale learning rate by sqrt(gradient_accumulation_steps * train_batch_size * num_gpus).
lr_scheduler_name = "constant" # Choose from "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
lr_warmup_steps = 0
use_8bit_adam = True # Whether to use 8-bit AdamW optimizer (requires bitsandbytes).
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-2
adam_epsilon = 1e-08
max_grad_norm = 1.0 # Max gradient norm for clipping.
gradient_accumulation_steps = 1 # Number of updates steps to accumulate before performing a backward/update pass.
mixed_precision = "fp16" # Choose: "no", "fp16", "bf16".
allow_tf32 = True # Allow TF32 on Ampere GPUs for potentially faster training.

# --- Other ---
seed = 42
checkpointing_steps = 200 # Save a checkpoint of the training state every X updates.
validation_prompt = "A floral design with the text 'Hello World'" # Prompt to use for generating validation images during training
num_validation_images = 4
validation_epochs = 20 # Run validation every X epochs

os.makedirs(output_dir, exist_ok=True)

# Create a dictionary of hyperparameters for logging
config_to_log = {
    "pretrained_model_name_or_path": pretrained_model_name_or_path,
    "revision": revision,
    "instance_data_dir": instance_data_dir,
    "output_dir": output_dir,
    "lora_rank": lora_rank,
    "lora_alpha": lora_alpha,
    "lora_dropout": lora_dropout,
    "resolution": resolution,
    "center_crop": center_crop,
    "train_batch_size": train_batch_size,
    "num_train_epochs": num_train_epochs, # Note: This might be recalculated later
    "max_train_steps": max_train_steps, # Note: This might be recalculated later
    "learning_rate": learning_rate,
    "scale_lr": scale_lr,
    "lr_scheduler_name": lr_scheduler_name, # Log the name
    "lr_warmup_steps": lr_warmup_steps,
    "use_8bit_adam": use_8bit_adam,
    "adam_beta1": adam_beta1,
    "adam_beta2": adam_beta2,
    "adam_weight_decay": adam_weight_decay,
    "adam_epsilon": adam_epsilon,
    "max_grad_norm": max_grad_norm,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "mixed_precision": mixed_precision,
    "allow_tf32": allow_tf32,
    "seed": seed,
    # Add validation params too for completeness
    "validation_prompt": validation_prompt,
    "num_validation_images": num_validation_images,
    "validation_epochs": validation_epochs,
    "checkpointing_steps": checkpointing_steps,
}

## 3. Load Models and Tokenizer

In [None]:
# Load the tokenizer
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
    revision=revision,
)

# Load the text encoder
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="text_encoder",
    revision=revision,
)

# Load the VAE
# We use the VAE from the original model, no fine-tuning needed here
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="vae",
    revision=revision,
)

# Load the UNet model
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="unet",
    revision=revision,
)

# Freeze VAE and text_encoder parameters
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

# Set UNet parameters to be trainable (initially)
unet.train()

# Move models to device
vae.to(device, dtype=torch.float16) # VAE often works well in float16
text_encoder.to(device)
unet.to(device)

print("Models loaded.")

## 4. Prepare Dataset

In [None]:
# Custom Dataset Class
class ImageCaptionDataset(Dataset):
    def __init__(self, data_dir, tokenizer, size=512, center_crop=True):
        self.data_dir = Path(data_dir)
        self.tokenizer = tokenizer
        self.size = size
        self.center_crop = center_crop

        self.image_paths = [p for p in self.data_dir.iterdir() if p.is_file() and p.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']]
        print(f"Found {len(self.image_paths)} images in {data_dir}")

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]), # Normalize to [-1, 1]
            ]
        )

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        # Derive caption from filename (remove extension, replace underscores/hyphens with spaces)
        caption = img_path.stem.replace('_', ' ').replace('-', ' ')
        # Simple cleanup: remove extra spaces
        caption = re.sub(r'\s+', ' ', caption).strip()

        try:
            image = Image.open(img_path).convert("RGB")
            pixel_values = self.image_transforms(image)
        except Exception as e:
            print(f"Error loading or processing image {img_path}: {e}")
            # Return dummy data or skip? For simplicity, returning None here.
            # A real implementation might handle this more robustly.
            return None

        # Tokenize caption
        input_ids = self.tokenizer(
            caption, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt"
        ).input_ids

        return {"pixel_values": pixel_values, "input_ids": input_ids.squeeze(0)} # Remove batch dim from input_ids

# Create the dataset
train_dataset = ImageCaptionDataset(
    instance_data_dir,
    tokenizer,
    size=resolution,
    center_crop=center_crop
)

# Collate function to handle potential None values from dataset errors
def collate_fn(examples):
    examples = [e for e in examples if e is not None] # Filter out None values
    if not examples:
        return None # Or raise an error if no valid data in batch

    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

# Create DataLoader
train_dataloader = DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

print(f"Dataset size: {len(train_dataset)}")
print(f"Dataloader created. Batch size: {train_batch_size}")

## 5. Configure LoRA

In [None]:
# Configure LoRA for the UNet
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"], # Common target modules for SD LoRA
    lora_dropout=lora_dropout,
    bias="none", # Usually set to none for LoRA
)

# Add LoRA adapters to the UNet model
unet = get_peft_model(unet, lora_config)

# Print trainable parameters
unet.print_trainable_parameters()

print("LoRA configured for UNet.")

## 6. Training Setup

In [None]:
# Set up accelerator
accelerator = accelerate.Accelerator(
    gradient_accumulation_steps=gradient_accumulation_steps,
    mixed_precision=mixed_precision,
    log_with="tensorboard",
    project_dir=os.path.join(output_dir, "logs")
)

# Use 8-bit AdamW if enabled
if use_8bit_adam:
    try:
        import bitsandbytes as bnb
    except ImportError:
        raise ImportError("Please install bitsandbytes to use 8-bit AdamW. `pip install bitsandbytes`")
    optimizer_cls = bnb.optim.AdamW8bit
else:
    optimizer_cls = torch.optim.AdamW

# Optimizer targets only the LoRA parameters in the UNet
optimizer = optimizer_cls(
    unet.parameters(), # Only optimize UNet LoRA parameters
    lr=learning_rate,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

# Calculate total training steps
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if max_train_steps is None:
    max_train_steps = num_train_epochs * num_update_steps_per_epoch
else:
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

# Learning rate scheduler
# Explicitly fetch the scheduler name from the config dictionary
current_scheduler_name = config_to_log['lr_scheduler_name']
lr_scheduler_obj = get_scheduler(
    current_scheduler_name, # Pass the fetched name string
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
    num_training_steps=max_train_steps * gradient_accumulation_steps,
)

# Prepare everything with accelerator
unet, optimizer, train_dataloader, lr_scheduler_obj = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler_obj # Pass the scheduler object
)

# We need to recalculate our total training steps as the size of the dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if max_train_steps is None:
    max_train_steps = num_train_epochs * num_update_steps_per_epoch
else:
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

# Move text_encoder and vae to GPU (if not already there) - they aren't prepared by accelerator
text_encoder.to(accelerator.device)
vae.to(accelerator.device)

# Keep vae and text_encoder in eval model as we don't train them
vae.eval()
text_encoder.eval()

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    # Use the specific config dictionary for logging
    accelerator.init_trackers("lora_floral_text", config=config_to_log)

print(f"***** Running training *****")
print(f"  Num examples = {len(train_dataset)}")
print(f"  Num Epochs = {num_train_epochs}")
print(f"  Instantaneous batch size per device = {train_batch_size}")
print(f"  Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * accelerator.num_processes * gradient_accumulation_steps}")
print(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f"  Total optimization steps = {max_train_steps}")

## 7. Training Loop

In [None]:
global_step = 0
first_epoch = 0

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")

progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

for epoch in range(first_epoch, num_train_epochs):
    unet.train()
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        if batch is None:
            print(f"Skipping step {step} due to batch error.")
            continue

        with accelerator.accumulate(unet):
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Define loss target
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Model prediction and loss
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
            train_loss += avg_loss.item() / gradient_accumulation_steps

            # Backward
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = unet.parameters()
                accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)
            optimizer.step()
            lr_scheduler_obj.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            accelerator.log({"train_loss": train_loss}, step=global_step)
            train_loss = 0.0

            if global_step % checkpointing_steps == 0:
                if accelerator.is_main_process:
                    save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    print(f"Saved state to {save_path}")

        logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler_obj.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

        if global_step >= max_train_steps:
            break

    # Validation loop
    if accelerator.is_main_process:
        if epoch % validation_epochs == 0 or epoch == num_train_epochs - 1:
            print(f"\nRunning validation... Epoch {epoch}")
            unet_inference = accelerator.unwrap_model(unet)
            pipeline = StableDiffusionPipeline.from_pretrained(
                pretrained_model_name_or_path,
                unet=unet_inference,
                text_encoder=text_encoder,
                vae=vae,
                revision=revision,
                torch_dtype=torch.float16,
            )
            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
            pipeline = pipeline.to(accelerator.device)
            pipeline.set_progress_bar_config(disable=True)

            generator = torch.Generator(device=accelerator.device).manual_seed(seed) if seed else None
            images = []
            for _ in range(num_validation_images):
                with torch.autocast("cuda"):
                    image = pipeline(validation_prompt, num_inference_steps=25, generator=generator).images[0]
                images.append(image)

            # Create image grid and log it
            grid = make_image_grid(images, rows=1, cols=num_validation_images)

            # ✅ FIXED image logging
            from torchvision import transforms
            writer = accelerator.get_tracker("tensorboard").writer
            writer.add_image("validation", transforms.ToTensor()(grid), epoch)

            # Save grid locally
            grid.save(os.path.join(output_dir, f"validation_epoch_{epoch}.png"))

            torch.cuda.empty_cache()

# End training
accelerator.wait_for_everyone()

# Save LoRA weights
if accelerator.is_main_process:
    unet = accelerator.unwrap_model(unet)
    unet.save_pretrained(output_dir)
    print(f"LoRA weights saved to {output_dir}")

accelerator.end_training()
print("Training finished.")


## 8. Evaluation / Example Generation with LoRA

Load the base model and attach the trained LoRA weights to generate images.

In [None]:
from diffusers import UNet2DConditionModel
unet = accelerator.unwrap_model(unet)
UNet2DConditionModel.save_attn_procs(unet, output_dir)
print(f"✅ LoRA weights saved correctly to {output_dir}")


In [None]:
import torch
import re
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

# --- Setup ---
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
output_dir = "./models/lora_floral_text"
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 42

# --- Load Base ---
pipe = StableDiffusionPipeline.from_pretrained(
    pretrained_model_name_or_path,
    torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)

# --- Load LoRA Adapter with specific weight file ---
# Replace with the exact name of the LoRA weight file if different
lora_weight_path = f"{output_dir}/pytorch_lora_weights.bin"
print(f"Loading LoRA adapter from: {lora_weight_path}")
pipe.load_lora_weights(output_dir, weight_name="pytorch_lora_weights.bin")
print("LoRA adapter loaded ✅")

# --- Prompts ---
test_prompts = [
    'A floral design with the text "Welcome" in elegant script',
    'The word "Love" made of pink roses and green vines',
    '"Shine" text with sunflowers and bright yellow petals',
    '"Dream" written in cursive with lavender flowers'
]

num_inference_steps = 30
guidance_scale = 7.5

# --- Generate Images ---
for prompt in test_prompts:
    print(f"\n--- Generating: {prompt} ---")
    generator = torch.Generator(device=device).manual_seed(seed)

    with torch.autocast(device):
        image = pipe(
            prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator
        ).images[0]

    try:
        from IPython.display import display
        display(image)
    except ImportError:
        safe_prompt = re.sub(r'[^a-zA-Z0-9_]+', '_', prompt)[:50]
        image.save(f"{safe_prompt}_lora_output.png")
        print(f"Saved to {safe_prompt}_lora_output.png")


In [None]:
!pip show diffusers