In [None]:
import datasets
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
import matplotlib as mpl
import math
import textwrap
import pandas as pd
import os
import time
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import re, time, json, random, gc, functools
from dataclasses import dataclass, asdict
from pathlib import Path
from torchmetrics.functional.multimodal import clip_score
from functools import partial
from PIL import Image
from typing import Dict, List, Tuple, Iterable, Any, Optional
import pathlib, importlib
from tqdm.auto import tqdm
from diffusers.optimization import get_scheduler

In [None]:
# plot style from IKT215
def set_mpl_params(dpi: int = 450, figsize: Tuple[int, int] = (9, 6), grid: bool = True, font_size: int = 12, font_family: str = 'serif') -> None:
    mpl.rcParams['figure.dpi'] = dpi
    mpl.rcParams['figure.figsize'] = figsize
    mpl.rcParams['axes.grid'] = grid
    mpl.rcParams.update({'font.size': font_size})
    mpl.rcParams['font.family'] = font_family

In [None]:
dataset = load_dataset("Norod78/cartoon-blip-captions", split = "train")
num_samples = 1500
indices = random.sample(range(len(dataset)), num_samples)
dataset = dataset.select(indices)

In [None]:
def show_images(dataset, samples: int = 15, cols: int = 4) -> None:
    set_mpl_params(figsize = (15, 15))
    rows = math.ceil(samples / cols)
    for i in range(samples):
        img = dataset[i]['image']
        caption = dataset[i]['text']
        wrapped = "\n".join(textwrap.wrap(caption, width = 40))
        plt.subplot(rows, cols, i + 1)
        plt.imshow(img)
        plt.title(wrapped)
        plt.axis('off')
        plt.title(wrapped, fontsize = 8)

    plt.tight_layout()
    plt.show()

In [None]:
show_images(dataset)

In [None]:
print(dataset)
print(dataset.features)
print(len(dataset))
if len(dataset) > 500:
    print("good to go")

In [None]:
def setup_device() -> tuple[torch.device, bool]:
    if torch.cuda.is_available():
        n = torch.cuda.device_count()
        for i in range(n):
            p = torch.cuda.get_device_properties(i)
            print(f"GPU {i}: {p.name} ({p.total_memory / 1e9:.1f}GB)")
        return torch.device("cuda:0"), n > 1
    return torch.device("cpu"), False

In [None]:
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [None]:
device, _multi = setup_device()
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_everything(42)

In [None]:
@dataclass
class TrainConfig:
    rank: int = 16
    alpha: Optional[int] = None
    lr: float = 0.0005
    batch_size: int = 4
    grad_accum: int = 1
    num_epochs: int = 10
    num_workers: int = 4
    caption_key: str = "text"
    image_key: str = "image"
    image_size: int = 512
    save_dir: str = "misc"
    save_every_steps: int = 0
    target_text_encoder: bool = False
    max_train_samples: int = 10000
    prompts_for_eval: int = 50
    guidance_scale: float = 7.5
    num_inference_steps: int = 30
    sd_model: str = "runwayml/stable-diffusion-v1-5"
    prompts: str = "sdlora_prompts.json"

    def __post_init__(self):
        if self.alpha is None:
            self.alpha = self.rank * 2

In [None]:
class LoRALayer(nn.Module):
    def __init__(self, original_layer: nn.Module, rank: int, alpha: int, dropout: float = 0.0, seed: int = 42) -> None:
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        self.scale = alpha / rank
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        for p in self.original_layer.parameters(): p.requires_grad_(False)
        if hasattr(original_layer, "weight"):
            dtype = original_layer.weight.dtype
            device = original_layer.weight.device
        else:
            dtype = torch.float32
            device = next(original_layer.parameters(), torch.zeros(1)).device
        g = torch.Generator().manual_seed(seed)
        if isinstance(self.original_layer, nn.Linear):
            in_features = self.original_layer.in_features; out_features = self.original_layer.out_features
            self.lora_A = nn.Linear(in_features, rank, bias = False)
            self.lora_B = nn.Linear(rank, out_features, bias = False)
            nn.init.kaiming_uniform_(self.lora_A.weight, a = np.sqrt(5), generator = g)
            nn.init.zeros_(self.lora_B.weight)
            self.lora_A.to(device, dtype = dtype); self.lora_B.to(device, dtype = dtype)
        elif isinstance(self.original_layer, nn.Conv2d):
            self.lora_A = nn.Conv2d(self.original_layer.in_channels, rank, self.original_layer.kernel_size, bias = False)
            self.lora_B = nn.Conv2d(rank, self.original_layer.out_channels, 1, stride = self.original_layer.stride, padding = self.original_layer.padding, dilation = self.original_layer.dilation, groups = self.original_layer.groups, bias = False)
            nn.init.kaiming_uniform_(self.lora_A.weight, a = np.sqrt(5), generator = g)
            nn.init.zeros_(self.lora_B.weight)
            self.lora_A.to(device, dtype = dtype); self.lora_B.to(device, dtype = dtype)
        else:
            raise TypeError(f"Unsupported layer type: {type(original_layer)}")
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        base = self.original_layer(x); lora = self.lora_B(self.lora_A(self.dropout(x)))
        return base + lora * self.scale

In [None]:
def lora_to_unet(unet: nn.Module, rank: int, alpha: int, dropout: float = 0.1, target_modules: list[str] | None = None) -> dict[str, LoRALayer]:
    if target_modules is None: 
        target_modules = ["to_k", "to_q", "to_v", "to_out.0", "ff.net.0", "ff.net.2", "attn1.to_q", "attn1.to_v"]
    layers: dict[str, LoRALayer] = {}
    def inject(module: nn.Module, prefix: str = "") -> None:
        for name, child in module.named_children():
            full = f"{prefix}.{name}" if prefix else name
            if isinstance(child, LoRALayer): 
                continue
            if any(full.endswith(t) for t in target_modules) and isinstance(child, (nn.Linear, nn.Conv2d)):
                wrapped = LoRALayer(child, rank, alpha, dropout = dropout)
                setattr(module, name, wrapped)
                wrapped._is_lora = True
                wrapped._lora_target_name = full
                layers[full] = wrapped
            else: 
                inject(child, full)
    inject(unet)
    return layers

In [None]:
@dataclass
class TransformData(Dataset):
    dataset: datasets.Dataset
    config: TrainConfig

    def __post_init__(self):
        self.image_transform = T.Compose([
            T.Resize((self.config.image_size, self.config.image_size)),
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return min(len(self.dataset), self.config.max_train_samples) if self.config.max_train_samples else len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.dataset[idx]
        image = item[self.config.image_key]
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        image = image.convert("RGB")
        image = self.image_transform(image)
        
        caption = item[self.config.caption_key]
        if not isinstance(caption, str):
            caption = str(caption)

        return {
            "pixel_values": image,
            "text": caption
        }

In [None]:
def setup_training(config: TrainConfig) -> dict[str, Any]:
    device, _ = setup_device()
    dtype = torch.float32
    model_id = getattr(config, "sd_model", "runwayml/stable-diffusion-v1-5")
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype)
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype)
    scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
    vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype)

    for m in [text_encoder, vae]:
        for p in m.parameters():
            p.requires_grad_(False)

    lora_layers = lora_to_unet(unet, rank=config.rank if hasattr(config, "rank") else 16, alpha=config.alpha if hasattr(config, "alpha") else 32, dropout=getattr(config, "dropout", 0.1),)
    unet = unet.to(device)
    text_encoder = text_encoder.to(device)
    vae = vae.to(device)

    text_encoder.eval()
    vae.eval()

    return {
        "device": device,
        "dtype": dtype,
        "unet": unet,
        "vae": vae,
        "tokenizer": tokenizer,
        "text_encoder": text_encoder,
        "scheduler": scheduler,
        "lora_layers": lora_layers
    }


In [None]:
def collate_fn(batch, tokenizer):
    images = torch.stack([example["pixel_values"] for example in batch])
    captions = [example["text"] for example in batch]
    
    tokens = tokenizer(
        captions,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    )
    
    return {
        "pixel_values": images,
        "input_ids": tokens.input_ids,
        "attention_mask": tokens.attention_mask,
        "captions": captions,
    }


In [None]:
def extract_lora_state_dict(model: nn.Module) -> dict[str, torch.Tensor]:
    state_dict = {}
    for name, m in model.named_modules():
        if isinstance(m, LoRALayer):
            state_dict[f"{name}.lora_A.weight"] = m.lora_A.weight.detach().cpu()
            state_dict[f"{name}.lora_B.weight"] = m.lora_B.weight.detach().cpu()
            state_dict[f"{name}.alpha"] = torch.tensor(m.alpha)
    return state_dict

In [None]:
def no_grad(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        with torch.no_grad():
            return func(*args, **kwargs)
    return wrapper

In [None]:
@no_grad
def encode_inputs(vae, text_encoder, pixel_values, input_ids, attention_mask):
    latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
    encoder_hidden_states = text_encoder(input_ids, attention_mask=attention_mask)[0]
    return latents, encoder_hidden_states

In [None]:
def train_LoRA(config: TrainConfig, train_dataloader: DataLoader, setup: dict[str, Any], output_dir: str | None = None, save_every: int = 500) -> tuple[list[float], dict[str, Any]]:
    if output_dir is None: 
        output_dir = config.save_dir

    device, dtype = setup["device"], torch.float32
    unet, vae, text_encoder = setup["unet"], setup["vae"], setup["text_encoder"]
    scheduler, lora_layers = setup["scheduler"], setup["lora_layers"]

    set_mpl_params(); seed_everything(42); os.makedirs(output_dir, exist_ok = True)

    lora_params: list[nn.Parameter] = []
    for lora_layer in lora_layers.values():
        lora_params.extend(list(lora_layer.lora_A.parameters())); lora_params.extend(list(lora_layer.lora_B.parameters()))

    optimizer = torch.optim.AdamW(params = lora_params, lr = config.lr, weight_decay = 0.0, betas = (0.9, 0.999), eps = 1e-8)
    total_steps = len(train_dataloader) * config.num_epochs; warmup_steps = int(total_steps * 0.01)
    lr_scheduler = get_scheduler(name = "constant_with_warmup", optimizer = optimizer, num_warmup_steps = warmup_steps, num_training_steps = total_steps)

    global_step, loss_history, epoch_times, epoch_memory = 0, [], [], []
    if torch.cuda.is_available(): 
        torch.cuda.reset_peak_memory_stats(device)
    start_time = time.perf_counter()

    for epoch in range(config.num_epochs):
        unet.train(); epoch_start = time.time(); epoch_loss = 0.0; ema_loss = None
        bar = tqdm(train_dataloader, desc = f"Epoch {epoch + 1}/{config.num_epochs}")
        for batch in bar:
            pixel_values = batch["pixel_values"].to(device, dtype = dtype); input_ids = batch["input_ids"].to(device); attention_mask = batch["attention_mask"].to(device)
            latents, encoder_hidden_states = encode_inputs(vae, text_encoder, pixel_values, input_ids, attention_mask)
            bsz = latents.shape[0]; timesteps = torch.randint(0, scheduler.num_train_timesteps, (bsz,), device = device).long(); noise = torch.randn_like(latents); noisy_latents = scheduler.add_noise(latents, noise, timesteps)
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            pred_type = getattr(getattr(scheduler, "config", None), "prediction_type", "epsilon")
            if pred_type == "epsilon": 
                target = noise
            elif pred_type == "v_prediction": 
                target = scheduler.get_velocity(latents, noise, timesteps)
            else: 
                raise ValueError(f"Unknown scheduler prediction type: {pred_type}")
            loss = F.mse_loss(model_pred.float(), target.float())
            optimizer.zero_grad(set_to_none = True); loss.backward(); torch.nn.utils.clip_grad_norm_(lora_params, 1.0); optimizer.step(); lr_scheduler.step()
            global_step += 1; loss_history.append(loss.item()); epoch_loss += loss.item()
            ema_loss = loss.item() if ema_loss is None else 0.9 * ema_loss + 0.1 * loss.item()
            bar.set_postfix({"loss": f"{loss.item():.4f}", "ema": f"{ema_loss:.4f}"})
            if global_step % save_every == 0:
                save_path = os.path.join(output_dir, f"lora_step_{global_step}_r{config.rank}.pt"); torch.save(extract_lora_state_dict(unet), save_path)
        epoch_end = time.time(); duration_min = (epoch_end - epoch_start) / 60
        epoch_times.append(duration_min); peak_mem = torch.cuda.max_memory_allocated(device) / 1e9 if torch.cuda.is_available() else 0.0; epoch_memory.append(peak_mem)
        avg_loss = epoch_loss / len(train_dataloader)
        print(f"\nEpoch {epoch + 1}: avg_loss = {avg_loss:.4f} | time = {duration_min:.2f}m | peak_mem = {peak_mem:.2f}GB")
        rank_dir = os.path.join(output_dir, f"rank_{config.rank}"); os.makedirs(rank_dir, exist_ok = True)
        if (epoch + 1) % 2 == 0:
            plt.figure(figsize = (9, 6)); plt.plot(loss_history); plt.title(f"Training loss at rank {config.rank}"); plt.xlabel("Steps"); plt.ylabel("MSE loss"); plt.grid(True); plt.tight_layout(); plt.savefig(f"{rank_dir}/loss_epoch_{epoch + 1}.png", dpi = 300); plt.close()
        if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats(device)

    save_path = os.path.join(output_dir, f"lora_final_r{config.rank}.pt"); torch.save(extract_lora_state_dict(unet), save_path)
    elapsed_time = time.perf_counter() - start_time; file_size_mb = os.path.getsize(save_path) / 1024 ** 2
    stats = {"rank": config.rank, "train_time_s": elapsed_time, "peak_mem_gb": max(epoch_memory) if epoch_memory else 0, "file_size_mb": file_size_mb, "final_loss": loss_history[-1] if loss_history else None}
    print(f"Rank {config.rank} Summary → time = {elapsed_time:.1f}s | peak_mem = {stats['peak_mem_gb']:.2f}GB | file = {file_size_mb:.2f}MB")
    return loss_history, stats

In [None]:
def clean_cuda():
    import torch, gc
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

In [None]:
config = TrainConfig()
train_dataset = TransformData(dataset, config)
all_stats = {}
all_losses = {}
clean_cuda()

for rank in [16, 64]:
    config = TrainConfig(rank = rank)
    setup = setup_training(config)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size = config.batch_size,
        shuffle=True,
        num_workers = config.num_workers,
        collate_fn = lambda ex: collate_fn(ex, setup["tokenizer"])
    )

    loss_history, stats = train_LoRA(config, train_dataloader, setup)
    all_stats[rank] = stats
    all_losses[rank] = loss_history

print("\nTraining summary by rank:")
for r, s in all_stats.items():
    print(f"Rank {r}: final_loss={s['final_loss']:.4f}, time={s['train_time_s']:.1f}s, file_size={s['file_size_mb']:.2f}MB")

In [None]:
set_mpl_params()
plt.figure(figsize = (8, 5))
colors = ["#0d47a1", "#64b5f6"] 

for (rank, losses), c in zip(sorted(all_losses.items()), colors):
    smooth = np.convolve(losses, np.ones(50) / 50, mode = "valid")
    plt.plot(smooth, label = f"Rank {rank}", linewidth = 2, color = c)

plt.title("Smoothed loss curves for different ranks")
plt.xlabel("Steps")
plt.ylabel("Smoothed loss")
plt.legend(title = "Rank")
plt.grid(True, linestyle = "--", alpha = 0.7)
plt.tight_layout()
plt.savefig("misc/loss_curves_smooth.png", dpi = 150)
plt.show()

In [None]:
set_mpl_params()
ranks = sorted(all_stats.keys())
times = [all_stats[r]["train_time_s"] / 60 for r in ranks]
mems = [all_stats[r]["peak_mem_gb"] for r in ranks]
sizes = [all_stats[r]["file_size_mb"] for r in ranks]

plt.figure(figsize = (8, 5))
plt.plot(ranks, times, marker = "o", linewidth = 2, color = "#1f77b4", label = "Time to train (min)")
plt.plot(ranks, mems, marker = "s", linewidth = 2, color = "#2980b9", label = "Peak memory (GB)")
plt.plot(ranks, sizes, marker = "D", linewidth = 2, color = "#5dade2", label = "File size (MB)")

plt.title("Performance of each rank")
plt.xlabel("Rank")
plt.ylabel("Given value")
plt.xticks(ranks)
plt.legend()
plt.grid(True, linestyle = "--", alpha = 0.7)
plt.tight_layout()
plt.savefig("misc/performance_vs_rank.png", dpi = 150)
plt.show()

In [None]:
def load_lora_weights(unet: UNet2DConditionModel, config: TrainConfig, lora_weights_path: str) -> UNet2DConditionModel:
    state_dict = torch.load(lora_weights_path, map_location = "cpu", weights_only = True)
    lora_layers = lora_to_unet(unet, rank = config.rank, alpha = config.alpha)
    for name, layer in lora_layers.items():
        if f"{name}.lora_A.weight" in state_dict:
            layer.lora_A.weight.data.copy_(state_dict[f"{name}.lora_A.weight"])
        if f"{name}.lora_B.weight" in state_dict:
            layer.lora_B.weight.data.copy_(state_dict[f"{name}.lora_B.weight"])
    return unet

In [None]:
def generate_images(config: TrainConfig, lora_paths: list[tuple[int, str]], categories: list[str] = ["normal", "regular_abstract"], load_lora: bool = True) -> None:
    seed_everything(42); device, dtype = setup_device()[0], torch.float32
    with open(config.prompts, "r") as f: data = json.load(f)
    for category in categories:
        prompts = data[category][: config.prompts_for_eval]
        for rank, lora_path in lora_paths:
            pipe = StableDiffusionPipeline.from_pretrained(config.sd_model, safety_checker = None, requires_safety_checker = False, torch_dtype = dtype).to(device)
            if load_lora:
                cfg = TrainConfig(rank = rank, alpha = rank * 2)
                pipe.unet = load_lora_weights(pipe.unet, cfg, lora_path)
            pipe.unet.eval(); pipe.text_encoder.eval(); pipe.vae.eval()
            out_dir = os.path.join(config.save_dir, f"generated_rank_{rank}_{category}_lora" if load_lora else f"generated_baseline_{category}"); os.makedirs(out_dir, exist_ok = True)
            print(f"\nGenerating {len(prompts)} images to {out_dir}")
            for i, prompt in enumerate(prompts):
                g = torch.Generator(device = device).manual_seed(1234 + i)
                image = pipe(prompt, num_inference_steps = config.num_inference_steps, guidance_scale = config.guidance_scale, generator = g).images[0]
                name = (f"rank{rank}_{category}_{i:03d}.png" if load_lora else f"baseline_{i:03d}.png"); image.save(os.path.join(out_dir, name))
            print(f"Saved {len(prompts)} images to {out_dir}")

In [None]:
generate_images(config, lora_paths = [(16, "misc/lora_final_r16.pt"), (64, "misc/lora_final_r64.pt")], categories = ["normal", "regular_abstract"], load_lora = True)

In [None]:
def create_comparison_image(img1: Image.Image, img2: Image.Image, prompt: str, img_size: Tuple[int, int] = (512, 512)) -> Tuple[Image.Image, Image.Image, Image.Image]:
    from PIL import Image, ImageDraw, ImageFont
    img1 = img1.resize(img_size)
    img2 = img2.resize(img_size)
    canvas_width = img_size[0] * 2 + 20
    canvas_height = img_size[1] + 50
    canvas = Image.new("RGB", (canvas_width, canvas_height), "white")
    canvas.paste(img1, (0, 0))
    canvas.paste(img2, (img_size[0] + 20, 0))
    draw = ImageDraw.Draw(canvas)
    try:
        font = ImageFont.truetype("arial.ttf", 14)
    except:
        font = ImageFont.load_default()
    
    text = prompt[:100] + "..." if len(prompt) > 100 else prompt
    text_bbox = draw.textbbox((0, 0), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    draw.text(((canvas_width - text_width) // 2, img_size[1] + 15), text, fill = "black", font = font)
    
    return img1, img2, canvas

In [None]:
def comparative_test(config: TrainConfig, lora_paths: list[tuple[int, str]], categories: list[str] = ["normal", "regular_abstract"], output_dir: str = "outputs") -> None:
    device, _ = setup_device(); seed_everything(42); set_mpl_params(); os.makedirs(output_dir, exist_ok = True)
    with open(config.prompts, "r") as f: data = json.load(f)
    for category in categories:
        prompts = data[category][: config.prompts_for_eval]
        for rank, lora_path in lora_paths:
            pipe = StableDiffusionPipeline.from_pretrained(config.sd_model, torch_dtype = torch.float16, safety_checker = None, requires_safety_checker = False).to(device)
            pipe.enable_attention_slicing()
            base_imgs = []
            for i, prompt in enumerate(prompts):
                g = torch.Generator(device = device).manual_seed(1234 + i)
                img = pipe(prompt, num_inference_steps = config.num_inference_steps, guidance_scale = config.guidance_scale, generator = g).images[0]
                base_imgs.append(img)
            cfg = TrainConfig(rank = rank, alpha = rank * 2)
            pipe.unet = load_lora_weights(pipe.unet, cfg, lora_path).to(device)
            torch.cuda.empty_cache()
            for i, (prompt, base_img) in enumerate(zip(prompts, base_imgs)):
                g = torch.Generator(device = device).manual_seed(1234 + i)
                lora_img = pipe(prompt, num_inference_steps = config.num_inference_steps, guidance_scale = config.guidance_scale, generator = g).images[0]
                base_img, lora_img, comparison = create_comparison_image(base_img, lora_img, prompt)
                fname = prompt.replace(" ", "_")[:40]
                rank_dir = os.path.join(output_dir, f"rank_{rank}_{category}"); os.makedirs(rank_dir, exist_ok = True)
                base_img.save(f"{rank_dir}/base_{i:02d}_{fname}.png"); lora_img.save(f"{rank_dir}/lora_{i:02d}_{fname}.png"); comparison.save(f"{rank_dir}/comparison_{i:02d}_{fname}.png")
                plt.figure(figsize = (12, 6)); plt.imshow(comparison); plt.axis("off"); plt.title(f"Comparison {i+1}"); plt.show()
            print(f"Saved comparisons for rank {rank}, category '{category}' to {rank_dir}")

In [None]:
config = TrainConfig()
lora_paths = [(16, "misc/lora_final_r16.pt"), (64, "misc/lora_final_r64.pt")]
comparative_test(config, lora_paths = lora_paths, categories = ["normal", "regular_abstract"], output_dir = "comparisons")

In [None]:
def calculate_clip_score(config: TrainConfig, rank: int, category: str) -> list[float]:
    clip_score_fn = partial(clip_score, model_name_or_path = "openai/clip-vit-base-patch16")
    with open(config.prompts, "r") as f: data = json.load(f)
    prompts = data[category][: config.prompts_for_eval]
    folder = Path(config.save_dir) / f"generated_rank_{rank}_{category}_lora"
    images = sorted(folder.glob("*.png"))
    scores = []
    for i, (path, prompt) in enumerate(zip(images, prompts)):
        img = Image.open(path).convert("RGB")
        img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).unsqueeze(0)
        with torch.no_grad():
            score = clip_score_fn(img_tensor, prompt)
            scores.append(score.item())
    return scores

In [None]:
config = TrainConfig()
ranks = [16, 64]
categories = ["normal", "regular_abstract"]

for category in categories:
    for rank in ranks:
        scores = calculate_clip_score(config, rank, category)
        print(f"Rank {rank} | Category {category} → Mean CLIP: {np.mean(scores):.4f}, Std: {np.std(scores):.4f}")