In [None]:
%pip install diffusers
%pip install peft

In [None]:
!pip install protobuf==3.20.*

In [None]:
from pathlib import Path
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import login
from peft import LoraConfig
import torch
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import math
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from peft.utils import get_peft_model_state_dict
from diffusers.utils import convert_state_dict_to_diffusers
from datasets import load_dataset
from functools import partial
from PIL import Image
from kaggle_secrets import UserSecretsClient
from torch.utils.data import Dataset
import pandas as pd
from pydantic import BaseModel
from diffusers.training_utils import compute_snr
import os

In [None]:
login(token=UserSecretsClient().get_secret("Longphanryu"))

In [None]:
def get_models(model_name: str, dtype=torch.float16):
    """
    Load tokenizer, text_encoder, vae, scheduler, unet.
    Trả về các model ở dtype (thường float16 cho T4), nhưng chúng ta sẽ chuyển
    trainable params sang float32 sau khi thêm LoRA.
    """
    tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
    # Sử dụng torch_dtype nếu environment hỗ trợ (giảm VRAM)
    text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
    vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
    scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
    unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")

    # Chuyển các module lớn sang dtype để giảm VRAM (đặc biệt trên T4)
    text_encoder.to(dtype=dtype)
    vae.to(dtype=dtype)
    unet.to(dtype=dtype)

    return tokenizer, text_encoder, vae, scheduler, unet

In [None]:
class FashionDataset(Dataset):
    def __init__(self, csv_path: str, image_folder: str, tokenizer: CLIPTokenizer, image_size: int = 512):
        self.image_folder = image_folder
        self.tokenizer = tokenizer

        # Load CSV
        self.df = pd.read_csv(csv_path)

        # 1) Lọc caption null / empty
        self.df = self.df.dropna(subset=["caption"])
        self.df = self.df[self.df["caption"].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)]

        # 2) Lọc ảnh không tồn tại
        self.df = self.df[self.df["image"].apply(lambda x: os.path.exists(os.path.join(image_folder, x)))]

        # 3) Reset index
        self.df = self.df.reset_index(drop=True)

        # 4) Transforms - sửa normalize cho 3 kênh RGB
        self.transforms = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_folder, row["image"])
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            raise FileNotFoundError(f"Cannot open image: {img_path}. Exception: {e}")

        image = self.transforms(image)
        caption = row["caption"]

        input_ids = self.tokenizer(
            caption,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"][0]

        return {
            "pixel_values": image,
            "input_ids": input_ids,
        }

In [None]:
def get_lora_params(model: torch.nn.Module):
    """
    Lấy parameters có requires_grad=True — giả sử adapter/LoRA đã bật grad cho các tham số đó.
    """
    return [p for p in model.parameters() if p.requires_grad]

def enable_lora_trainable_params(module: torch.nn.Module):
    """
    Sau khi gọi add_adapter, nhiều implementation sẽ thêm tham số với tên chứa 'lora' hoặc 'adapter'.
    Hàm này sẽ bật requires_grad cho các tham số có tên chứa một số từ khoá thường gặp.
    Nếu implementation của bạn khác, chỉnh lại điều kiện tìm kiếm tên tham số.
    """
    keywords = ["lora", "adapter", "lora_up", "lora_down", "alpha"]
    matched = []
    for name, p in module.named_parameters():
        if any(k in name.lower() for k in keywords):
            p.requires_grad = True
            matched.append(name)
        else:
            p.requires_grad = False
    # Trả về danh sách tên param được bật
    return matched

# -------------------------
# Setup models for training
# -------------------------
def setup_models_for_training(model_name: str, rank: int = 128, dtype=torch.float16, device=None):
    tokenizer, text_encoder, vae, scheduler, unet = get_models(model_name, dtype=dtype)

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1) Freeze VAE hoàn toàn (không train)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad = False

    # 2) Freeze base weights của text_encoder + unet trước
    for p in text_encoder.parameters():
        p.requires_grad = False
    for p in unet.parameters():
        p.requires_grad = False

    # 3) Thêm LoRA/adapter config cho UNet và Text Encoder
    unet_lora_config = LoraConfig(
        r=rank,
        lora_alpha=rank,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],  # giữ như bạn
    )
    # Thêm adapter (hàm này phụ thuộc implementation của bạn)
    # ví dụ: unet.add_adapter(unet_lora_config)
    unet.add_adapter(unet_lora_config)

    text_encoder_lora_config = LoraConfig(
        r=rank,
        lora_alpha=rank,
        init_lora_weights="gaussian",
        target_modules=["q_proj", "k_proj", "v_proj"],
    )
    text_encoder.add_adapter(text_encoder_lora_config)

    # 4) Bật requires_grad cho tham số LoRA/adapter (dựa trên tên param)
    unet_matched = enable_lora_trainable_params(unet)
    text_matched = enable_lora_trainable_params(text_encoder)

    # 5) Đảm bảo các param trainable ở float32 (ổn định khi dùng LoRA)
    for p in unet.parameters():
        if p.requires_grad:
            p.data = p.data.to(torch.float32)
    for p in text_encoder.parameters():
        if p.requires_grad:
            p.data = p.data.to(torch.float32)

    # 6) Đưa các module chính sang device (VAE ở eval trên device)
    vae.to(device).eval()
    text_encoder.to(device)  # will be used for LoRA training (set train mode later)
    unet.to(device).train()

    # Thông tin debug
    print(f"LoRA params in UNet matched: {len(unet_matched)} names (examples): {unet_matched[:10]}")
    print(f"LoRA params in TextEncoder matched: {len(text_matched)} names (examples): {text_matched[:10]}")

    return tokenizer, text_encoder, vae, scheduler, unet

In [None]:
class TrainingConfig:
    def __init__(self,
                 train_steps: int = 100,
                 lr: float = 1e-5,
                 batch_size: int = 2,
                 accumulation_steps: int = 4,
                 rank: int = 128,
                 max_grad_norm: float = 1.0,
                 pretrained_name: str = "runwayml/stable-diffusion-v1-5",
                 csv_path: str = "/kaggle/input/fashion-dataset/train_fixed.csv",
                 image_folder: str = "/kaggle/input/fashion-dataset/train_images",
                 snr_gamma: float = -1,
                 seed: int = 42):
        self.train_steps = train_steps
        self.lr = lr
        self.batch_size = batch_size
        self.accumulation_steps = accumulation_steps
        self.rank = rank
        self.max_grad_norm = max_grad_norm
        self.pretrained_name = pretrained_name
        self.csv_path = csv_path
        self.image_folder = image_folder
        self.snr_gamma = snr_gamma
        self.seed = seed

In [None]:
def compute_snr(scheduler: DDPMScheduler, timesteps: torch.Tensor):
    """Compute SNR for loss weighting."""
    alpha_prod = scheduler.alphas_cumprod.to(timesteps.device)
    alpha = alpha_prod[timesteps]
    snr = alpha / (1 - alpha)
    return snr

In [None]:
def train(
    tokenizer: CLIPTokenizer,
    text_encoder: CLIPTextModel,
    vae: AutoencoderKL,
    scheduler: DDPMScheduler,
    unet: UNet2DConditionModel,
    config: TrainingConfig,
    device=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Gộp các tham số LoRA từ unet + text_encoder
    lora_params = get_lora_params(unet) + get_lora_params(text_encoder)
    if len(lora_params) == 0:
        raise RuntimeError("Không có tham số trainable (LoRA). Kiểm tra add_adapter / enable_lora_trainable_params.")

    # Put models to device / mode
    vae.to(device).eval()
    text_encoder.to(device).train()
    unet.to(device).train()

    # Dataset / Dataloader
    train_dataset = FashionDataset(config.csv_path, config.image_folder, tokenizer)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True
    )

    # Compute steps/epochs
    steps_per_epoch = math.ceil(len(train_dataloader) / config.accumulation_steps)
    epochs = math.ceil(config.train_steps / steps_per_epoch)

    # Adjust lr to account for accumulation & batch (kept similar to original)
    lr = config.lr * config.accumulation_steps * config.batch_size
    optimizer = torch.optim.AdamW(lora_params, lr=lr)

    scaler = torch.cuda.amp.GradScaler()

    global_step = 0
    progress_bar = tqdm(range(config.train_steps), desc="Steps")

    print("Training config:")
    print(f" train_steps: {config.train_steps}, epochs (est): {epochs}, steps_per_epoch: {steps_per_epoch}")
    print(f" batch_size: {config.batch_size}, accumulation_steps: {config.accumulation_steps}, effective lr: {lr}")

    losses = []
    torch.manual_seed(config.seed)
    for epoch in range(epochs):
        for step, batch in enumerate(train_dataloader):
            bs = batch["input_ids"].shape[0]

            with torch.autocast(device_type="cuda" if device.type == "cuda" else "cpu", dtype=torch.float16):
                # Text encoder forward (không dùng torch.no_grad vì LoRA text encoder cần grad)
                encoder_hidden_states = text_encoder(batch["input_ids"].to(device), return_dict=False)[0]

                # Encode images to latents (VAE frozen so no grad)
                with torch.no_grad():
                    latents = vae.encode(batch["pixel_values"].to(device)).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor

                # Noise & timesteps
                timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bs,), device=device).long()
                noise = torch.randn_like(latents)
                noisy_latents = scheduler.add_noise(latents, noise, timesteps)

                # UNet predicts noise
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

                # Loss (MSE or SNR-weighted)
                if config.snr_gamma > 0:
                    snr = compute_snr(scheduler, timesteps)
                    mse_loss_weights = torch.stack([snr, config.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0]
                    mse_loss_weights = mse_loss_weights / snr
                    loss = F.mse_loss(noise_pred, noise, reduction="none")
                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                    loss = loss.mean()
                else:
                    loss = F.mse_loss(noise_pred, noise, reduction="mean")

            global_step += 1
            scaler.scale(loss).backward()

            if global_step % config.accumulation_steps == 0:
                # Gradient clipping (unscale first)
                if config.max_grad_norm > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(lora_params, config.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                progress_bar.update(1)

            losses.append(loss.item())
            progress_bar.set_postfix({"loss": losses[-1]})

            # stop condition
            if global_step / config.accumulation_steps >= config.train_steps:
                break
        # End epoch
    return {"losses": losses}

In [None]:
def save_lora_weights(unet, text_encoder, save_path="lora.safetensors"):
    from safetensors.torch import save_file

    # gom tất cả param trainable (LoRA) của UNet + TextEncoder
    lora_state = {}

    for name, param in unet.named_parameters():
        if param.requires_grad:
            lora_state[f"unet.{name}"] = param.detach().cpu()

    for name, param in text_encoder.named_parameters():
        if param.requires_grad:
            lora_state[f"text_encoder.{name}"] = param.detach().cpu()

    save_file(lora_state, save_path)
    print(f"✔ Saved LoRA weights → {save_path}")


In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cfg = TrainingConfig(
        train_steps=10,        # điều chỉnh tuỳ bạn
        lr=1e-5,
        batch_size=2,
        accumulation_steps=4,
        rank=64,                # giảm rank nếu VRAM hạn chế trên T4
        max_grad_norm=1.0,
        pretrained_name="runwayml/stable-diffusion-v1-5",
        csv_path="/kaggle/input/fashion-dataset/train_fixed.csv",
        image_folder="/kaggle/input/fashion-dataset/train_images",
        snr_gamma=-1,
        seed=42
    )

    # Load & setup
    tokenizer, text_encoder, vae, scheduler, unet = setup_models_for_training(
        cfg.pretrained_name,
        rank=cfg.rank,
        dtype=torch.float16,
        device=device
    )

    # Train
    metrics = train(tokenizer, text_encoder, vae, scheduler, unet, cfg, device=device)
    save_lora_weights(unet, text_encoder, "my_lora.safetensors")
    print("Done. Final losses (last 5):", metrics["losses"][-5:])

===============================================================

In [None]:
!pip install diffusers transformers accelerate torch torchvision pillow pandas numpy tqdm scipy ftfy open_clip_torch safetensors
!pip install pytorch-fid
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import csv

# For FID calculation
from scipy import linalg

# For CLIP Score
import clip

# Dataset class
class ValidationDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['image']  # Your image column name
        caption = self.data.iloc[idx]['caption']  # Your caption column name
        
        img_path = os.path.join(self.image_folder, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, caption, img_name

# Load model with LoRA for both UNet and text encoder
def load_model_with_lora():
    from diffusers import StableDiffusionPipeline
    from safetensors.torch import load_file
    
    # Load base model
    model_id = "runwayml/stable-diffusion-v1-5"  # Adjust as needed
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        safety_checker=None,
        requires_safety_checker=False,
    )
    
    # Load LoRA weights for both UNet and text encoder
    lora_weights = load_file("/kaggle/input/lora-weight/my_lora.safetensors")
    
    # Load UNet LoRA
    pipe.unet.load_state_dict(lora_weights, strict=False)
    
    # Load text encoder LoRA if present
    text_encoder_keys = [k for k in lora_weights.keys() if k.startswith('text_encoder')]
    if text_encoder_keys:
        pipe.text_encoder.load_state_dict(lora_weights, strict=False)
    
    if torch.cuda.is_available():
        pipe = pipe.to("cuda")
    
    return pipe

# Generate images for validation
def generate_images(pipe, captions, output_dir, image_size=512):
    os.makedirs(output_dir, exist_ok=True)
    generated_paths = []
    
    for i, caption in enumerate(tqdm(captions, desc="Generating images")):
        # Generate image
        with torch.autocast("cuda"):
            image = pipe(
                caption, 
                num_inference_steps=50,
                guidance_scale=7.5,
                height=image_size,
                width=image_size
            ).images[0]
        
        # Save image
        img_path = os.path.join(output_dir, f"generated_{i:05d}.png")
        image.save(img_path)
        generated_paths.append(img_path)
    
    return generated_paths

# Calculate CLIP Score
def calculate_clip_score(model, preprocess, generated_images, captions, batch_size=16):
    device = next(model.parameters()).device
    
    # Process images
    image_tensors = []
    for img_path in tqdm(generated_images, desc="Processing images for CLIP"):
        image = Image.open(img_path).convert('RGB')
        image_tensor = preprocess(image).unsqueeze(0)
        image_tensors.append(image_tensor)
    
    image_tensors = torch.cat(image_tensors).to(device)
    
    # Process texts
    text_tokens = clip.tokenize(captions, truncate=True).to(device)
    
    # Calculate similarities in batches
    similarities = []
    for i in tqdm(range(0, len(image_tensors), batch_size), desc="Calculating CLIP scores"):
        batch_images = image_tensors[i:i+batch_size]
        batch_texts = text_tokens[i:i+batch_size]
        
        with torch.no_grad():
            image_features = model.encode_image(batch_images)
            text_features = model.encode_text(batch_texts)
            
            # Normalize features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            # Calculate cosine similarity
            similarity = (image_features * text_features).sum(dim=-1)
            similarities.append(similarity.cpu())
    
    similarities = torch.cat(similarities)
    clip_score = similarities.mean().item()
    
    return clip_score, similarities

# Extract features for FID calculation using InceptionV3
def extract_features_inception(image_paths, batch_size=16, dims=2048):
    """Extract features using InceptionV3 for FID calculation"""
    from pytorch_fid.inception import InceptionV3
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx]).to(device)
    model.eval()
    
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    all_features = []
    
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting features for FID"):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = []
        
        for path in batch_paths:
            try:
                image = Image.open(path).convert('RGB')
                image = transform(image).unsqueeze(0)
                batch_images.append(image)
            except Exception as e:
                print(f"Error loading {path}: {e}")
                continue
        
        if not batch_images:
            continue
            
        batch_images = torch.cat(batch_images).to(device)
        
        with torch.no_grad():
            features = model(batch_images)[0].squeeze(3).squeeze(2).cpu().numpy()
            all_features.append(features)
    
    if not all_features:
        raise ValueError("No features extracted - check your image paths")
    
    return np.concatenate(all_features)

def calculate_fid(real_features, generated_features):
    """Calculate FID score between real and generated features"""
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)
    
    diff = mu1 - mu2
    
    # Calculate sqrt of product of cov matrices
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    
    # Numerical stability
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

# Calculate both metrics
def calculate_metrics(real_image_paths, generated_paths, captions):
    """Calculate both FID and CLIP Score"""
    
    # Calculate FID
    print("Extracting real image features for FID...")
    real_features = extract_features_inception(real_image_paths)
    
    print("Extracting generated image features for FID...")
    generated_features = extract_features_inception(generated_paths)
    
    print("Calculating FID score...")
    fid_value = calculate_fid(real_features, generated_features)
    
    # Calculate CLIP Score
    print("Calculating CLIP Score...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
    clip_score, all_scores = calculate_clip_score(clip_model, clip_preprocess, generated_paths, captions)
    
    return fid_value, clip_score

# Main validation function
def main():
    # Configuration
    csv_file = "/kaggle/input/fashion-dataset/val_fixed.csv"
    image_folder = "/kaggle/input/fashion-dataset/val_images"
    output_dir = "generated_val_images"
    batch_size = 2  # Conservative for T4 GPU
    
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Load dataset
    dataset = ValidationDataset(csv_file, image_folder)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Get all captions and real image paths
    all_captions = [dataset[i][1] for i in range(len(dataset))]
    real_image_paths = [os.path.join(image_folder, dataset.data.iloc[i]['image']) 
                       for i in range(len(dataset))]
    
    # Verify real images exist
    for path in real_image_paths[:5]:  # Check first 5
        if not os.path.exists(path):
            print(f"Warning: Real image not found: {path}")
    
    # Load LoRA model
    print("Loading LoRA model...")
    pipe = load_model_with_lora()
    
    # Generate images
    print("Generating validation images...")
    generated_paths = generate_images(pipe, all_captions, output_dir)
    
    # Calculate metrics
    fid_value, clip_score = calculate_metrics(real_image_paths, generated_paths, all_captions)
    
    # Save results
    results = {
        'FID': fid_value,
        'CLIP_Score': clip_score,
        'num_samples': len(dataset)
    }
    
    with open('validation_results.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Metric', 'Value'])
        for key, value in results.items():
            writer.writerow([key, value])
    
    # Save per-image scores
    with open('per_image_scores.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['image_name', 'caption', 'generated_image'])
        for i in range(len(dataset)):
            writer.writerow([
                dataset.data.iloc[i]['image'],
                dataset.data.iloc[i]['caption'],
                f"generated_{i:05d}.png"
            ])
    
    print("\n" + "="*50)
    print("VALIDATION RESULTS:")
    print("="*50)
    print(f"FID Score: {fid_value:.4f}")
    print(f"CLIP Score: {clip_score:.4f}")
    print(f"Number of samples: {len(dataset)}")
    print("="*50)
    
    # Interpretation
    print("\nINTERPRETATION:")
    print("- Lower FID is better (0 = perfect)")
    print("- Higher CLIP Score is better (1.0 = perfect alignment)")
    print(f"- Typical CLIP Scores range from 0.2 to 0.4 for good models")

if __name__ == "__main__":
    main()