In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path
import os

In [None]:
PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
Path(PROJ).mkdir(parents=True, exist_ok=True)

In [None]:
RAW_SCENE_ROOT  = "/content/drive/MyDrive/Gen AI/Gen AI project/Scene"
RAW_OBJECT_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/Object"  # we’ll use this later

print("PROJ:", PROJ)
print("RAW_SCENE_ROOT:", RAW_SCENE_ROOT, "Exists?", os.path.isdir(RAW_SCENE_ROOT))
print("RAW_OBJECT_ROOT:", RAW_OBJECT_ROOT, "Exists?", os.path.isdir(RAW_OBJECT_ROOT))

print("\nSubfolders under RAW_SCENE_ROOT:")
if os.path.isdir(RAW_SCENE_ROOT):
    for name in os.listdir(RAW_SCENE_ROOT):
        print(" -", name)
else:
    print(" !! RAW_SCENE_ROOT does not exist. Double-check the path.")

In [None]:
from pathlib import Path
import glob, shutil, os


In [None]:
NORM_SCENE_ROOT = f"{PROJ}/data/sketchycoco/scene"
SCENE_TRAIN_SK  = Path(f"{NORM_SCENE_ROOT}/train/sketch")
SCENE_TRAIN_PH  = Path(f"{NORM_SCENE_ROOT}/train/photo")
SCENE_VAL_SK    = Path(f"{NORM_SCENE_ROOT}/val/sketch")
SCENE_VAL_PH    = Path(f"{NORM_SCENE_ROOT}/val/photo")

for p in [SCENE_TRAIN_SK, SCENE_TRAIN_PH, SCENE_VAL_SK, SCENE_VAL_PH]:
    p.mkdir(parents=True, exist_ok=True)

print("Normalized scene root:", NORM_SCENE_ROOT)

def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))

In [None]:
def find_subdir(root, name_options):
    name_options = [n.lower() for n in name_options]
    for entry in os.listdir(root):
        full = os.path.join(root, entry)
        if os.path.isdir(full) and entry.lower() in name_options:
            return full
    return None

sketch_root = find_subdir(RAW_SCENE_ROOT, ["sketch", "Sketch"])
gt_root     = find_subdir(RAW_SCENE_ROOT, ["gt", "GT"])

# If not found directly under RAW_SCENE_ROOT, try one level deeper if a common "Scene" folder exists
if sketch_root is None or gt_root is None:
    nested_scene_path = find_subdir(RAW_SCENE_ROOT, ["Scene"])
    if nested_scene_path:
        print(f"Searching for sketch/gt in nested path: {nested_scene_path}")
        if sketch_root is None:
            sketch_root = find_subdir(nested_scene_path, ["sketch", "Sketch"])
        if gt_root is None:
            gt_root = find_subdir(nested_scene_path, ["gt", "GT"])


print("Detected sketch_root:", sketch_root)
print("Detected gt_root    :", gt_root)

if sketch_root is None or gt_root is None:
    raise RuntimeError("Could not find Sketch/GT folders under RAW_SCENE_ROOT or its immediate 'Scene' subfolder. Check names and path structure.")

# Added debugging: List contents of sketch_root to help diagnose issues in the next cell
print(f"\nContents of sketch_root ({sketch_root}):")
if sketch_root and os.path.isdir(sketch_root):
    for item in os.listdir(sketch_root):
        print(f" - {item}")
else:
    print(" - sketch_root is not a valid directory or does not exist.")

In [None]:
def dive_if_single_subdir(root):
    subdirs = [d for d in os.scandir(root) if d.is_dir()]
    if len(subdirs) == 1:
        print(f"Found single subdir '{subdirs[0].name}' under {root}, diving into it.")
        return subdirs[0].path
    return root

split_sketch_root = dive_if_single_subdir(sketch_root)
split_gt_root     = dive_if_single_subdir(gt_root)

print("\nUsing these as split roots:")
print("  split_sketch_root:", split_sketch_root)
print("  split_gt_root    :", split_gt_root)

In [None]:
NORM_SCENE_ROOT = f"{PROJ}/data/sketchycoco/scene"
SCENE_TRAIN_SK  = Path(f"{NORM_SCENE_ROOT}/train/sketch")
SCENE_TRAIN_PH  = Path(f"{NORM_SCENE_ROOT}/train/photo")
SCENE_VAL_SK    = Path(f"{NORM_SCENE_ROOT}/val/sketch")
SCENE_VAL_PH    = Path(f"{NORM_SCENE_ROOT}/val/photo")

for p in [SCENE_TRAIN_SK, SCENE_TRAIN_PH, SCENE_VAL_SK, SCENE_VAL_PH]:
    p.mkdir(parents=True, exist_ok=True)

def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))

In [None]:
# sketch_splits = sorted([d.name for d in os.scandir(split_sketch_root) if d.is_dir()])
# gt_splits     = sorted([d.name for d in os.scandir(split_gt_root)     if d.is_dir()])

# print("\nSketch splits under", split_sketch_root, ":", sketch_splits)
# print("GT splits under", split_gt_root, ":", gt_splits)

# common_splits = sorted(set(sketch_splits) & set(gt_splits))
# print("Common splits (Sketch ∩ GT):", common_splits)

# # Treat all common splits except 'val' as TRAIN, 'val' as VAL
# train_splits = [s for s in common_splits if s.lower() != "val"]
# val_splits   = [s for s in common_splits if s.lower() == "val"]

# print("Train splits we will use:", train_splits)
# print("Val splits we will use  :", val_splits)

# scene_sk_train, scene_img_train = [], []
# for split in train_splits:
#     scene_sk_train += g(f"{split_sketch_root}/{split}/**/*.*")
#     scene_img_train+= g(f"{split_gt_root}/{split}/**/*.*")

# scene_sk_val, scene_img_val = [], []
# for split in val_splits:
#     scene_sk_val += g(f"{split_sketch_root}/{split}/**/*.*")
#     scene_img_val+= g(f"{split_gt_root}/{split}/**/*.*")

# print("\nRaw counts (after auto-detect):")
# print("  train sketches:", len(scene_sk_train))
# print("  train images  :", len(scene_img_train))
# print("  val sketches  :", len(scene_sk_val))
# print("  val images    :", len(scene_img_val))

# def copy_pairs(sketch_list, photo_list, out_sk_dir: Path, out_ph_dir: Path):
#     by_stem_photo = {Path(p).stem: p for p in photo_list}
#     paired = 0
#     for s in sketch_list:
#         stem = Path(s).stem
#         if stem in by_stem_photo:
#             dst_sk = out_sk_dir / f"{stem}{Path(s).suffix}"
#             dst_ph = out_ph_dir / f"{stem}{Path(by_stem_photo[stem]).suffix}"
#             if not dst_sk.exists():
#                 shutil.copy2(s, dst_sk)
#             if not dst_ph.exists():
#                 shutil.copy2(by_stem_photo[stem], dst_ph)
#             paired += 1
#     return paired

# paired_train = copy_pairs(scene_sk_train, scene_img_train, SCENE_TRAIN_SK, SCENE_TRAIN_PH)
# paired_val   = copy_pairs(scene_sk_val,   scene_img_val,   SCENE_VAL_SK,   SCENE_VAL_PH)

# print("\nPaired train:", paired_train)
# print("Paired val  :", paired_val)

# print("\nFinal normalized counts:")
# print("  train sketch:", len(list(SCENE_TRAIN_SK.glob('*'))))
# print("  train photo :", len(list(SCENE_TRAIN_PH.glob('*'))))
# print("  val   sketch:", len(list(SCENE_VAL_SK.glob('*'))))
# print("  val   photo :", len(list(SCENE_VAL_PH.glob('*'))))





# ... (Keep your imports and path definitions from the top of the cell if needed) ...

# 1. IDENTIFY SPLITS (Fast)
sketch_splits = sorted([d.name for d in os.scandir(split_sketch_root) if d.is_dir()])
gt_splits     = sorted([d.name for d in os.scandir(split_gt_root)     if d.is_dir()])

print("\nSketch splits under", split_sketch_root, ":", sketch_splits)
print("GT splits under", split_gt_root, ":", gt_splits)

common_splits = sorted(set(sketch_splits) & set(gt_splits))
print("Common splits (Sketch ∩ GT):", common_splits)

train_splits = [s for s in common_splits if s.lower() != "val"]
val_splits   = [s for s in common_splits if s.lower() == "val"]

print("Train splits we will use:", train_splits)
print("Val splits we will use  :", val_splits)

# 2. COLLECT FILE PATHS (Fast - uses glob)
scene_sk_train, scene_img_train = [], []
for split in train_splits:
    scene_sk_train += g(f"{split_sketch_root}/{split}/**/*.*")
    scene_img_train+= g(f"{split_gt_root}/{split}/**/*.*")

scene_sk_val, scene_img_val = [], []
for split in val_splits:
    scene_sk_val += g(f"{split_sketch_root}/{split}/**/*.*")
    scene_img_val+= g(f"{split_gt_root}/{split}/**/*.*")

print("\nRaw counts (Found files):")
print("  train sketches:", len(scene_sk_train))
print("  train images  :", len(scene_img_train))
print("  val sketches  :", len(scene_sk_val))
print("  val images    :", len(scene_img_val))

# 3. VERIFY PAIRS (Fast - In-memory check only)
def verify_pairs_only(sketch_list, photo_list):
    # Create a dictionary of {filename_stem: full_path} for photos
    by_stem_photo = {Path(p).stem: p for p in photo_list}
    paired_count = 0

    # Check if each sketch has a matching photo
    for s in sketch_list:
        stem = Path(s).stem
        if stem in by_stem_photo:
            paired_count += 1

    return paired_count

# We simply calculate the number of valid pairs without moving any files
paired_train_count = verify_pairs_only(scene_sk_train, scene_img_train)
paired_val_count   = verify_pairs_only(scene_sk_val, scene_img_val)

print("------------------------------------------------")
print("Data Integrity Check (No files were copied):")
print(f"  Valid Train Pairs: {paired_train_count} / {len(scene_sk_train)}")
print(f"  Valid Val Pairs  : {paired_val_count} / {len(scene_sk_val)}")
print("------------------------------------------------")

# NOTE: The physical folders at 'data/sketchycoco/scene' are arguably EMPTY now.
# If your next cell tries to load files from there, it will fail.
# You will need to pass the lists 'scene_sk_train' and 'scene_img_train' directly to your dataset class.

In [None]:
import os, json, csv, re, glob
from pathlib import Path
import pandas as pd

SCENE_MANIFEST = f"{PROJ}/data/sketchycoco/scene_manifest.csv"
Path(os.path.dirname(SCENE_MANIFEST)).mkdir(parents=True, exist_ok=True)

print("Writing scene manifest to:", SCENE_MANIFEST)



In [None]:
ann_dir = None
for root, dirs, files in os.walk(RAW_SCENE_ROOT):
    for d in dirs:
        if d.lower().startswith("annot"):
            ann_dir = os.path.join(root, d)
            break
    if ann_dir is not None:
        break

scene_captions = {}

def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))

if ann_dir is not None:
    print("Annotation dir found at:", ann_dir)
    ann_files = g(f"{ann_dir}/**/*.json") + g(f"{ann_dir}/**/*.txt") + g(f"{ann_dir}/**/*.tsv")
    print("Found annotation files:", len(ann_files))
    for af in ann_files:
        try:
            if af.endswith(".json"):
                with open(af, "r", encoding="utf-8", errors="ignore") as f:
                    data = json.load(f)
                if isinstance(data, list):
                    for row in data:
                        fn = row.get("filename") or row.get("file_name") or row.get("image") or ""
                        cap= row.get("caption") or row.get("text") or ""
                        if fn and cap:
                            scene_captions[Path(fn).stem] = cap
                elif isinstance(data, dict):
                    for k, v in data.items():
                        if isinstance(v, str):
                            scene_captions[Path(k).stem] = v
            else:
                # txt/tsv: assume "filename<TAB>caption"
                with open(af, "r", encoding="utf-8", errors="ignore") as f:
                    for line in f:
                        parts = re.split(r'\t+', line.strip(), maxsplit=1)
                        if len(parts) == 2:
                            scene_captions[Path(parts[0]).stem] = parts[1]
        except Exception as e:
            print("Could not parse annotation file:", af, "->", e)
else:
    print("No Annotation dir found; captions will be empty.")

print("Total captions loaded:", len(scene_captions))

In [None]:
from pathlib import Path
import csv
import pandas as pd

SCENE_MANIFEST = f"{PROJ}/data/sketchycoco/scene_manifest.csv"
Path(os.path.dirname(SCENE_MANIFEST)).mkdir(parents=True, exist_ok=True)

print("Writing scene manifest to:", SCENE_MANIFEST)


In [None]:
train_photo_by_stem = {Path(p).stem: p for p in scene_img_train}
val_photo_by_stem   = {Path(p).stem: p for p in scene_img_val}

rows = []

def add_rows(split_name, sketch_list, photo_dict):
    missing = 0
    for sp in sketch_list:
        stem = Path(sp).stem
        if stem not in photo_dict:
            missing += 1
            continue
        photo_path = photo_dict[stem]
        # If no real caption, use a simple synthetic one
        cap = scene_captions.get(stem, "a realistic everyday scene")
        rows.append([split_name, sp, photo_path, cap])
    return missing

missing_train = add_rows("train", scene_sk_train, train_photo_by_stem)
missing_val   = add_rows("val",   scene_sk_val,   val_photo_by_stem)

print(f"Missing train pairs (no photo for stem): {missing_train}")
print(f"Missing val pairs   (no photo for stem): {missing_val}")
print("Total paired rows to write:", len(rows))

In [None]:
import csv, os
from pathlib import Path
import pandas as pd

SCENE_MANIFEST = f"{PROJ}/data/sketchycoco/scene_manifest.csv"
Path(os.path.dirname(SCENE_MANIFEST)).mkdir(parents=True, exist_ok=True)

print("Writing scene manifest to:", SCENE_MANIFEST)

with open(SCENE_MANIFEST, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["split", "sketch_path", "photo_path", "caption"])
    for r in rows:
        w.writerow(r)

print("✅ Wrote scene manifest!")

df_scene = pd.read_csv(SCENE_MANIFEST)
print("Rows in scene_manifest:", len(df_scene))
df_scene.head()


In [None]:
from pathlib import Path
import os, glob

# Reuse PROJ and RAW_OBJECT_ROOT
PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
RAW_OBJECT_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/Object"

print("RAW_OBJECT_ROOT:", RAW_OBJECT_ROOT, "Exists?", os.path.isdir(RAW_OBJECT_ROOT))
print("Contents of RAW_OBJECT_ROOT:")
for name in os.listdir(RAW_OBJECT_ROOT):
    print(" -", name)

# Helper to find specific subdirs by name (case-insensitive)
def find_subdir(root, name_options):
    name_options = {n.lower() for n in name_options}
    for entry in os.scandir(root):
        if entry.is_dir() and entry.name.lower() in name_options:
            return entry.path
    return None

# Helper: if a dir only has a single subdir (like 'paper_version'), dive into it
def dive_if_single_subdir(root):
    subdirs = [d for d in os.scandir(root) if d.is_dir()]
    if len(subdirs) == 1:
        print(f"Found single subdir '{subdirs[0].name}' under {root}, diving into it.")
        return subdirs[0].path
    return root

In [None]:
sketch_root_raw = find_subdir(RAW_OBJECT_ROOT, ["sketch"])
gt_root_raw     = find_subdir(RAW_OBJECT_ROOT, ["gt", "GT"])
edge_root_raw   = find_subdir(RAW_OBJECT_ROOT, ["edge", "Edge"])

print("\nInitial detected roots:")
print("  sketch_root_raw:", sketch_root_raw)
print("  gt_root_raw    :", gt_root_raw)
print("  edge_root_raw  :", edge_root_raw)

if sketch_root_raw is None or gt_root_raw is None or edge_root_raw is None:
    raise RuntimeError("Could not find sketch/GT/Edge dirs directly under RAW_OBJECT_ROOT. Check names.")

In [None]:
# 2) dive one level if they have a single subdir (like 'paper_version')
sketch_root = dive_if_single_subdir(sketch_root_raw)
gt_root     = dive_if_single_subdir(gt_root_raw)
edge_root   = dive_if_single_subdir(edge_root_raw)

print("\nUsing these as split roots:")
print("  sketch_root:", sketch_root)
print("  gt_root    :", gt_root)
print("  edge_root  :", edge_root)

In [None]:
# 3) list splits (train/val) under each
sketch_splits = sorted([d.name for d in os.scandir(sketch_root) if d.is_dir()])
gt_splits     = sorted([d.name for d in os.scandir(gt_root)     if d.is_dir()])
edge_splits   = sorted([d.name for d in os.scandir(edge_root)   if d.is_dir()])

print("\nSketch splits under", sketch_root, ":", sketch_splits)
print("GT splits under", gt_root, ":", gt_splits)
print("Edge splits under", edge_root, ":", edge_splits)

common_splits = sorted(set(sketch_splits) & set(gt_splits) & set(edge_splits))
print("Common splits (Sketch ∩ GT ∩ Edge):", common_splits)

# Treat everything except 'val' as train; 'val' as val
train_splits = [s for s in common_splits if s.lower() != "val"]
val_splits   = [s for s in common_splits if s.lower() == "val"]

print("Train splits we will use:", train_splits)
print("Val splits we will use  :", val_splits)

# Glob helper
def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))


In [None]:
# 4) collect file paths (no copying)
obj_sk_train, obj_img_train, obj_edge_train = [], [], []
for split in train_splits:
    obj_sk_train   += g(f"{sketch_root}/{split}/**/*.*")
    obj_img_train  += g(f"{gt_root}/{split}/**/*.*")
    obj_edge_train += g(f"{edge_root}/{split}/**/*.*")

obj_sk_val, obj_img_val, obj_edge_val = [], [], []
for split in val_splits:
    obj_sk_val   += g(f"{sketch_root}/{split}/**/*.*")
    obj_img_val  += g(f"{gt_root}/{split}/**/*.*")
    obj_edge_val += g(f"{edge_root}/{split}/**/*.*")

print("\nRaw counts (found files):")
print("  TRAIN  sketches:", len(obj_sk_train))
print("  TRAIN  images  :", len(obj_img_train))
print("  TRAIN  edges   :", len(obj_edge_train))
print("  VAL    sketches:", len(obj_sk_val))
print("  VAL    images  :", len(obj_img_val))
print("  VAL    edges   :", len(obj_edge_val))

In [None]:
# 5) quick integrity check: just count how many have matching photos (no IO)
from pathlib import Path

def verify_pairs_only(src_list, tgt_list):
    by_stem = {Path(p).stem: p for p in tgt_list}
    paired = 0
    for s in src_list:
        if Path(s).stem in by_stem:
            paired += 1
    return paired

train_sk_photo_pairs = verify_pairs_only(obj_sk_train, obj_img_train)
val_sk_photo_pairs   = verify_pairs_only(obj_sk_val,  obj_img_val)
train_edge_photo_pairs = verify_pairs_only(obj_edge_train, obj_img_train)
val_edge_photo_pairs   = verify_pairs_only(obj_edge_val,  obj_img_val)

print("\nData integrity check (no files copied):")
print(f"  Train sketch-photo pairs: {train_sk_photo_pairs}")
print(f"  Val   sketch-photo pairs: {val_sk_photo_pairs}")
print(f"  Train edge-photo pairs  : {train_edge_photo_pairs}")
print(f"  Val   edge-photo pairs  : {val_edge_photo_pairs}")

In [None]:
from pathlib import Path
import csv, pandas as pd

DATA_ROOT = f"{PROJ}/data/sketchycoco"
Path(DATA_ROOT).mkdir(parents=True, exist_ok=True)

OBJECT_MANIFEST = f"{DATA_ROOT}/object_manifest.csv"
OBJECT_EDGES_MANIFEST = f"{DATA_ROOT}/object_edges_manifest.csv"

print("Will write:")
print("  OBJECT_MANIFEST       :", OBJECT_MANIFEST)
print("  OBJECT_EDGES_MANIFEST :", OBJECT_EDGES_MANIFEST)

In [None]:
# 1) Build lookup dicts by stem
train_photo_by_stem = {Path(p).stem: p for p in obj_img_train}
val_photo_by_stem   = {Path(p).stem: p for p in obj_img_val}

train_edge_by_stem  = {Path(p).stem: p for p in obj_edge_train}
val_edge_by_stem    = {Path(p).stem: p for p in obj_edge_val}

object_rows = []        # for object_manifest (sketch-photo-caption)
object_edge_rows = []   # for object_edges_manifest (edge-photo)

def add_object_rows(split_name, sketch_list, photo_dict):
    missing = 0
    for sp in sketch_list:
        stem = Path(sp).stem
        if stem not in photo_dict:
            missing += 1
            continue
        photo_path = photo_dict[stem]
        # simple synthetic caption for now
        cap = "a realistic photo of an everyday object"
        object_rows.append([split_name, sp, photo_path, cap])
    return missing

def add_edge_rows(split_name, edge_list, photo_dict):
    missing = 0
    for ep in edge_list:
        stem = Path(ep).stem
        if stem not in photo_dict:
            missing += 1
            continue
        photo_path = photo_dict[stem]
        object_edge_rows.append([split_name, ep, photo_path])
    return missing

missing_train_obj  = add_object_rows("train", obj_sk_train, train_photo_by_stem)
missing_val_obj    = add_object_rows("val",   obj_sk_val,   val_photo_by_stem)
missing_train_edge = add_edge_rows("train",   obj_edge_train, train_photo_by_stem)
missing_val_edge   = add_edge_rows("val",     obj_edge_val,   val_photo_by_stem)

print("\nMissing sketch-photo pairs:")
print("  train:", missing_train_obj)
print("  val  :", missing_val_obj)
print("Missing edge-photo pairs:")
print("  train:", missing_train_edge)
print("  val  :", missing_val_edge)

print("\nTotal rows to write:")
print("  object_rows       :", len(object_rows))
print("  object_edge_rows  :", len(object_edge_rows))

In [None]:
# 2) Write CSVs
with open(OBJECT_MANIFEST, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["split", "sketch_path", "photo_path", "caption"])
    for r in object_rows:
        w.writerow(r)

with open(OBJECT_EDGES_MANIFEST, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["split", "edge_path", "photo_path"])
    for r in object_edge_rows:
        w.writerow(r)

print("\n✅ Wrote object manifests.")

df_obj = pd.read_csv(OBJECT_MANIFEST)
df_obj_edges = pd.read_csv(OBJECT_EDGES_MANIFEST)
print("Rows in object_manifest       :", len(df_obj))
print("Rows in object_edges_manifest :", len(df_obj_edges))
df_obj.head()

In [None]:
#fscoco

In [None]:
from pathlib import Path
import os, glob

# We keep the same project root
PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"

#  EDIT THIS to your actual FS-COCO folder in Drive
# Example guesses (you pick the correct one in your Drive):
# RAW_FSC_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/FSCOCO"
# RAW_FSC_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/fscoco"
RAW_FSC_ROOT =  "/content/drive/MyDrive/Gen AI/Gen AI project/fscoco/fscoco"  # <-- changed to point to the nested fscoco folder

print("RAW_FSC_ROOT:", RAW_FSC_ROOT, "Exists?", os.path.isdir(RAW_FSC_ROOT))
print("Contents of RAW_FSC_ROOT:")
if os.path.isdir(RAW_FSC_ROOT):
    for name in os.listdir(RAW_FSC_ROOT):
        print(" -", name)
else:
    raise RuntimeError("RAW_FSC_ROOT path is wrong. Fix it before continuing.")

In [None]:
# Helper to find subdir by name (case-insensitive, partial match)
def find_subdir(root, name_options):
    name_options = [n.lower() for n in name_options]
    for entry in os.scandir(root):
        if entry.is_dir():
            lower = entry.name.lower()
            if any(opt in lower for opt in name_options):
                return entry.path
    return None

images_root = find_subdir(RAW_FSC_ROOT, ["image"])
rast_root   = find_subdir(RAW_FSC_ROOT, ["raster"])
vec_root    = find_subdir(RAW_FSC_ROOT, ["vector"])
text_root   = find_subdir(RAW_FSC_ROOT, ["text"])

print("\nDetected sub-roots:")
print("  images_root:", images_root)
print("  rast_root  :", rast_root)
print("  vec_root   :", vec_root)
print("  text_root  :", text_root)

if images_root is None or text_root is None or (rast_root is None and vec_root is None):
    raise RuntimeError("Could not find images/text/raster/vector subfolders. Check FS-COCO structure.")

from collections import defaultdict

def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))

# We'll build dicts by (bucket, stem)
images_by_key = {}
sketches_by_key = {}
captions_by_key = {}


In [None]:
# 1) Images: images/<bucket>/<filename>.<ext>
for bucket_entry in os.scandir(images_root):
    if not bucket_entry.is_dir():
        continue
    bucket = bucket_entry.name
    for img_path in g(f"{bucket_entry.path}/**/*.*"):
        stem = Path(img_path).stem
        images_by_key[(bucket, stem)] = img_path

In [None]:
# 2) Sketches: prefer raster_sketches, else fall back to vector_sketches
if rast_root is not None:
    for bucket_entry in os.scandir(rast_root):
        if not bucket_entry.is_dir():
            continue
        bucket = bucket_entry.name
        for sk_path in g(f"{bucket_entry.path}/**/*.*"):
            stem = Path(sk_path).stem
            sketches_by_key[(bucket, stem)] = sk_path

if vec_root is not None:
    for bucket_entry in os.scandir(vec_root):
        if not bucket_entry.is_dir():
            continue
        bucket = bucket_entry.name
        for sk_path in g(f"{bucket_entry.path}/**/*.*"):
            stem = Path(sk_path).stem
            # don't overwrite a raster sketch if we already have one
            sketches_by_key.setdefault((bucket, stem), sk_path)

In [None]:
# 3) Captions: assume text/<bucket>/*.txt, one caption per image
for bucket_entry in os.scandir(text_root):
    if not bucket_entry.is_dir():
        continue
    bucket = bucket_entry.name
    for txt_path in g(f"{bucket_entry.path}/**/*.txt"):
        stem = Path(txt_path).stem
        try:
            with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
                caption = f.read().strip()
        except Exception:
            caption = ""
        captions_by_key[(bucket, stem)] = caption

print("\nCounts per dictionary:")
print("  images_by_key   :", len(images_by_key))
print("  sketches_by_key :", len(sketches_by_key))
print("  captions_by_key :", len(captions_by_key))

In [None]:
# 4) Build a master list of valid triplets: (sketch, image, caption)
fscoco_triplets = []
for key, img_path in images_by_key.items():
    sketch_path = sketches_by_key.get(key, None)
    caption     = captions_by_key.get(key, "")
    if sketch_path is None:
        continue  # require sketch to exist
    if caption == "":
        # you can skip if you want strict text, but let's keep them and add synthetic caption
        caption = "a realistic everyday scene or object"
    fscoco_triplets.append((img_path, sketch_path, caption))

print("\nTotal valid (image, sketch, caption) triplets:", len(fscoco_triplets))

In [None]:
import random
from pathlib import Path
import csv, pandas as pd

print("Total triplets before split:", len(fscoco_triplets))

# Fixed seed for reproducibility
random.seed(1337)
random.shuffle(fscoco_triplets)

n = len(fscoco_triplets)
n_train = int(0.8 * n)
n_val   = int(0.1 * n)
n_test  = n - n_train - n_val

train_triplets = fscoco_triplets[:n_train]
val_triplets   = fscoco_triplets[n_train:n_train+n_val]
test_triplets  = fscoco_triplets[n_train+n_val:]

print("Split sizes:")
print("  train:", len(train_triplets))
print("  val  :", len(val_triplets))
print("  test :", len(test_triplets))

FSC_MANIFEST = f"{PROJ}/data/fscoco_manifest.csv"
Path(os.path.dirname(FSC_MANIFEST)).mkdir(parents=True, exist_ok=True)

rows = []
for split_name, triples in [
    ("train", train_triplets),
    ("val",   val_triplets),
    ("test",  test_triplets),
]:
    for img_path, sketch_path, caption in triples:
        rows.append([split_name, sketch_path, img_path, caption])

print("Total rows to write:", len(rows))

with open(FSC_MANIFEST, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["split", "sketch_path", "photo_path", "caption"])
    for r in rows:
        w.writerow(r)

print("✅ Wrote FS-COCO manifest:", FSC_MANIFEST)

df_fsc = pd.read_csv(FSC_MANIFEST)
print("Rows in fscoco_manifest:", len(df_fsc))
df_fsc.head()


In [None]:
from pathlib import Path
import os, glob

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"

# EDIT THIS to the folder where your Edges2Shoes dataset lives
# Example guesses:
# RAW_E2S_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/edges2shoes"
# RAW_E2S_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/edge2shoes"
RAW_E2S_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/edges2shoes/edges2shoes"  # <-- change if needed

print("RAW_E2S_ROOT:", RAW_E2S_ROOT, "Exists?", os.path.isdir(RAW_E2S_ROOT))
print("Contents of RAW_E2S_ROOT:")
if os.path.isdir(RAW_E2S_ROOT):
    for name in os.listdir(RAW_E2S_ROOT):
        print(" -", name)
else:
    raise RuntimeError("RAW_E2S_ROOT path is wrong. Fix it before continuing.")


In [None]:
import re

def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))

subdirs = [d for d in os.scandir(RAW_E2S_ROOT) if d.is_dir()]
subdir_names = [d.name.lower() for d in subdirs]
print("\nTop-level subdirs:", subdir_names)

# Case 1: Pix2Pix official format: 'train', 'val', 'test' with concatenated edge+photo
has_train_val_test = all(name in subdir_names for name in ["train", "val"])  # test is optional

# Case 2: Separated A/B or edges/images
has_trainA = "traina" in subdir_names or "train_a" in subdir_names
has_edges  = any("edge" in name for name in subdir_names)

if has_train_val_test:
    print("\nDetected Pix2Pix-style concatenated format (train/val[/test]).")
    # We'll treat each file as a single concat image, and store just that path.
    e2s_entries = []  # (split, concat_path)

    for split in ["train", "val", "test"]:
        split_dir = os.path.join(RAW_E2S_ROOT, split)
        if not os.path.isdir(split_dir):
            continue
        img_paths = g(f"{split_dir}/**/*.*")
        for p in img_paths:
            e2s_entries.append((split, p))

    print("Total concatenated images found:", len(e2s_entries))

    # We'll write a manifest later: split, concat_path
    e2s_mode = "concat"

elif has_trainA:
    print("\nDetected trainA/trainB style (separate inputs/targets).")
    # Example: trainA = edges, trainB = photos
    def find_subdir(root, name_substring):
        for entry in os.scandir(root):
            if entry.is_dir() and name_substring in entry.name.lower():
                return entry.path
        return None

    trainA = find_subdir(RAW_E2S_ROOT, "traina")
    trainB = find_subdir(RAW_E2S_ROOT, "trainb")
    valA   = find_subdir(RAW_E2S_ROOT, "vala")
    valB   = find_subdir(RAW_E2S_ROOT, "valb")
    testA  = find_subdir(RAW_E2S_ROOT, "testa")
    testB  = find_subdir(RAW_E2S_ROOT, "testb")

    print("trainA:", trainA)
    print("trainB:", trainB)
    print("valA  :", valA)
    print("valB  :", valB)
    print("testA :", testA)
    print("testB :", testB)

    def build_pairs(input_root, target_root, split_name):
        if input_root is None or target_root is None or not os.path.isdir(input_root) or not os.path.isdir(target_root):
            return []
        input_paths = g(f"{input_root}/**/*.*")
        target_by_stem = {Path(p).stem: p for p in g(f"{target_root}/**/*.*")}
        pairs = []
        for ip in input_paths:
            stem = Path(ip).stem
            if stem in target_by_stem:
                pairs.append((split_name, ip, target_by_stem[stem]))
        return pairs

    e2s_edges_pairs = []  # (split, edge_path, photo_path)
    e2s_edges_pairs += build_pairs(trainA, trainB, "train")
    e2s_edges_pairs += build_pairs(valA,   valB,   "val")
    e2s_edges_pairs += build_pairs(testA,  testB,  "test")

    print("Total edge-photo pairs:", len(e2s_edges_pairs))
    e2s_mode = "edges_pairs"

elif has_edges:
    print("\nDetected edges/images style (edges + images dirs).")
    edges_dir = None
    imgs_dir  = None
    for entry in subdirs:
        lname = entry.name.lower()
        if "edge" in lname:
            edges_dir = entry.path
        elif "img" in lname or "image" in lname or "photo" in lname:
            imgs_dir = entry.path

    print("edges_dir:", edges_dir)
    print("imgs_dir :", imgs_dir)

    if edges_dir is None or imgs_dir is None:
        raise RuntimeError("Could not clearly find edges/images directories.")

    edge_paths = g(f"{edges_dir}/**/*.*")
    img_by_stem = {Path(p).stem: p for p in g(f"{imgs_dir}/**/*.*")}

    e2s_edges_pairs = []
    for ep in edge_paths:
        stem = Path(ep).stem
        if stem in img_by_stem:
            # we don't know split; treat all as 'train' for now, we can subset later
            e2s_edges_pairs.append(("train", ep, img_by_stem[stem]))

    print("Total edge-photo pairs:", len(e2s_edges_pairs))
    e2s_mode = "edges_pairs"

else:
    raise RuntimeError("Could not recognize Edges2Shoes structure. Check RAW_E2S_ROOT contents.")


In [None]:
from pathlib import Path
import csv, pandas as pd
import os

DATA_ROOT = f"{PROJ}/data"
Path(DATA_ROOT).mkdir(parents=True, exist_ok=True)

E2S_MANIFEST = f"{DATA_ROOT}/edges2shoes_manifest.csv"
print("Will write Edges2Shoes manifest to:", E2S_MANIFEST)

rows = []

if e2s_mode == "concat":
    # e2s_entries: (split, concat_path)
    for split, concat_path in e2s_entries:
        # For concat images, we store just the one path; pix2pix loader will split left/right later.
        caption = "a realistic photo of a shoe"
        rows.append([split, concat_path, "", caption])  # edge_path left empty for now

    header = ["split", "concat_path", "photo_path", "caption"]

elif e2s_mode == "edges_pairs":
    # e2s_edges_pairs: (split, edge_path, photo_path)
    for split, edge_path, photo_path in e2s_edges_pairs:
        caption = "a realistic photo of a shoe"
        rows.append([split, edge_path, photo_path, caption])

    header = ["split", "edge_path", "photo_path", "caption"]

else:
    raise RuntimeError(f"Unknown e2s_mode: {e2s_mode}")

print("Total rows to write:", len(rows))

with open(E2S_MANIFEST, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(header)
    for r in rows:
        w.writerow(r)

print("Wrote Edges2Shoes manifest:", E2S_MANIFEST)

df_e2s = pd.read_csv(E2S_MANIFEST)
print("Rows in edges2shoes_manifest:", len(df_e2s))
df_e2s.head()


In [None]:
import os
from pathlib import Path
import pandas as pd

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
DATA_ROOT = f"{PROJ}/data"
SUBSET_ROOT = f"{PROJ}/subsets"
Path(SUBSET_ROOT).mkdir(parents=True, exist_ok=True)

print("DATA_ROOT   :", DATA_ROOT)
print("SUBSET_ROOT :", SUBSET_ROOT)

# ---- Helper: load + sample per split ----
def make_subset(
    in_path,
    out_path,
    target_per_split,
    split_col="split",
    seed=1337,
):
    print(f"\n=== Subsetting {os.path.basename(in_path)} ===")
    df = pd.read_csv(in_path)
    print("  Total rows:", len(df))
    print("  Split counts:\n", df[split_col].value_counts())

    rows = []
    for split, target_n in target_per_split.items():
        df_split = df[df[split_col] == split]
        available = len(df_split)
        if available == 0:
            print(f"  [WARN] No rows for split '{split}' in {in_path}")
            continue
        n = min(target_n, available)
        if n < target_n:
            print(f"  [INFO] Split '{split}': requested {target_n}, only {available} available -> taking {n}.")
        sampled = df_split.sample(n=n, random_state=seed)
        rows.append(sampled)

    if not rows:
        print("  [ERROR] No rows selected at all, skipping write.")
        return

    df_sub = pd.concat(rows, axis=0).reset_index(drop=True)
    df_sub.to_csv(out_path, index=False)
    print(f" Wrote subset: {out_path}")
    print("  Subset split counts:\n", df_sub[split_col].value_counts())

In [None]:
# ========= 1) Scene-level SketchyCOCO =========
SCENE_MANIFEST = f"{DATA_ROOT}/sketchycoco/scene_manifest.csv"
SCENE_SUBSET   = f"{SUBSET_ROOT}/scene_subset.csv"

# Target sizes (will auto-reduce if not enough data)
scene_targets = {
    "train": 4000,
    "val":   400,
}
make_subset(SCENE_MANIFEST, SCENE_SUBSET, scene_targets)


In [None]:
# ========= 2) Object-level SketchyCOCO (sketch-photo) =========
OBJECT_MANIFEST = f"{DATA_ROOT}/sketchycoco/object_manifest.csv"
OBJECT_SUBSET   = f"{SUBSET_ROOT}/object_subset.csv"

object_targets = {
    "train": 4000,
    "val":   400,
}
make_subset(OBJECT_MANIFEST, OBJECT_SUBSET, object_targets)

In [None]:
# ========= 3) Object-level edges (for Pix2Pix) =========
OBJECT_EDGES_MANIFEST = f"{DATA_ROOT}/sketchycoco/object_edges_manifest.csv"
OBJECT_EDGES_SUBSET   = f"{SUBSET_ROOT}/object_edges_subset.csv"

object_edges_targets = {
    "train": 4000,
    "val":   400,
}
make_subset(OBJECT_EDGES_MANIFEST, OBJECT_EDGES_SUBSET, object_edges_targets)

In [None]:
# ========= 4) FS-COCO =========
FSC_MANIFEST = f"{DATA_ROOT}/fscoco_manifest.csv"
FSC_SUBSET   = f"{SUBSET_ROOT}/fscoco_subset.csv"

fsc_targets = {
    "train": 3000,
    "val":   300,
    "test":  600,
}
make_subset(FSC_MANIFEST, FSC_SUBSET, fsc_targets)

In [None]:
# ========= 5) Edges2Shoes =========
E2S_MANIFEST = f"{DATA_ROOT}/edges2shoes_manifest.csv"
E2S_SUBSET   = f"{SUBSET_ROOT}/edges2shoes_subset.csv"

# We don't know exact split distribution, but aim for:
e2s_targets = {
    "train": 5000,
    "val":   500,
    "test":  1000,  # if 'test' doesn't exist, helper will warn and skip
}
make_subset(E2S_MANIFEST, E2S_SUBSET, e2s_targets)

In [None]:
print("\n All subset creation attempts finished.")

In [None]:
!pip install -q accelerate transformers safetensors
!pip install -q "diffusers[torch]" peft controlnet-aux
!pip install -q clean-fid lpips open-clip-torch


In [None]:
import os
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
DATA_ROOT = f"{PROJ}/data"
SUBSET_ROOT = f"{PROJ}/subsets"

print("Using project root:", PROJ)
print("Data root   :", DATA_ROOT)
print("Subsets dir :", SUBSET_ROOT)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# image size we’ll standardize to (you can change later to 512 if GPU is comfy)
IMAGE_SIZE = 512

# Paths to subset CSVs we created
SCENE_SUBSET_CSV        = f"{SUBSET_ROOT}/scene_subset.csv"
OBJECT_SUBSET_CSV       = f"{SUBSET_ROOT}/object_subset.csv"
FSC_SUBSET_CSV          = f"{SUBSET_ROOT}/fscoco_subset.csv"
OBJECT_EDGES_SUBSET_CSV = f"{SUBSET_ROOT}/object_edges_subset.csv"
E2S_SUBSET_CSV          = f"{SUBSET_ROOT}/edges2shoes_subset.csv"

for p in [SCENE_SUBSET_CSV, OBJECT_SUBSET_CSV, FSC_SUBSET_CSV]:
    print(p, "exists?", os.path.isfile(p))


In [None]:
paths = {
    "SCENE_SUBSET_CSV"       : SCENE_SUBSET_CSV,
    "OBJECT_SUBSET_CSV"      : OBJECT_SUBSET_CSV,
    "FSC_SUBSET_CSV"         : FSC_SUBSET_CSV,
    "OBJECT_EDGES_SUBSET_CSV": OBJECT_EDGES_SUBSET_CSV,
    "E2S_SUBSET_CSV"         : E2S_SUBSET_CSV,
}

for name, p in paths.items():
    print(f"{name}: {p} -> exists? {os.path.isfile(p)}")


In [None]:
import numpy as np

def pil_to_tensor(img: Image.Image, size=IMAGE_SIZE, to_gray=False):
    """
    Resize, optionally convert to gray, convert to float tensor in [-1, 1].
    Returns shape (C, H, W).
    """
    if to_gray:
        img = img.convert("L")  # 1-channel
    else:
        img = img.convert("RGB")

    img = img.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype("float32") / 255.0

    if to_gray:
        if arr.ndim == 2:
            arr = arr[..., None]   # (H, W, 1)
    else:
        if arr.ndim == 2:  # just in case
            arr = np.stack([arr]*3, axis=-1)

    # HWC -> CHW
    arr = arr.transpose(2, 0, 1)
    # [0,1] -> [-1,1]
    arr = (arr * 2.0) - 1.0

    return torch.from_numpy(arr)


In [None]:
class Sketch2ImageDataset(Dataset):
    """
    Generic sketch → image dataset for diffusion, using manifests:
    columns: split, sketch_path, photo_path, caption
    """

    def __init__(self, manifest_csv, split="train", image_size=IMAGE_SIZE):
        self.manifest_csv = manifest_csv
        self.split = split
        self.image_size = image_size

        df = pd.read_csv(manifest_csv)
        if "split" not in df.columns:
            raise ValueError(f"'split' column not found in {manifest_csv}")

        self.df = df[df["split"] == split].reset_index(drop=True)
        if len(self.df) == 0:
            raise ValueError(f"No rows for split='{split}' in {manifest_csv}")

        print(f"[{os.path.basename(manifest_csv)}] Loaded {len(self.df)} rows for split='{split}'")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        sketch_path = row["sketch_path"]
        photo_path  = row["photo_path"]
        caption     = row.get("caption", "")

        # Load images
        sketch_img = Image.open(sketch_path)
        photo_img  = Image.open(photo_path)

        # Convert to tensors
        sketch_tensor = pil_to_tensor(sketch_img, size=self.image_size, to_gray=True)   # (1, H, W)
        photo_tensor  = pil_to_tensor(photo_img,  size=self.image_size, to_gray=False)  # (3, H, W)

        return {
            "sketch": sketch_tensor,
            "image":  photo_tensor,
            "caption": str(caption),
            "sketch_path": sketch_path,
            "photo_path":  photo_path,
        }


In [None]:
# Scene-level dataset test
scene_train_ds = Sketch2ImageDataset(SCENE_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
scene_train_dl = DataLoader(scene_train_ds, batch_size=4, shuffle=True)

batch = next(iter(scene_train_dl))
print("Scene batch:")
print("  sketch:", batch["sketch"].shape, batch["sketch"].dtype, batch["sketch"].min().item(), batch["sketch"].max().item())
print("  image :", batch["image"].shape, batch["image"].dtype, batch["image"].min().item(), batch["image"].max().item())
print("  captions[0]:", batch["caption"][0])

# Object-level dataset test
obj_train_ds = Sketch2ImageDataset(OBJECT_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
obj_train_dl = DataLoader(obj_train_ds, batch_size=4, shuffle=True)
batch2 = next(iter(obj_train_dl))
print("\nObject batch:")
print("  sketch:", batch2["sketch"].shape)
print("  image :", batch2["image"].shape)
print("  captions[0]:", batch2["caption"][0])

# FS-COCO dataset test
fsc_train_ds = Sketch2ImageDataset(FSC_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
fsc_train_dl = DataLoader(fsc_train_ds, batch_size=4, shuffle=True)
batch3 = next(iter(fsc_train_dl))
print("\nFS-COCO batch:")
print("  sketch:", batch3["sketch"].shape)
print("  image :", batch3["image"].shape)
print("  captions[0]:", batch3["caption"][0])


In [None]:
#Load Stable Diffusion 1.5 & tokenizer

In [None]:
!pip install -q --upgrade diffusers transformers accelerate safetensors peft


In [None]:
import torch
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"

torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device, "| dtype:", torch_dtype)

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    safety_checker=None,   # fine for local research use
).to(device)

tokenizer = pipe.tokenizer
vae        = pipe.vae
unet       = pipe.unet
scheduler  = pipe.scheduler

pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("xFormers enabled")
except Exception as e:
    print("xFormers not enabled:", e)

print("Pipeline loaded.")


In [None]:
from torch.utils.data import DataLoader
import pandas as pd
import os
from pathlib import Path
from PIL import Image
import numpy as np

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
SUBSET_ROOT = f"{PROJ}/subsets"
OBJECT_SUBSET_CSV = f"{SUBSET_ROOT}/object_subset.csv"

print("Object subset exists?", os.path.isfile(OBJECT_SUBSET_CSV))

IMAGE_SIZE = 512

def pil_to_tensor(img: Image.Image, size=IMAGE_SIZE, to_gray=False):
    if to_gray:
        img = img.convert("L")
    else:
        img = img.convert("RGB")

    img = img.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype("float32") / 255.0

    if to_gray:
        if arr.ndim == 2:
            arr = arr[..., None]
    else:
        if arr.ndim == 2:
            arr = np.stack([arr]*3, axis=-1)

    arr = arr.transpose(2, 0, 1)
    arr = (arr * 2.0) - 1.0
    return torch.from_numpy(arr)

class Sketch2ImageDataset(torch.utils.data.Dataset):
    def __init__(self, manifest_csv, split="train", image_size=IMAGE_SIZE):
        self.manifest_csv = manifest_csv
        self.split = split
        self.image_size = image_size

        df = pd.read_csv(manifest_csv)
        if "split" not in df.columns:
            raise ValueError(f"'split' column not found in {manifest_csv}")

        self.df = df[df["split"] == split].reset_index(drop=True)
        if len(self.df) == 0:
            raise ValueError(f"No rows for split='{split}' in {manifest_csv}")

        print(f"[{os.path.basename(manifest_csv)}] Loaded {len(self.df)} rows for split='{split}'")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sketch_path = row["sketch_path"]
        photo_path  = row["photo_path"]
        caption     = row.get("caption", "")

        sketch_img = Image.open(sketch_path)
        photo_img  = Image.open(photo_path)

        sketch_tensor = pil_to_tensor(sketch_img, size=self.image_size, to_gray=True)
        photo_tensor  = pil_to_tensor(photo_img,  size=self.image_size, to_gray=False)

        return {
            "sketch": sketch_tensor,
            "image":  photo_tensor,
            "caption": str(caption),
            "sketch_path": sketch_path,
            "photo_path":  photo_path,
        }

obj_train_ds = Sketch2ImageDataset(OBJECT_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
obj_val_ds   = Sketch2ImageDataset(OBJECT_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)

BATCH_SIZE = 2
obj_train_dl = DataLoader(obj_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
obj_val_dl   = DataLoader(obj_val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

batch = next(iter(obj_train_dl))
print("Train batch shapes:")
print("  sketch:", batch["sketch"].shape)
print("  image :", batch["image"].shape)
print("  caption[0]:", batch["caption"][0])


In [None]:
from peft import LoraConfig
import torch

LORA_RANK = 8

def add_lora_to_unet(unet, rank=LORA_RANK):
    lora_config = LoraConfig(
        r=rank,
        lora_alpha=rank,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        init_lora_weights="gaussian",
    )
    unet.add_adapter(lora_config)
    return unet

# 1) Add adapters
add_lora_to_unet(unet, rank=LORA_RANK)

# 2) Freeze all params
for param in unet.parameters():
    param.requires_grad_(False)

# 3) Unfreeze only LoRA params
lora_params = []
for name, param in unet.named_parameters():
    if "lora" in name:
        param.requires_grad_(True)
        lora_params.append(param)

print("LoRA successfully added via PEFT.")

trainable_params = sum(p.numel() for p in lora_params)
print(f"Trainable LoRA params: {trainable_params:,}")


In [None]:
import torch.nn.functional as F

def encode_prompts(captions, tokenizer, device, dtype):
    text_inputs = tokenizer(
        list(captions),
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    input_ids = text_inputs.input_ids.to(device)
    with torch.no_grad():
        text_embeds = pipe.text_encoder(input_ids)[0]
    return text_embeds.to(device=device, dtype=dtype)

def encode_images_to_latents(images: torch.Tensor, vae, device, dtype):
    images = images.to(device=device, dtype=dtype)
    with torch.no_grad():
        latents = vae.encode(images).latent_dist.sample() * 0.18215
    return latents


In [None]:
# from torch.optim import AdamW
# from tqdm.auto import tqdm

# learning_rate = 1e-4
# optimizer = AdamW(lora_params, lr=learning_rate)

# scheduler = pipe.scheduler  # just use as-is; no set_format

# NUM_EPOCHS = 1
# MAX_TRAIN_STEPS = 2000
# global_step = 0

# unet.train()
# torch_dtype = vae.dtype  # should be float16 on cuda

# for epoch in range(NUM_EPOCHS):
#     pbar = tqdm(obj_train_dl, desc=f"Epoch {epoch+1}")
#     for batch in pbar:
#         if global_step >= MAX_TRAIN_STEPS:
#             break

#         images   = batch["image"]         # (B,3,H,W)
#         captions = batch["caption"]       # list[str]

#         # Encode
#         text_embeds = encode_prompts(captions, tokenizer, device, torch_dtype)
#         latents     = encode_images_to_latents(images, vae, device, torch_dtype)

#         # Noise & timesteps
#         bsz = latents.shape[0]
#         noise = torch.randn_like(latents)
#         timesteps = torch.randint(
#             0, scheduler.config.num_train_timesteps,
#             (bsz,),
#             device=device,
#             dtype=torch.long,
#         )
#         noisy_latents = scheduler.add_noise(latents, noise, timesteps)

#         # UNet forward
#         model_pred = unet(
#             noisy_latents,
#             timesteps,
#             encoder_hidden_states=text_embeds,
#         ).sample

#         # Loss
#         loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

#         optimizer.zero_grad()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
#         optimizer.step()

#         global_step += 1
#         pbar.set_postfix({"loss": loss.item(), "step": global_step})

#     if global_step >= MAX_TRAIN_STEPS:
#         print("Reached max train steps, stopping.")
#         break

# print("Training loop finished.")


In [None]:
from pathlib import Path

SAVE_DIR = f"{PROJ}/checkpoints/lora_object_unet"
Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

# Extract only the LoRA parameters from the UNet state dict
unet_lora_layers = {k: v for k, v in unet.state_dict().items() if "lora" in k}

# Pass them explicitly to the save function
pipe.save_lora_weights(SAVE_DIR, unet_lora_layers=unet_lora_layers)
print("Saved LoRA weights to:", SAVE_DIR)

In [None]:
# fallback: save full unet with LoRA baked in
UNET_SAVE_DIR = f"{PROJ}/checkpoints/unet_with_lora"
Path(UNET_SAVE_DIR).mkdir(parents=True, exist_ok=True)
unet.save_pretrained(UNET_SAVE_DIR)
print("Saved full UNet (with LoRA) to:", UNET_SAVE_DIR)


In [None]:
import torch
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None,
).to(device)

LORA_DIR = f"{PROJ}/checkpoints/lora_object_unet"

pipe.load_lora_weights(LORA_DIR)
pipe.enable_attention_slicing()

prompt = "a realistic photo of an everyday object on a clean background"
with torch.autocast(device_type="cuda", dtype=torch.float16):
    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

from IPython.display import display
display(image)


In [None]:
import torch
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None,
).to(device)

pipe.enable_attention_slicing()

prompt = "a high quality photo of a red sneaker on a white table"
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_base = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

from IPython.display import display
display(img_base)


In [None]:
import torch

def check_lora_for_nans(unet):
    has_nan = False
    for name, p in unet.named_parameters():
        if "lora" in name:
            if torch.isnan(p).any() or torch.isinf(p).any():
                print("❌ NaNs/Infs in", name)
                has_nan = True
    if not has_nan:
        print("✅ No NaNs/Infs found in LoRA params.")
    return has_nan

check_lora_for_nans(unet)


In [None]:
import torch
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32,   # 👈 train in full precision
    safety_checker=None,
).to(device)

tokenizer = pipe.tokenizer
vae       = pipe.vae
unet      = pipe.unet
scheduler = pipe.scheduler

pipe.enable_attention_slicing()
pipe.enable_vae_slicing()

print("Pipeline reloaded in float32.")


In [None]:
from peft import LoraConfig

LORA_RANK = 4  # keep small & stable to start

def add_lora_to_unet(unet, rank=LORA_RANK):
    lora_config = LoraConfig(
        r=rank,
        lora_alpha=rank,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        init_lora_weights="gaussian",
    )
    unet.add_adapter(lora_config)
    return unet

# fresh adapters
add_lora_to_unet(unet, rank=LORA_RANK)

# freeze base weights
for p in unet.parameters():
    p.requires_grad_(False)

# collect only LoRA params
lora_params = []
for name, p in unet.named_parameters():
    if "lora" in name:
        p.requires_grad_(True)
        lora_params.append(p)

print("LoRA params:", sum(p.numel() for p in lora_params))

# sanity check: there should be NO NaNs now
def check_lora_for_nans(unet):
    has_nan = False
    for name, p in unet.named_parameters():
        if "lora" in name:
            if torch.isnan(p).any() or torch.isinf(p).any():
                print("❌ NaNs/Infs in", name)
                has_nan = True
    if not has_nan:
        print("No NaNs/Infs found in LoRA params.")
    return has_nan

check_lora_for_nans(unet)


In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm
import torch.nn.functional as F

learning_rate = 5e-5       # smaller LR for stability
optimizer = AdamW(lora_params, lr=learning_rate)

NUM_EPOCHS = 1
MAX_TRAIN_STEPS = 2000      # do a short run just to confirm it behaves
global_step = 0

unet.train()

for epoch in range(NUM_EPOCHS):
    pbar = tqdm(obj_train_dl, desc=f"Epoch {epoch+1}")
    for batch in pbar:
        if global_step >= MAX_TRAIN_STEPS:
            break

        images   = batch["image"]        # (B,3,H,W), float32 [-1,1]
        captions = batch["caption"]

        # text & latents stay in float32
        text_embeds = encode_prompts(captions, tokenizer, device, torch.float32)
        latents     = encode_images_to_latents(images, vae, device, torch.float32)

        bsz = latents.shape[0]
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, scheduler.config.num_train_timesteps,
            (bsz,),
            device=device,
            dtype=torch.long,
        )
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=text_embeds,
        ).sample

        loss = F.mse_loss(model_pred, noise, reduction="mean")

        # NaN/Inf guard
        if not torch.isfinite(loss):
            print(f"❌ Loss became NaN/Inf at step {global_step}, stopping training.")
            break

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
        optimizer.step()

        global_step += 1
        pbar.set_postfix({"loss": float(loss.item()), "step": global_step})

    if global_step >= MAX_TRAIN_STEPS:
        print("Reached max train steps, stopping.")
        break

print("Training loop finished; total steps:", global_step)


In [None]:
check_lora_for_nans(unet)


In [None]:
from pathlib import Path

SAVE_DIR = f"{PROJ}/checkpoints/lora_object_unet_v2"
Path(SAVE_DIR).mkdir(parents=True, exist_ok=True)

# tell diffusers to grab LoRA from this UNet
pipe.save_lora_weights(
    SAVE_DIR,
    unet_lora_layers=unet,          # <--- important
    text_encoder_lora_layers=None,  # we didn't train text encoder LoRA
)

print("✅ Saved LoRA weights to:", SAVE_DIR)


In [None]:
check_lora_for_nans(unet)


In [None]:
from pathlib import Path
from diffusers.utils import convert_state_dict_to_diffusers
from peft.utils import get_peft_model_state_dict
from diffusers import StableDiffusionPipeline

NEW_LORA_DIR = f"{PROJ}/checkpoints/lora_object_unet_v3"
Path(NEW_LORA_DIR).mkdir(parents=True, exist_ok=True)

# 1) Get only the PEFT (LoRA) weights from the UNet
unwrapped_unet = unet  # no accelerator, so this is fine
lora_state_dict = get_peft_model_state_dict(unwrapped_unet)

# 2) Convert them to diffusers LoRA format (adds 'lora' keys etc.)
lora_state_dict = convert_state_dict_to_diffusers(lora_state_dict)

# 3) Save in the official LoRA format
StableDiffusionPipeline.save_lora_weights(
    save_directory=NEW_LORA_DIR,
    unet_lora_layers=lora_state_dict,
    safe_serialization=True,
)

print("✅ Correct LoRA saved to:", NEW_LORA_DIR)



In [None]:
import torch
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None,
).to(device)

LORA_DIR = f"{PROJ}/checkpoints/lora_object_unet_v3"

# this should now work without "Invalid LoRA checkpoint"
pipe.load_lora_weights(LORA_DIR)
pipe.enable_attention_slicing()

prompt = "a realistic photo of an everyday object on a clean background"
with torch.autocast(device_type="cuda", dtype=torch.float16):
    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

from IPython.display import display
display(image)


In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm
import torch.nn.functional as F
from pathlib import Path
from diffusers.utils import convert_state_dict_to_diffusers
from peft.utils import get_peft_model_state_dict
from diffusers import StableDiffusionPipeline

# ==== hyperparams ====
learning_rate   = 5e-5
MAX_TRAIN_STEPS = 2000          # you can change later
SAVE_EVERY      = 500           # save every 500 steps
global_step     = 0

optimizer = AdamW(lora_params, lr=learning_rate)

unet.train()
torch_dtype = torch.float32     # we’re training in fp32
scheduler   = pipe.scheduler

def save_lora_checkpoint(step):
    """
    Save LoRA weights in diffusers format at a specific step.
    """
    ckpt_dir = f"{PROJ}/checkpoints/lora_object_steps/step_{step}"
    Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

    # 1) extract only LoRA weights from UNet
    unwrapped_unet = unet
    lora_state_dict = get_peft_model_state_dict(unwrapped_unet)

    # 2) convert to diffusers LoRA format
    lora_state_dict = convert_state_dict_to_diffusers(lora_state_dict)

    # 3) save as a proper LoRA checkpoint
    StableDiffusionPipeline.save_lora_weights(
        save_directory=ckpt_dir,
        unet_lora_layers=lora_state_dict,
        safe_serialization=True,
    )

    print(f" Saved LoRA checkpoint at step {step} -> {ckpt_dir}")

for epoch in range(9999):   # we'll break by steps, so epoch number is not important
    pbar = tqdm(obj_train_dl, desc=f"Epoch {epoch+1}")
    for batch in pbar:
        if global_step >= MAX_TRAIN_STEPS:
            break

        images   = batch["image"]        # (B,3,H,W)
        captions = batch["caption"]

        # 1) encode text + images (fp32)
        text_embeds = encode_prompts(captions, tokenizer, device, torch.float32)
        latents     = encode_images_to_latents(images, vae, device, torch.float32)

        # 2) noise & timesteps
        bsz   = latents.shape[0]
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, scheduler.config.num_train_timesteps,
            (bsz,),
            device=device,
            dtype=torch.long,
        )
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        # 3) forward UNet
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=text_embeds,
        ).sample

        # 4) noise prediction loss
        loss = F.mse_loss(model_pred, noise, reduction="mean")

        # NaN/Inf guard
        if not torch.isfinite(loss):
            print(f"Loss became NaN/Inf at step {global_step}, stopping training.")
            break

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
        optimizer.step()

        global_step += 1
        pbar.set_postfix({"loss": float(loss.item()), "step": global_step})

        # periodic checkpoint
        if global_step % SAVE_EVERY == 0:
            save_lora_checkpoint(global_step)

    if global_step >= MAX_TRAIN_STEPS:
        print("Reached max train steps, stopping.")
        break

print("Training loop finished; total steps:", global_step)


In [None]:
import torch
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

# helper to load base pipe
def load_base_pipe(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

# choose one of your checkpoints
CKPT_STEP = 1000   # or 500, 1500, etc.
LORA_DIR  = f"{PROJ}/checkpoints/lora_object_steps/step_{CKPT_STEP}"

prompt = "a realistic photo of an everyday object on a clean background"

# 1) base model
pipe_base = load_base_pipe(torch.float16)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# 2) LoRA model
pipe_lora = load_base_pipe(torch.float16)
pipe_lora.load_lora_weights(LORA_DIR)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_lora = pipe_lora(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# side-by-side
w, h = img_base.size
combined = Image.new("RGB", (2 * w, h))
combined.paste(img_base, (0, 0))
combined.paste(img_lora, (w, 0))

print("Left: base SD | Right: LoRA at step", CKPT_STEP)
display(combined)


In [None]:
from pathlib import Path
import pandas as pd

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"   # you already use this

SCENE_MANIFEST = f"{PROJ}/data/sketchycoco/scene_manifest.csv"   # 👈 set to your real scene manifest
scene_df = pd.read_csv(SCENE_MANIFEST)
print(df_scene.head())
print("Total rows in scene manifest:", len(df_scene))

# Make a smaller subset for Colab training
# (change sizes if you want bigger)
MAX_TRAIN = 4000
MAX_VAL   = 500

scene_train = scene_df[scene_df["split"] == "train"].sample(
    n=min(MAX_TRAIN, (scene_df["split"] == "train").sum()),
    random_state=42
)
scene_val = scene_df[scene_df["split"] == "val"].sample(
    n=min(MAX_VAL, (scene_df["split"] == "val").sum()),
    random_state=42
)

scene_subset = pd.concat([scene_train, scene_val]).reset_index(drop=True)

SCENE_SUBSET_CSV = f"{PROJ}/subsets/scene_subset.csv"
Path(f"{PROJ}/subsets").mkdir(parents=True, exist_ok=True)
scene_subset.to_csv(SCENE_SUBSET_CSV, index=False)

print("Saved scene_subset.csv at:", SCENE_SUBSET_CSV)
print("Subset sizes -> train:", len(scene_train), " val:", len(scene_val))

In [None]:
from torch.utils.data import DataLoader

IMAGE_SIZE = 512  # same as before

scene_train_ds = Sketch2ImageDataset(SCENE_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
scene_val_ds   = Sketch2ImageDataset(SCENE_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)

BATCH_SIZE = 2

scene_train_dl = DataLoader(scene_train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
scene_val_dl   = DataLoader(scene_val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

batch = next(iter(scene_train_dl))
print("Scene train batch:")
print("  sketch:", batch["sketch"].shape)
print("  image :", batch["image"].shape)
print("  caption[0]:", batch["caption"][0])


In [None]:
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

pipe_scene = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32,      # train in fp32
    safety_checker=None,
).to(device)

tokenizer_scene = pipe_scene.tokenizer
vae_scene       = pipe_scene.vae
unet_scene      = pipe_scene.unet
scheduler_scene = pipe_scene.scheduler

pipe_scene.enable_attention_slicing()
pipe_scene.enable_vae_slicing()

LORA_RANK_SCENE = 4

def add_lora_to_unet(unet, rank=LORA_RANK_SCENE):
    lora_config = LoraConfig(
        r=rank,
        lora_alpha=rank,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        init_lora_weights="gaussian",
    )
    unet.add_adapter(lora_config)
    return unet

add_lora_to_unet(unet_scene, rank=LORA_RANK_SCENE)

# freeze base weights
for p in unet_scene.parameters():
    p.requires_grad_(False)

# collect LoRA params
lora_params_scene = []
for name, p in unet_scene.named_parameters():
    if "lora" in name:
        p.requires_grad_(True)
        lora_params_scene.append(p)

print("Scene LoRA params:", sum(p.numel() for p in lora_params_scene))


In [None]:
import torch.nn.functional as F

def encode_prompts_scene(captions, tokenizer, device, dtype):
    text_inputs = tokenizer(
        list(captions),
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    input_ids = text_inputs.input_ids.to(device)
    with torch.no_grad():
        text_embeds = pipe_scene.text_encoder(input_ids)[0]
    return text_embeds.to(device=device, dtype=dtype)

def encode_images_to_latents_scene(images: torch.Tensor, vae, device, dtype):
    images = images.to(device=device, dtype=dtype)
    with torch.no_grad():
        latents = vae.encode(images).latent_dist.sample() * 0.18215
    return latents


In [None]:
import gc, torch

# Try to delete any big old pipelines/models if they exist
for name in [
    "pipe", "unet", "vae", "scheduler",
    "pipe_base", "pipe_lora",
    "obj_train_dl", "obj_val_dl",
]:
    if name in globals():
        del globals()[name]

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print(" Cleared old references and emptied CUDA cache.")
if torch.cuda.is_available():
    print(torch.cuda.mem_get_info())


In [None]:
# ---- create scene loaders ----
scene_train_ds = Sketch2ImageDataset(SCENE_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
scene_val_ds   = Sketch2ImageDataset(SCENE_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)

BATCH_SIZE = 1   # 👈 reduce from 2 to 1 to save memory

scene_train_dl = DataLoader(scene_train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
scene_val_dl   = DataLoader(scene_val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

batch = next(iter(scene_train_dl))
print("Scene train batch:")
print("  sketch:", batch["sketch"].shape)
print("  image :", batch["image"].shape)
print("  caption[0]:", batch["caption"][0])


In [None]:
# add_lora_to_unet(unet_scene, rank=LORA_RANK_SCENE) # Removed: LoRA is already added in f52VEu5Ojq4a

# 👇 Add this line to save memory
unet_scene.enable_gradient_checkpointing()

# freeze base weights
for p in unet_scene.parameters():
    p.requires_grad_(False)

# collect only LoRA params
lora_params_scene = []
for name, p in unet_scene.named_parameters():
    if "lora" in name:
        p.requires_grad_(True)
        lora_params_scene.append(p)

print("Scene LoRA trainable params:", sum(p.numel() for p in lora_params_scene))

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm
from pathlib import Path
from diffusers.utils import convert_state_dict_to_diffusers
from peft.utils import get_peft_model_state_dict
from diffusers import StableDiffusionPipeline

learning_rate_scene   = 5e-5
MAX_TRAIN_STEPS_SCENE = 2000
SAVE_EVERY_SCENE      = 500

global_step_scene = 0
optimizer_scene   = AdamW(lora_params_scene, lr=learning_rate_scene)

unet_scene.train()
torch_dtype_scene = torch.float32

def save_lora_checkpoint_scene(step):
    ckpt_dir = f"{PROJ}/checkpoints/lora_scene_steps/step_{step}"
    Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

    unwrapped_unet = unet_scene
    lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
    lora_state_dict = convert_state_dict_to_diffusers(lora_state_dict)

    StableDiffusionPipeline.save_lora_weights(
        save_directory=ckpt_dir,
        unet_lora_layers=lora_state_dict,
        safe_serialization=True,
    )
    print(f"💾 Saved SCENE LoRA checkpoint at step {step} -> {ckpt_dir}")

for epoch in range(9999):   # we'll break by step count
    pbar = tqdm(scene_train_dl, desc=f"[SCENE] Epoch {epoch+1}")
    for batch in pbar:
        if global_step_scene >= MAX_TRAIN_STEPS_SCENE:
            break

        images   = batch["image"]
        captions = batch["caption"]

        # encode text + images
        text_embeds = encode_prompts_scene(captions, tokenizer_scene, device, torch_dtype_scene)
        latents     = encode_images_to_latents_scene(images, vae_scene, device, torch_dtype_scene)

        bsz   = latents.shape[0]
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, scheduler_scene.config.num_train_timesteps,
            (bsz,),
            device=device,
            dtype=torch.long,
        )
        noisy_latents = scheduler_scene.add_noise(latents, noise, timesteps)

        # UNet forward
        model_pred = unet_scene(
            noisy_latents,
            timesteps,
            encoder_hidden_states=text_embeds,
        ).sample

        loss = F.mse_loss(model_pred, noise, reduction="mean")

        if not torch.isfinite(loss):
            print(f"Loss became NaN/Inf at step {global_step_scene}, stopping.")
            break

        optimizer_scene.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(lora_params_scene, 1.0)
        optimizer_scene.step()

        global_step_scene += 1
        pbar.set_postfix({"loss": float(loss.item()), "step": global_step_scene})

        if global_step_scene % SAVE_EVERY_SCENE == 0:
            save_lora_checkpoint_scene(global_step_scene)

    if global_step_scene >= MAX_TRAIN_STEPS_SCENE:
        print(" Reached max SCENE train steps, stopping.")
        break

print("SCENE training loop finished; total steps:", global_step_scene)


In [None]:
import os
from pathlib import Path

CKPT_ROOT = f"{PROJ}/checkpoints/lora_scene_steps"
print("Looking in:", CKPT_ROOT)

if not os.path.isdir(CKPT_ROOT):
    print("⚠️ No scene checkpoints folder found.")
else:
    for name in sorted(os.listdir(CKPT_ROOT)):
        full = os.path.join(CKPT_ROOT, name)
        if os.path.isdir(full):
            print(" -", name)


In [None]:
import torch
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

def load_base_pipe(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

#  set this to one of the step folders you saw above
CKPT_STEP_SCENE = 1000      # e.g. 500, 1000, 1500, 2000
LORA_SCENE_DIR  = f"{PROJ}/checkpoints/lora_scene_steps/step_{CKPT_STEP_SCENE}"

prompt = "a busy city street scene with cars and pedestrians during daytime"

# 1) base model
pipe_base = load_base_pipe(torch.float16)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# 2) model with scene LoRA
pipe_scene_eval = load_base_pipe(torch.float16)
pipe_scene_eval.load_lora_weights(LORA_SCENE_DIR)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_lora = pipe_scene_eval(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# side-by-side
w, h = img_base.size
combined = Image.new("RGB", (2 * w, h))
combined.paste(img_base, (0, 0))
combined.paste(img_lora, (w, 0))

print(f"Left: base SD | Right: SCENE LoRA @ step {CKPT_STEP_SCENE}")
display(combined)


In [None]:
OUT_DIR = f"{PROJ}/results/scene_lora"
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

out_path = f"{OUT_DIR}/base_vs_lora_step{CKPT_STEP_SCENE}.png"
combined.save(out_path)
print("✅ Saved comparison image to:", out_path)


In [None]:
import torch
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image
from pathlib import Path

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

RESULTS_DIR = f"{PROJ}/results/scene_lora"
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)

def load_base_pipe(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

def compare_and_save_scene(prompt, step, filename_suffix):
    lora_dir = f"{PROJ}/checkpoints/lora_scene_steps/step_{step}"
    assert os.path.isdir(lora_dir), f"No such LoRA dir: {lora_dir}"

    pipe_base = load_base_pipe(torch.float16)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    pipe_lora = load_base_pipe(torch.float16)
    pipe_lora.load_lora_weights(lora_dir)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_lora = pipe_lora(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    w, h = img_base.size
    combined = Image.new("RGB", (2 * w, h))
    combined.paste(img_base, (0, 0))
    combined.paste(img_lora, (w, 0))

    out_path = f"{RESULTS_DIR}/scene_base_vs_lora_step{step}_{filename_suffix}.png"
    combined.save(out_path)
    print(f"Saved: {out_path}")
    display(combined)

# generate a few nice comparisons
CKPT_STEP_SCENE = 1000   # or 1500 / 2000 if you have them

compare_and_save_scene("a busy city street scene with cars and pedestrians during daytime",
                       CKPT_STEP_SCENE, "city_day")
compare_and_save_scene("a crowded outdoor market with colorful stalls and people shopping",
                       CKPT_STEP_SCENE, "market")
compare_and_save_scene("a park scene with trees, benches, and people walking dogs",
                       CKPT_STEP_SCENE, "park")


In [None]:
import torch
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image
from pathlib import Path

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

RESULTS_OBJ_DIR = f"{PROJ}/results/object_lora"
Path(RESULTS_OBJ_DIR).mkdir(parents=True, exist_ok=True)

def load_base_pipe_obj(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

def compare_and_save_object(prompt, step, filename_suffix):
    lora_dir = f"{PROJ}/checkpoints/lora_object_steps/step_{step}"
    assert os.path.isdir(lora_dir), f"No such LoRA dir: {lora_dir}"

    pipe_base = load_base_pipe_obj(torch.float16)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    pipe_lora = load_base_pipe_obj(torch.float16)
    pipe_lora.load_lora_weights(lora_dir)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_lora = pipe_lora(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    w, h = img_base.size
    combined = Image.new("RGB", (2 * w, h))
    combined.paste(img_base, (0, 0))
    combined.paste(img_lora, (w, 0))

    out_path = f"{RESULTS_OBJ_DIR}/object_base_vs_lora_step{step}_{filename_suffix}.png"
    combined.save(out_path)
    print(f"Saved: {out_path}")
    display(combined)

# set to one of the steps you actually have (e.g., 500 / 1000 / 1500)
CKPT_STEP_OBJ = 1000

compare_and_save_object("a realistic photo of an everyday object on a plain background",
                        CKPT_STEP_OBJ, "generic_object")
compare_and_save_object("a red mug on a wooden table",
                        CKPT_STEP_OBJ, "red_mug")
compare_and_save_object("a pair of sneakers on a white background",
                        CKPT_STEP_OBJ, "sneakers")


In [None]:
#few SCENE LoRA comparisons

In [None]:
import torch, os
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image
from pathlib import Path

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

RESULTS_SCENE_DIR = f"{PROJ}/results/scene_lora"
Path(RESULTS_SCENE_DIR).mkdir(parents=True, exist_ok=True)

def load_base_pipe(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

def compare_and_save_scene(prompt, step, filename_suffix):
    lora_dir = f"{PROJ}/checkpoints/lora_scene_steps/step_{step}"
    assert os.path.isdir(lora_dir), f"No such LoRA dir: {lora_dir}"

    pipe_base = load_base_pipe(torch.float16)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    pipe_lora = load_base_pipe(torch.float16)
    pipe_lora.load_lora_weights(lora_dir)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_lora = pipe_lora(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    w, h = img_base.size
    combined = Image.new("RGB", (2 * w, h))
    combined.paste(img_base, (0, 0))
    combined.paste(img_lora, (w, 0))

    out_path = f"{RESULTS_SCENE_DIR}/scene_base_vs_lora_step{step}_{filename_suffix}.png"
    combined.save(out_path)
    print(f"Saved: {out_path}")
    display(combined)

CKPT_STEP_SCENE = 1000   # use the step that looked good

compare_and_save_scene("a busy city street scene with cars and pedestrians during daytime",
                       CKPT_STEP_SCENE, "city_day")
compare_and_save_scene("a crowded outdoor market with colorful stalls and people shopping",
                       CKPT_STEP_SCENE, "market")
compare_and_save_scene("a park with trees, benches, and people walking dogs",
                       CKPT_STEP_SCENE, "park")


In [None]:
#OBJECT LoRA

In [None]:
import torch, os
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image
from pathlib import Path

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

RESULTS_OBJ_DIR = f"{PROJ}/results/object_lora"
Path(RESULTS_OBJ_DIR).mkdir(parents=True, exist_ok=True)

def load_base_pipe_obj(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

def compare_and_save_object(prompt, step, filename_suffix):
    lora_dir = f"{PROJ}/checkpoints/lora_object_steps/step_{step}"
    assert os.path.isdir(lora_dir), f"No such LoRA dir: {lora_dir}"

    pipe_base = load_base_pipe_obj(torch.float16)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    pipe_lora = load_base_pipe_obj(torch.float16)
    pipe_lora.load_lora_weights(lora_dir)
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        img_lora = pipe_lora(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

    w, h = img_base.size
    combined = Image.new("RGB", (2 * w, h))
    combined.paste(img_base, (0, 0))
    combined.paste(img_lora, (w, 0))

    out_path = f"{RESULTS_OBJ_DIR}/object_base_vs_lora_step{step}_{filename_suffix}.png"
    combined.save(out_path)
    print(f"Saved: {out_path}")
    display(combined)

CKPT_STEP_OBJ = 1000   # use one of your object steps

compare_and_save_object("a realistic photo of an everyday object on a plain background",
                        CKPT_STEP_OBJ, "generic_object")
compare_and_save_object("a red mug on a wooden table",
                        CKPT_STEP_OBJ, "red_mug")
compare_and_save_object("a pair of sneakers on a white background",
                        CKPT_STEP_OBJ, "sneakers")


In [None]:
#fscoco

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
from pathlib import Path

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
Path(PROJ).mkdir(parents=True, exist_ok=True)
print("PROJ:", PROJ)

# It should contain subfolders: images, raster_sketches, text, etc.
FS_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/fscoco/fscoco"

print("FS_ROOT:", FS_ROOT)
print("Exists?", os.path.isdir(FS_ROOT))
print("Subfolders:", os.listdir(FS_ROOT))


In [None]:
import glob
import random
import pandas as pd
from pathlib import Path

images_root   = os.path.join(FS_ROOT, "images")
sketch_root   = os.path.join(FS_ROOT, "raster_sketches")
text_root     = os.path.join(FS_ROOT, "text")  # may or may not exist

print("images_root:", images_root, "exists?", os.path.isdir(images_root))
print("sketch_root:", sketch_root, "exists?", os.path.isdir(sketch_root))
print("text_root  :", text_root,   "exists?", os.path.isdir(text_root))

def g(pattern):
    return sorted(glob.glob(pattern, recursive=True))

# 1) collect images
image_files = g(f"{images_root}/**/*.*")
print("Found images:", len(image_files))

# 2) collect sketches
sketch_files = g(f"{sketch_root}/**/*.*")
print("Found raster sketches:", len(sketch_files))

# map (folder_id, stem) -> path
def key_from_path(path, root):
    rel = os.path.relpath(path, root)
    parts = rel.split(os.sep)
    if len(parts) < 2:
        return None
    folder_id = parts[0]
    stem = Path(parts[-1]).stem
    return (folder_id, stem)

img_by_key   = {}
for p in image_files:
    k = key_from_path(p, images_root)
    if k is not None:
        img_by_key[k] = p

sk_by_key   = {}
for p in sketch_files:
    k = key_from_path(p, sketch_root)
    if k is not None:
        sk_by_key[k] = p

# 3) optional captions
cap_by_key = {}
if os.path.isdir(text_root):
    text_files = g(f"{text_root}/**/*.*")
    print("Found text files:", len(text_files))
    for p in text_files:
        k = key_from_path(p, text_root)
        if k is None:
            continue
        try:
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                text = f.read().strip()
            if text:
                cap_by_key[k] = text
        except Exception as e:
            print("Could not read caption from", p, "->", e)
else:
    print("No text_root folder; will use synthetic captions.")

# 4) build paired rows
rows = []
for key, sk_path in sk_by_key.items():
    if key not in img_by_key:
        continue
    img_path = img_by_key[key]
    cap = cap_by_key.get(key, "a realistic photo corresponding to a sketch")
    rows.append([sk_path, img_path, cap])

print("Total FSCOCO paired rows:", len(rows))

fscoco_df = pd.DataFrame(rows, columns=["sketch_path", "photo_path", "caption"])

# 5) random split 90/10 train/val
random.seed(42)
indices = list(range(len(fscoco_df)))
random.shuffle(indices)

split_idx = int(0.9 * len(indices))
train_idx = set(indices[:split_idx])

splits = ["train" if i in train_idx else "val" for i in range(len(fscoco_df))]
fscoco_df["split"] = splits

MANIFEST_DIR = f"{PROJ}/manifests"
Path(MANIFEST_DIR).mkdir(parents=True, exist_ok=True)

FS_MANIFEST = f"{MANIFEST_DIR}/fscoco_manifest.csv"
fscoco_df.to_csv(FS_MANIFEST, index=False)
print("✅ Saved full FSCOCO manifest to:", FS_MANIFEST)
print("Train rows:", (fscoco_df["split"]=="train").sum(),
      "Val rows:", (fscoco_df["split"]=="val").sum())

# 6) Optional smaller subset for Colab
MAX_TRAIN_FS = 4000
MAX_VAL_FS   = 500

train_fs = fscoco_df[fscoco_df["split"]=="train"].sample(
    n=min(MAX_TRAIN_FS, (fscoco_df["split"]=="train").sum()),
    random_state=42
)
val_fs = fscoco_df[fscoco_df["split"]=="val"].sample(
    n=min(MAX_VAL_FS, (fscoco_df["split"]=="val").sum()),
    random_state=42
)

fs_subset = pd.concat([train_fs, val_fs]).reset_index(drop=True)

SUBSETS_DIR = f"{PROJ}/subsets"
Path(SUBSETS_DIR).mkdir(parents=True, exist_ok=True)

FS_SUBSET_CSV = f"{SUBSETS_DIR}/fscoco_subset.csv"
fs_subset.to_csv(FS_SUBSET_CSV, index=False)

print("✅ Saved FSCOCO subset to:", FS_SUBSET_CSV)
print("Subset sizes -> train:", len(train_fs), " val:", len(val_fs))


In [None]:
#Dataset + dataloader for FSCOCO

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch

IMAGE_SIZE = 512

def pil_to_tensor(img: Image.Image, size=IMAGE_SIZE, to_gray=False):
    if to_gray:
        img = img.convert("L")
    else:
        img = img.convert("RGB")

    img = img.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype("float32") / 255.0

    if to_gray:
        if arr.ndim == 2:
            arr = arr[..., None]
    else:
        if arr.ndim == 2:
            arr = np.stack([arr]*3, axis=-1)

    arr = arr.transpose(2, 0, 1)
    arr = (arr * 2.0) - 1.0
    return torch.from_numpy(arr)

class Sketch2ImageDataset(Dataset):
    def __init__(self, manifest_csv, split="train", image_size=IMAGE_SIZE):
        self.manifest_csv = manifest_csv
        self.split = split
        self.image_size = image_size

        df = pd.read_csv(manifest_csv)
        if "split" not in df.columns:
            raise ValueError(f"'split' column not found in {manifest_csv}")

        self.df = df[df["split"] == split].reset_index(drop=True)
        if len(self.df) == 0:
            raise ValueError(f"No rows for split='{split}' in {manifest_csv}")

        print(f"[{os.path.basename(manifest_csv)}] Loaded {len(self.df)} rows for split='{split}'")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sketch_path = row["sketch_path"]
        photo_path  = row["photo_path"]
        caption     = str(row.get("caption", ""))

        sketch_img = Image.open(sketch_path)
        photo_img  = Image.open(photo_path)

        sketch_tensor = pil_to_tensor(sketch_img, size=self.image_size, to_gray=True)
        photo_tensor  = pil_to_tensor(photo_img,  size=self.image_size, to_gray=False)

        return {
            "sketch": sketch_tensor,
            "image":  photo_tensor,
            "caption": caption,
            "sketch_path": sketch_path,
            "photo_path":  photo_path,
        }

# ---- FSCOCO loaders ----
fs_train_ds = Sketch2ImageDataset(FS_SUBSET_CSV, split="train", image_size=IMAGE_SIZE)
fs_val_ds   = Sketch2ImageDataset(FS_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)

BATCH_SIZE_FS = 1   # keep 1 for safety on L4

fs_train_dl = DataLoader(fs_train_ds, batch_size=BATCH_SIZE_FS, shuffle=True,  num_workers=2)
fs_val_dl   = DataLoader(fs_val_ds,   batch_size=BATCH_SIZE_FS, shuffle=False, num_workers=2)

batch = next(iter(fs_train_dl))
print("FSCOCO train batch:")
print("  sketch:", batch["sketch"].shape)
print("  image :", batch["image"].shape)
print("  caption[0]:", batch["caption"][0][:80])


In [None]:
#clear GPU, load SD, attach FSCOCO LoRA

In [None]:
import gc, torch
from diffusers import StableDiffusionPipeline
from peft import LoraConfig

# clear old stuff from GPU
for name in list(globals().keys()):
    if name.startswith(("pipe_", "pipe", "unet_", "vae_", "scheduler_", "optimizer_")):
        try:
            del globals()[name]
        except:
            pass

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("CUDA memory after clear:", torch.cuda.mem_get_info())

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

pipe_fs = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float32,   # train in fp32
    safety_checker=None,
).to(device)

tokenizer_fs = pipe_fs.tokenizer
vae_fs       = pipe_fs.vae
unet_fs      = pipe_fs.unet
scheduler_fs = pipe_fs.scheduler

pipe_fs.enable_attention_slicing()
pipe_fs.enable_vae_slicing()

LORA_RANK_FS = 4

def add_lora_to_unet_fs(unet, rank=LORA_RANK_FS):
    lora_config = LoraConfig(
        r=rank,
        lora_alpha=rank,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        init_lora_weights="gaussian",
    )
    unet.add_adapter(lora_config)
    return unet

add_lora_to_unet_fs(unet_fs, rank=LORA_RANK_FS)

# save memory
if hasattr(unet_fs, "enable_gradient_checkpointing"):
    unet_fs.enable_gradient_checkpointing()

# freeze base weights
for p in unet_fs.parameters():
    p.requires_grad_(False)

# collect LoRA params
lora_params_fs = []
for name, p in unet_fs.named_parameters():
    if "lora" in name:
        p.requires_grad_(True)
        lora_params_fs.append(p)

print("FSCOCO LoRA trainable params:", sum(p.numel() for p in lora_params_fs))


In [None]:
#Helpers for FSCOCO (encode captions + images)

In [None]:
import torch.nn.functional as F

def encode_prompts_fs(captions, tokenizer, device, dtype):
    text_inputs = tokenizer(
        list(captions),
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    input_ids = text_inputs.input_ids.to(device)
    with torch.no_grad():
        text_embeds = pipe_fs.text_encoder(input_ids)[0]
    return text_embeds.to(device=device, dtype=dtype)

def encode_images_to_latents_fs(images: torch.Tensor, vae, device, dtype):
    images = images.to(device=device, dtype=dtype)
    with torch.no_grad():
        latents = vae.encode(images).latent_dist.sample() * 0.18215
    return latents


In [None]:
#Training loop for FSCOCO LoRA (with checkpoints)

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm
from pathlib import Path
from diffusers.utils import convert_state_dict_to_diffusers
from peft.utils import get_peft_model_state_dict
from diffusers import StableDiffusionPipeline

learning_rate_fs   = 5e-5
MAX_TRAIN_STEPS_FS = 4000     # set 1000 for a quicker run
SAVE_EVERY_FS      = 500

global_step_fs = 0
optimizer_fs   = AdamW(lora_params_fs, lr=learning_rate_fs)

unet_fs.train()
torch_dtype_fs = torch.float32

def save_lora_checkpoint_fs(step):
    ckpt_dir = f"{PROJ}/checkpoints/lora_fscoco_steps/step_{step}"
    Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

    unwrapped_unet = unet_fs
    lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
    lora_state_dict = convert_state_dict_to_diffusers(lora_state_dict)

    StableDiffusionPipeline.save_lora_weights(
        save_directory=ckpt_dir,
        unet_lora_layers=lora_state_dict,
        safe_serialization=True,
    )
    print(f"Saved FSCOCO LoRA checkpoint at step {step} -> {ckpt_dir}")

for epoch in range(9999):
    pbar = tqdm(fs_train_dl, desc=f"[FSCOCO] Epoch {epoch+1}")
    for batch in pbar:
        if global_step_fs >= MAX_TRAIN_STEPS_FS:
            break

        images   = batch["image"]
        captions = batch["caption"]

        text_embeds = encode_prompts_fs(captions, tokenizer_fs, device, torch_dtype_fs)
        latents     = encode_images_to_latents_fs(images, vae_fs, device, torch_dtype_fs)

        bsz   = latents.shape[0]
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, scheduler_fs.config.num_train_timesteps,
            (bsz,),
            device=device,
            dtype=torch.long,
        )
        noisy_latents = scheduler_fs.add_noise(latents, noise, timesteps)

        model_pred = unet_fs(
            noisy_latents,
            timesteps,
            encoder_hidden_states=text_embeds,
        ).sample

        loss = F.mse_loss(model_pred, noise, reduction="mean")

        if not torch.isfinite(loss):
            print(f"❌ Loss became NaN/Inf at step {global_step_fs}, stopping.")
            break

        optimizer_fs.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(lora_params_fs, 1.0)
        optimizer_fs.step()

        global_step_fs += 1
        pbar.set_postfix({"loss": float(loss.item()), "step": global_step_fs})

        if global_step_fs % SAVE_EVERY_FS == 0:
            save_lora_checkpoint_fs(global_step_fs)

    if global_step_fs >= MAX_TRAIN_STEPS_FS:
        print("✅ Reached max FSCOCO train steps, stopping.")
        break

print("FSCOCO training loop finished; total steps:", global_step_fs)


In [None]:
#Compare base SD vs FSCOCO LoRA

In [None]:
import torch
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"

def load_base_pipe_fs(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

CKPT_STEP_FS = 1000   # change to a step you actually have
LORA_FS_DIR  = f"{PROJ}/checkpoints/lora_fscoco_steps/step_{CKPT_STEP_FS}"

prompt = "a sketch-like scene of people and objects in an indoor environment"

# base SD
pipe_base = load_base_pipe_fs(torch.float16)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_base = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# FSCOCO LoRA
pipe_fs_eval = load_base_pipe_fs(torch.float16)
pipe_fs_eval.load_lora_weights(LORA_FS_DIR)
with torch.autocast(device_type="cuda", dtype=torch.float16):
    img_lora = pipe_fs_eval(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

w, h = img_base.size
combined = Image.new("RGB", (2*w, h))
combined.paste(img_base, (0, 0))
combined.paste(img_lora, (w, 0))

print(f"Left: base SD | Right: FSCOCO LoRA @ step {CKPT_STEP_FS}")
display(combined)


In [None]:
#edge2shoes

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
from pathlib import Path

PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
Path(PROJ).mkdir(parents=True, exist_ok=True)
print("PROJ:", PROJ)

# Typical structure (Isola pix2pix): edge2shoes/train, edge2shoes/val each with 256x512 images
EDGE_ROOT = "/content/drive/MyDrive/Gen AI/Gen AI project/edges2shoes/edges2shoes"

print("EDGE_ROOT:", EDGE_ROOT, "exists?", os.path.isdir(EDGE_ROOT))
print("Subfolders:", os.listdir(EDGE_ROOT))


In [None]:
import glob
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

IMAGE_SIZE = 256   # edges2shoes default

def pil_to_tensor_01(img: Image.Image, size=IMAGE_SIZE):
    img = img.convert("RGB")
    img = img.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype("float32") / 255.0
    arr = arr.transpose(2, 0, 1)   # (C,H,W)
    return torch.from_numpy(arr)

class Edge2ShoesDataset(Dataset):
    def __init__(self, root, split="train", max_images=None):
        self.root = os.path.join(root, split)
        files = sorted(glob.glob(os.path.join(self.root, "*.jpg")))
        if len(files) == 0:
            files = sorted(glob.glob(os.path.join(self.root, "*.png")))
        if max_images is not None:
            files = files[:max_images]
        self.files = files
        print(f"[edges2shoes] split={split}, images={len(self.files)} in {self.root}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        img = Image.open(path).convert("RGB")
        w, h = img.size
        w2 = w // 2

        edge = img.crop((0, 0, w2, h))
        shoe = img.crop((w2, 0, w, h))

        edge_t = pil_to_tensor_01(edge, IMAGE_SIZE)
        shoe_t = pil_to_tensor_01(shoe, IMAGE_SIZE)

        # normalize to [-1,1]
        edge_t = (edge_t * 2.0) - 1.0
        shoe_t = (shoe_t * 2.0) - 1.0

        return {
            "edge": edge_t,
            "shoe": shoe_t,
            "path": path,
        }

MAX_TRAIN_E2S = 6000   # you can lower for faster runs
MAX_VAL_E2S   = 500

train_e2s_ds = Edge2ShoesDataset(EDGE_ROOT, split="train", max_images=MAX_TRAIN_E2S)
val_e2s_ds   = Edge2ShoesDataset(EDGE_ROOT, split="val",   max_images=MAX_VAL_E2S)

BATCH_E2S = 4

train_e2s_dl = DataLoader(train_e2s_ds, batch_size=BATCH_E2S, shuffle=True,  num_workers=2)
val_e2s_dl   = DataLoader(val_e2s_ds,   batch_size=1,      shuffle=False, num_workers=2)

batch = next(iter(train_e2s_dl))
print("edge:", batch["edge"].shape, "shoe:", batch["shoe"].shape)


In [None]:
#Define pix2pix Generator (U-Net) & PatchGAN Discriminator

In [None]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, norm=True):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=not norm)]
        if norm:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

class DeconvBlock(nn.Module):
    def __init__(self, in_c, out_c, dropout=False):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(0.5))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_filters=64):
        super().__init__()
        # encoder
        self.down1 = ConvBlock(in_channels, num_filters,    norm=False)   # 256 -> 128
        self.down2 = ConvBlock(num_filters, num_filters*2)                # 128 -> 64
        self.down3 = ConvBlock(num_filters*2, num_filters*4)              # 64  -> 32
        self.down4 = ConvBlock(num_filters*4, num_filters*8)              # 32  -> 16
        self.down5 = ConvBlock(num_filters*8, num_filters*8)              # 16  -> 8
        self.down6 = ConvBlock(num_filters*8, num_filters*8)              # 8   -> 4
        self.down7 = ConvBlock(num_filters*8, num_filters*8)              # 4   -> 2
        self.down8 = ConvBlock(num_filters*8, num_filters*8, norm=False)  # 2   -> 1

        # decoder
        self.up1 = DeconvBlock(num_filters*8, num_filters*8, dropout=True)
        self.up2 = DeconvBlock(num_filters*16, num_filters*8, dropout=True)
        self.up3 = DeconvBlock(num_filters*16, num_filters*8, dropout=True)
        self.up4 = DeconvBlock(num_filters*16, num_filters*8)
        self.up5 = DeconvBlock(num_filters*16, num_filters*4)
        self.up6 = DeconvBlock(num_filters*8,  num_filters*2)
        self.up7 = DeconvBlock(num_filters*4,  num_filters)
        self.up8 = nn.ConvTranspose2d(num_filters*2, out_channels, kernel_size=4, stride=2, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], dim=1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], dim=1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], dim=1)
        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], dim=1)
        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], dim=1)
        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], dim=1)
        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], dim=1)
        u8 = self.up8(u7)
        return self.tanh(u8)

class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, num_filters=64):
        super().__init__()
        # in_channels = 3 (edge) + 3 (shoe)
        self.model = nn.Sequential(
            ConvBlock(in_channels, num_filters, norm=False),          # 256 -> 128
            ConvBlock(num_filters, num_filters*2),                    # 128 -> 64
            ConvBlock(num_filters*2, num_filters*4),                  # 64  -> 32
            ConvBlock(num_filters*4, num_filters*8),                  # 32  -> 16
            nn.Conv2d(num_filters*8, 1, kernel_size=4, stride=1, padding=1)  # 16 -> 15-ish
        )

    def forward(self, edge, shoe):
        x = torch.cat([edge, shoe], dim=1)
        return self.model(x)


In [None]:
#Instantiate models, optimizers, loss functions

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

G = UNetGenerator(in_channels=3, out_channels=3).to(device)
D = PatchDiscriminator(in_channels=6).to(device)

lr = 2e-4
beta1 = 0.5

optim_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
optim_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))

bce_loss = nn.BCEWithLogitsLoss()
l1_loss  = nn.L1Loss()

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

G.apply(weights_init)
D.apply(weights_init)

print("Generator params:", sum(p.numel() for p in G.parameters()) / 1e6, "M")
print("Discriminator params:", sum(p.numel() for p in D.parameters()) / 1e6, "M")


In [None]:
#Training loop (light pix2pix)

In [None]:
import gc, torch

# Try to delete big diffusion / LoRA / old models
for name in list(globals().keys()):
    if any(name.startswith(prefix) for prefix in [
        "pipe", "pipe_", "unet", "unet_", "vae", "vae_",
        "scheduler", "scheduler_",
        "optimizer_", "lora", "LORA",
        "unet_scene", "unet_object", "pipe_scene", "pipe_object", "pipe_fs"
    ]):
        try:
            del globals()[name]
        except:
            pass

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    free, total = torch.cuda.mem_get_info()
    print(f"CUDA memory after cleanup: free {free/1e9:.2f} GB / total {total/1e9:.2f} GB")


In [None]:
import glob
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

IMAGE_SIZE = 128   # smaller than before

def pil_to_tensor_01(img: Image.Image, size=IMAGE_SIZE):
    img = img.convert("RGB")
    img = img.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype("float32") / 255.0
    arr = arr.transpose(2, 0, 1)   # (C,H,W)
    return torch.from_numpy(arr)

class Edge2ShoesDataset(Dataset):
    def __init__(self, root, split="train", max_images=None):
        self.root = os.path.join(root, split)
        files = sorted(glob.glob(os.path.join(self.root, "*.jpg")))
        if len(files) == 0:
            files = sorted(glob.glob(os.path.join(self.root, "*.png")))
        if max_images is not None:
            files = files[:max_images]
        self.files = files
        print(f"[edges2shoes] split={split}, images={len(self.files)} in {self.root}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        img = Image.open(path).convert("RGB")
        w, h = img.size
        w2 = w // 2

        edge = img.crop((0, 0, w2, h))
        shoe = img.crop((w2, 0, w, h))

        edge_t = pil_to_tensor_01(edge, IMAGE_SIZE)
        shoe_t = pil_to_tensor_01(shoe, IMAGE_SIZE)

        edge_t = (edge_t * 2.0) - 1.0
        shoe_t = (shoe_t * 2.0) - 1.0

        return {
            "edge": edge_t,
            "shoe": shoe_t,
            "path": path,
        }

MAX_TRAIN_E2S = 3000   # smaller subset is fine
MAX_VAL_E2S   = 200

train_e2s_ds = Edge2ShoesDataset(EDGE_ROOT, split="train", max_images=MAX_TRAIN_E2S)
val_e2s_ds   = Edge2ShoesDataset(EDGE_ROOT, split="val",   max_images=MAX_VAL_E2S)

BATCH_E2S = 1  # 🔹 critical: batch=1 to save memory

train_e2s_dl = DataLoader(train_e2s_ds, batch_size=BATCH_E2S, shuffle=True,  num_workers=2)
val_e2s_dl   = DataLoader(val_e2s_ds,   batch_size=1,      shuffle=False, num_workers=2)

batch = next(iter(train_e2s_dl))
print("edge:", batch["edge"].shape, "shoe:", batch["shoe"].shape)


In [None]:
#redefine Generator & Discriminator (no BatchNorm/InstanceNorm)

In [None]:
import torch
import torch.nn as nn

# --------- Simple blocks: Conv + LeakyReLU / Deconv + ReLU, NO normalization ---------

class DownBlock(nn.Module):
    def __init__(self, in_c, out_c, first=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=True),
        ]
        if not first:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        else:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)

class UpBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=True),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.main(x)

class SimpleUNetGenerator(nn.Module):
    """
    6 down / 6 up U-Net, no normalization. Works with 128x128 inputs.
    """
    def __init__(self, in_channels=3, out_channels=3, base_ch=32):
        super().__init__()
        # encoder: 128 -> 64 -> 32 -> 16 -> 8 -> 4 -> 2
        self.down1 = DownBlock(in_channels,   base_ch,    first=True)   # 128 -> 64
        self.down2 = DownBlock(base_ch,       base_ch*2)                # 64  -> 32
        self.down3 = DownBlock(base_ch*2,     base_ch*4)                # 32  -> 16
        self.down4 = DownBlock(base_ch*4,     base_ch*8)                # 16  -> 8
        self.down5 = DownBlock(base_ch*8,     base_ch*8)                # 8   -> 4
        self.down6 = DownBlock(base_ch*8,     base_ch*8)                # 4   -> 2  (bottleneck)

        # decoder: 2 -> 4 -> 8 -> 16 -> 32 -> 64 -> 128
        self.up1 = UpBlock(base_ch*8,     base_ch*8)      # 2 -> 4
        self.up2 = UpBlock(base_ch*8*2,   base_ch*8)      # 4 -> 8
        self.up3 = UpBlock(base_ch*8*2,   base_ch*4)      # 8 -> 16
        self.up4 = UpBlock(base_ch*4*2,   base_ch*2)      # 16 -> 32
        self.up5 = UpBlock(base_ch*2*2,   base_ch)        # 32 -> 64
        self.up6 = nn.ConvTranspose2d(base_ch*2, out_channels,
                                      kernel_size=4, stride=2, padding=1)  # 64 -> 128
        self.tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.down1(x)   # 64x64
        d2 = self.down2(d1)  # 32x32
        d3 = self.down3(d2)  # 16x16
        d4 = self.down4(d3)  # 8x8
        d5 = self.down5(d4)  # 4x4
        d6 = self.down6(d5)  # 2x2

        u1 = self.up1(d6)              # 4x4
        u1 = torch.cat([u1, d5], 1)
        u2 = self.up2(u1)              # 8x8
        u2 = torch.cat([u2, d4], 1)
        u3 = self.up3(u2)              # 16x16
        u3 = torch.cat([u3, d3], 1)
        u4 = self.up4(u3)              # 32x32
        u4 = torch.cat([u4, d2], 1)
        u5 = self.up5(u4)              # 64x64
        u5 = torch.cat([u5, d1], 1)
        u6 = self.up6(u5)              # 128x128

        return self.tanh(u6)

class SimplePatchDiscriminator(nn.Module):
    """
    PatchGAN-style discriminator, no normalization.
    Input: concatenated edge+shoe (6 channels)
    """
    def __init__(self, in_channels=6, base_ch=32):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, base_ch, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_ch, base_ch*2, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_ch*2, base_ch*4, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_ch*4, base_ch*8, kernel_size=4, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_ch*8, 1, kernel_size=4, stride=1, padding=1),  # Patch output
        )

    def forward(self, edge, shoe):
        x = torch.cat([edge, shoe], dim=1)
        return self.model(x)

# --------- Instantiate models, optimizers, losses ---------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

G = SimpleUNetGenerator(in_channels=3, out_channels=3, base_ch=32).to(device)
D = SimplePatchDiscriminator(in_channels=6, base_ch=32).to(device)

lr = 2e-4
beta1 = 0.5

optim_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
optim_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))

bce_loss = nn.BCEWithLogitsLoss()
l1_loss  = nn.L1Loss()

print("Generator params (M):", sum(p.numel() for p in G.parameters()) / 1e6)
print("Discriminator params (M):", sum(p.numel() for p in D.parameters()) / 1e6)


In [None]:
from tqdm.auto import tqdm
import os
from pathlib import Path

EPOCHS = 5
LAMBDA_L1 = 100.0

CKPT_DIR = f"{PROJ}/checkpoints/edges2shoes_gan"
Path(CKPT_DIR).mkdir(parents=True, exist_ok=True)

for epoch in range(1, EPOCHS+1):
    G.train()
    D.train()
    pbar = tqdm(train_e2s_dl, desc=f"[edges2shoes] Epoch {epoch}/{EPOCHS}")
    for batch in pbar:
        edge = batch["edge"].to(device)
        real_shoe = batch["shoe"].to(device)

        # --------- D ----------
        optim_D.zero_grad()

        out_real = D(edge, real_shoe)
        real_label = torch.ones_like(out_real)
        loss_D_real = bce_loss(out_real, real_label)

        with torch.no_grad():
            fake_shoe_detached = G(edge)
        out_fake = D(edge, fake_shoe_detached)
        fake_label = torch.zeros_like(out_fake)
        loss_D_fake = bce_loss(out_fake, fake_label)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optim_D.step()

        # --------- G ----------
        optim_G.zero_grad()

        fake_shoe = G(edge)
        out_fake_for_G = D(edge, fake_shoe)
        real_label_for_G = torch.ones_like(out_fake_for_G)

        loss_G_adv = bce_loss(out_fake_for_G, real_label_for_G)
        loss_G_l1  = l1_loss(fake_shoe, real_shoe) * LAMBDA_L1

        loss_G = loss_G_adv + loss_G_l1
        loss_G.backward()
        optim_G.step()

        pbar.set_postfix({"loss_D": float(loss_D.item()),
                          "loss_G": float(loss_G.item())})

    ckpt_path = os.path.join(CKPT_DIR, f"epoch_{epoch}.pt")
    torch.save({
        "G": G.state_dict(),
        "D": D.state_dict(),
        "optim_G": optim_G.state_dict(),
        "optim_D": optim_D.state_dict(),
        "epoch": epoch,
    }, ckpt_path)
    print(f"💾 Saved checkpoint at epoch {epoch} -> {ckpt_path}")


In [None]:
import torch
import os

CKPT_DIR = f"{PROJ}/checkpoints/edges2shoes_gan"
CKPT_PATH = os.path.join(CKPT_DIR, "epoch_5.pt")   # you can change to epoch_3/4 if you want

ckpt = torch.load(CKPT_PATH, map_location=device)
G.load_state_dict(ckpt["G"])
D.load_state_dict(ckpt["D"])

G.eval()
D.eval()

print("Loaded checkpoint:", CKPT_PATH)


In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from pathlib import Path

OUT_DIR = f"{PROJ}/outputs/edges2shoes_samples"
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

def tensor_to_pil(x):
    """
    x: (3,H,W) tensor in [-1,1]
    """
    x = x.detach().cpu().clamp(-1, 1)
    x = (x + 1.0) / 2.0          # [0,1]
    x = (x * 255.0).byte().numpy()
    x = np.transpose(x, (1, 2, 0))  # HWC
    return Image.fromarray(x)


In [None]:
import torch
from IPython.display import display

# Grab one batch from validation
batch = next(iter(val_e2s_dl))
edge = batch["edge"].to(device)       # (1,3,H,W)
real_shoe = batch["shoe"].to(device)

G.eval()
with torch.no_grad():
    fake_shoe = G(edge)

print("edge shape:", edge.shape)
print("fake shape:", fake_shoe.shape)
print("fake min/max:", float(fake_shoe.min()), float(fake_shoe.max()))


In [None]:
# Reduce LR for both optimizers
for opt in [optim_G, optim_D]:
    for pg in opt.param_groups:
        pg["lr"] = 1e-4   # was 2e-4 before

print("New LR for G:", optim_G.param_groups[0]["lr"])
print("New LR for D:", optim_D.param_groups[0]["lr"])


In [None]:
from tqdm.auto import tqdm
import os
from pathlib import Path

EPOCHS_REG = 10
CKPT_DIR_REG = f"{PROJ}/checkpoints/edges2shoes_reg"
Path(CKPT_DIR_REG).mkdir(parents=True, exist_ok=True)

for epoch in range(1, EPOCHS_REG + 1):
    G_reg.train()
    running_loss = 0.0
    pbar = tqdm(train_e2s_dl, desc=f"[E2S REG] Epoch {epoch}/{EPOCHS_REG}")

    for batch in pbar:
        edge = batch["edge"].to(device)       # (B,3,H,W), normalized [-1,1]
        real_shoe = batch["shoe"].to(device)  # (B,3,H,W), normalized [-1,1]

        optimizer_reg.zero_grad()

        fake_shoe = G_reg(edge)
        loss = l1_loss(fake_shoe, real_shoe)

        loss.backward()
        optimizer_reg.step()

        running_loss += loss.item() * edge.size(0)
        pbar.set_postfix({"L1": float(loss.item())})

    epoch_loss = running_loss / len(train_e2s_dl.dataset)
    print(f"Epoch {epoch} mean L1 loss: {epoch_loss:.4f}")

    ckpt_path = os.path.join(CKPT_DIR_REG, f"epoch_{epoch}.pt")
    torch.save({
        "G_reg": G_reg.state_dict(),
        "optimizer_reg": optimizer_reg.state_dict(),
        "epoch": epoch,
        "epoch_loss": epoch_loss,
    }, ckpt_path)
    print(f"💾 Saved regression checkpoint -> {ckpt_path}")


In [None]:
import torch
import os

BEST_EPOCH = 10  # change if you see a better epoch in logs

CKPT_PATH_REG = os.path.join(CKPT_DIR_REG, f"epoch_{BEST_EPOCH}.pt")
ckpt_reg = torch.load(CKPT_PATH_REG, map_location=device)
G_reg.load_state_dict(ckpt_reg["G_reg"])
G_reg.eval()

print("Loaded regression baseline from:", CKPT_PATH_REG)


In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path

OUT_DIR_REG = f"{PROJ}/outputs/edges2shoes_reg_samples"
Path(OUT_DIR_REG).mkdir(parents=True, exist_ok=True)

def tensor_to_pil(x):
    x = x.detach().cpu().clamp(-1, 1)
    x = (x + 1.0) / 2.0
    x = (x * 255.0).byte().numpy()
    x = np.transpose(x, (1, 2, 0))
    return Image.fromarray(x)


In [None]:
#edge2shoes

In [None]:
import numpy as np
from PIL import Image
import torch

def tensor_to_pil_auto(x: torch.Tensor) -> Image.Image:
    """
    x: (C,H,W) tensor, any range ([-1,1] or [0,1] or 0..255-ish)
    """
    x = x.detach().cpu()

    # If it looks like [-1,1], rescale
    if x.min() < -0.1:
        x = (x + 1.0) / 2.0

    # If it's >1 (e.g., 0..255), normalize
    if x.max() > 1.5:
        x = x / 255.0

    x = x.clamp(0, 1)
    x = (x * 255.0).byte().numpy()
    x = np.transpose(x, (1, 2, 0))  # C,H,W -> H,W,C
    return Image.fromarray(x)


In [None]:
batch = next(iter(train_e2s_dl))
edge  = batch["edge"][0]   # (3,H,W)
shoe  = batch["shoe"][0]

display(tensor_to_pil_auto(edge))
display(tensor_to_pil_auto(shoe))


In [None]:
#generator

In [None]:
import torch
import torch.nn as nn

class ConvBlockIN(nn.Module):
    def __init__(self, in_c, out_c, down=True, use_dropout=False):
        super().__init__()
        if down:
            layers = [
                nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(out_c, affine=True),
                nn.LeakyReLU(0.2, inplace=True),
            ]
        else:
            layers = [
                nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(out_c, affine=True),
                nn.ReLU(inplace=True),
            ]
        if use_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)


class DownNoNorm(nn.Module):
    """Last down block: conv + LeakyReLU, NO norm (avoids 1x1-normalization issue)."""
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 4, 2, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x):
        return self.block(x)


class SimpleUNetE2S_IN_NoLastNorm(nn.Module):
    """
    Pix2Pix-style U-Net for edges2shoes.
    Uses InstanceNorm everywhere EXCEPT the last 1x1 bottleneck encoder.
    """
    def __init__(self, in_ch=3, out_ch=3, base_ch=64):
        super().__init__()
        # Encoder (256 -> 1)
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_ch, base_ch, 4, 2, 1, bias=True),  # 128
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.enc2 = ConvBlockIN(base_ch,      base_ch*2, down=True)  # 64
        self.enc3 = ConvBlockIN(base_ch*2,    base_ch*4, down=True)  # 32
        self.enc4 = ConvBlockIN(base_ch*4,    base_ch*8, down=True)  # 16
        self.enc5 = ConvBlockIN(base_ch*8,    base_ch*8, down=True)  # 8
        self.enc6 = ConvBlockIN(base_ch*8,    base_ch*8, down=True)  # 4
        self.enc7 = DownNoNorm(base_ch*8,    base_ch*8)             # 2 -> 1 (NO NORM)
        self.enc8 = DownNoNorm(base_ch*8,     base_ch*8)             # 1 (NO NORM)

        # Decoder
        self.dec1 = ConvBlockIN(base_ch*8,     base_ch*8, down=False, use_dropout=True)   # 1->2
        self.dec2 = ConvBlockIN(base_ch*16,    base_ch*8, down=False, use_dropout=True)   # 2->4
        self.dec3 = ConvBlockIN(base_ch*16,    base_ch*8, down=False, use_dropout=True)   # 4->8
        self.dec4 = ConvBlockIN(base_ch*16,    base_ch*8, down=False)                     # 8->16
        self.dec5 = ConvBlockIN(base_ch*16,    base_ch*4, down=False)                     # 16->32
        self.dec6 = ConvBlockIN(base_ch*8,     base_ch*2, down=False)                     # 32->64
        self.dec7 = ConvBlockIN(base_ch*4,     base_ch,  down=False)                      # 64->128
        self.dec8 = nn.ConvTranspose2d(base_ch*2, out_ch, 4, 2, 1)                        # 128->256

        self.tanh = nn.Tanh()

    def forward(self, x):
        # Encode
        e1 = self.enc1(x)   # 128
        e2 = self.enc2(e1)  # 64
        e3 = self.enc3(e2)  # 32
        e4 = self.enc4(e3)  # 16
        e5 = self.enc5(e4)  # 8
        e6 = self.enc6(e5)  # 4
        e7 = self.enc7(e6)  # 2
        e8 = self.enc8(e7)  # 1

        # Decode with skips
        d1 = self.dec1(e8)
        d1 = torch.cat([d1, e7], dim=1)

        d2 = self.dec2(d1)
        d2 = torch.cat([d2, e6], dim=1)

        d3 = self.dec3(d2)
        d3 = torch.cat([d3, e5], dim=1)

        d4 = self.dec4(d3)
        d4 = torch.cat([d4, e4], dim=1)

        d5 = self.dec5(d4)
        d5 = torch.cat([d5, e3], dim=1)

        d6 = self.dec6(d5)
        d6 = torch.cat([d6, e2], dim=1)

        d7 = self.dec7(d6)
        d7 = torch.cat([d7, e1], dim=1)

        out = self.dec8(d7)
        return self.tanh(out)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

G_reg = SimpleUNetE2S_IN_NoLastNorm().to(device)

def init_weights_in_nolast(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif isinstance(m, (nn.InstanceNorm2d,)):
        if m.weight is not None:
            nn.init.normal_(m.weight, 1.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

G_reg.apply(init_weights_in_nolast)

print("G_reg params (M):", sum(p.numel() for p in G_reg.parameters())/1e6)
torch.cuda.empty_cache()

In [None]:
#optimizer and the training loop

In [None]:
import os
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

CKPT_DIR_GAN = f"{PROJ}/checkpoints/edges2shoes_gan"
BEST_EPOCH = 5  # or 3 if that visually looked better to you

ckpt_path = os.path.join(CKPT_DIR_GAN, f"epoch_{BEST_EPOCH}.pt")
state = torch.load(ckpt_path, map_location=device)

G.load_state_dict(state["G"])
D.load_state_dict(state["D"])
G.to(device).eval()

print("Loaded GAN checkpoint:", ckpt_path)


In [None]:
batch = next(iter(train_e2s_dl))
print("edge shape :", batch["edge"].shape)
print("shoe shape :", batch["shoe"].shape)


In [None]:
#fresh pix2pix

In [None]:
import torch
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- Basic blocks ----------

class DownBlock(nn.Module):
    def __init__(self, in_c, out_c, norm=True):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=not norm)]
        if norm:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)


class UpBlock(nn.Module):
    def __init__(self, in_c, out_c, dropout=False):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(0.5))
        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)


# ---------- Generator for 128x128 inputs (7 downs / 7 ups) ----------

class Pix2PixUNet128(nn.Module):
    """
    U-Net generator for 128x128 -> 128x128.
    Channel schedule (like pix2pix, but 7 downs instead of 8):

      Down: 3->64->128->256->512->512->512->512   (d1..d7)
      Up  : reverse with skip connections.
    """
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super().__init__()

        # Encoder: norm=False on first and last (1x1) blocks
        self.down1 = DownBlock(input_nc,      ngf,   norm=False)  # 128 -> 64
        self.down2 = DownBlock(ngf,           ngf*2)              # 64  -> 32
        self.down3 = DownBlock(ngf*2,         ngf*4)              # 32  -> 16
        self.down4 = DownBlock(ngf*4,         ngf*8)              # 16  -> 8
        self.down5 = DownBlock(ngf*8,         ngf*8)              # 8   -> 4
        self.down6 = DownBlock(ngf*8,         ngf*8)              # 4   -> 2
        self.down7 = DownBlock(ngf*8,         ngf*8, norm=False)  # 2   -> 1 (no BN)

        # Decoder
        self.up1 = UpBlock(ngf*8,     ngf*8, dropout=True)        # 1 -> 2, concat d6
        self.up2 = UpBlock(ngf*8*2,   ngf*8, dropout=True)        # 2 -> 4, concat d5
        self.up3 = UpBlock(ngf*8*2,   ngf*8, dropout=True)        # 4 -> 8, concat d4
        self.up4 = UpBlock(ngf*8*2,   ngf*4)                      # 8 -> 16, concat d3
        self.up5 = UpBlock(ngf*4*2,   ngf*2)                      # 16 -> 32, concat d2
        self.up6 = UpBlock(ngf*2*2,   ngf)                        # 32 -> 64, concat d1
        self.up7 = nn.ConvTranspose2d(ngf*2, output_nc,
                                      kernel_size=4, stride=2, padding=1)  # 64 -> 128

        self.tanh = nn.Tanh()

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)   # 64
        d2 = self.down2(d1)  # 32
        d3 = self.down3(d2)  # 16
        d4 = self.down4(d3)  # 8
        d5 = self.down5(d4)  # 4
        d6 = self.down6(d5)  # 2
        d7 = self.down7(d6)  # 1

        # Decoder + skips
        u1 = self.up1(d7)                # 1 -> 2
        u1 = torch.cat([u1, d6], dim=1)  # (B, 1024, 2, 2)

        u2 = self.up2(u1)                # 2 -> 4
        u2 = torch.cat([u2, d5], dim=1)  # (B, 1024, 4, 4)

        u3 = self.up3(u2)                # 4 -> 8
        u3 = torch.cat([u3, d4], dim=1)  # (B, 1024, 8, 8)

        u4 = self.up4(u3)                # 8 -> 16
        u4 = torch.cat([u4, d3], dim=1)  # (B, 512, 16, 16)

        u5 = self.up5(u4)                # 16 -> 32
        u5 = torch.cat([u5, d2], dim=1)  # (B, 256, 32, 32)

        u6 = self.up6(u5)                # 32 -> 64
        u6 = torch.cat([u6, d1], dim=1)  # (B, 128, 64, 64)

        out = self.up7(u6)               # 64 -> 128
        return self.tanh(out)


# ---------- PatchGAN discriminator (edge+shoe = 6 channels) ----------

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc=6, ndf=64, n_layers=3):
        super().__init__()
        kw = 4
        padw = 1

        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(
                    ndf * nf_mult_prev, ndf * nf_mult,
                    kernel_size=kw, stride=2, padding=padw, bias=False
                ),
                nn.BatchNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, inplace=True),
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(
                ndf * nf_mult_prev, ndf * nf_mult,
                kernel_size=kw, stride=1, padding=padw, bias=False
            ),
            nn.BatchNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        sequence += [
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
        ]

        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        return self.model(x)


# ---------- Instantiate models ----------

G_pix = Pix2PixUNet128(input_nc=3, output_nc=3, ngf=64).to(device)
D_pix = NLayerDiscriminator(input_nc=6, ndf=64).to(device)

print("G_pix params:", sum(p.numel() for p in G_pix.parameters()) / 1e6, "M")
print("D_pix params:", sum(p.numel() for p in D_pix.parameters()) / 1e6, "M")


In [None]:
import torch.nn.functional as F
from tqdm.auto import tqdm
import os
from pathlib import Path

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1  = nn.L1Loss()

lambda_L1 = 100.0
lr       = 2e-4
epochs   = 5

optimizer_G = torch.optim.Adam(G_pix.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D_pix.parameters(), lr=lr, betas=(0.5, 0.999))

CKPT_DIR_PIX = f"{PROJ}/checkpoints/edges2shoes_pix2pix"
Path(CKPT_DIR_PIX).mkdir(parents=True, exist_ok=True)

for epoch in range(1, epochs + 1):
    G_pix.train()
    D_pix.train()

    pbar = tqdm(train_e2s_dl, desc=f"[pix2pix E2S] Epoch {epoch}/{epochs}")
    for batch in pbar:
        edge      = batch["edge"].to(device)   # (B,3,128,128)
        real_shoe = batch["shoe"].to(device)   # (B,3,128,128)

        # --------------------- D ---------------------
        optimizer_D.zero_grad()

        real_input = torch.cat([edge, real_shoe], dim=1)   # (B,6,H,W)
        pred_real  = D_pix(real_input)
        valid      = torch.ones_like(pred_real)
        loss_D_real = criterion_GAN(pred_real, valid)

        with torch.no_grad():
            fake_shoe_detached = G_pix(edge)
        fake_input = torch.cat([edge, fake_shoe_detached], dim=1)
        pred_fake  = D_pix(fake_input)
        fake       = torch.zeros_like(pred_fake)
        loss_D_fake = criterion_GAN(pred_fake, fake)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # --------------------- G ---------------------
        optimizer_G.zero_grad()

        fake_shoe = G_pix(edge)
        fake_input_for_G = torch.cat([edge, fake_shoe], dim=1)
        pred_fake_for_G  = D_pix(fake_input_for_G)

        loss_G_GAN = criterion_GAN(pred_fake_for_G, valid)
        loss_G_L1  = criterion_L1(fake_shoe, real_shoe) * lambda_L1
        loss_G     = loss_G_GAN + loss_G_L1

        loss_G.backward()
        optimizer_G.step()

        pbar.set_postfix({
            "loss_D": float(loss_D.item()),
            "loss_G": float(loss_G.item()),
        })

    ckpt_path = os.path.join(CKPT_DIR_PIX, f"epoch_{epoch}.pt")
    torch.save({
        "G_pix": G_pix.state_dict(),
        "D_pix": D_pix.state_dict(),
        "epoch": epoch,
    }, ckpt_path)
    print(f"💾 Saved pix2pix checkpoint at epoch {epoch} -> {ckpt_path}")


In [None]:
import torch
import matplotlib.pyplot as plt

G_pix.eval()

# Use validation loader if it exists, else fall back to train
loader = None
if "val_e2s_dl" in globals():
    loader = val_e2s_dl
else:
    loader = train_e2s_dl

# Helper to denormalize from [-1,1] -> [0,1]
def denorm(t):
    t = (t + 1) / 2
    return t.clamp(0, 1)

# Take one batch
batch = next(iter(loader))
edge = batch["edge"].to(device)      # (B,3,128,128) or (B,1,128,128) depending on your dataset code
real_shoe = batch["shoe"].to(device) # (B,3,128,128)

with torch.no_grad():
    fake_shoe = G_pix(edge)

edge_cpu  = denorm(edge.cpu())
real_cpu  = denorm(real_shoe.cpu())
fake_cpu  = denorm(fake_shoe.cpu())

B = min(4, edge_cpu.size(0))  # show up to 4 examples

fig, axes = plt.subplots(B, 3, figsize=(9, 3*B))
if B == 1:
    axes = axes.reshape(1, 3)

for i in range(B):
    # Edge
    ax = axes[i, 0]
    img = edge_cpu[i]
    if img.size(0) == 1:
        img = img.repeat(3, 1, 1)
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.set_title("Edge")
    ax.axis("off")

    # Real shoe
    ax = axes[i, 1]
    ax.imshow(real_cpu[i].permute(1, 2, 0).numpy())
    ax.set_title("Real shoe")
    ax.axis("off")

    # GAN output
    ax = axes[i, 2]
    ax.imshow(fake_cpu[i].permute(1, 2, 0).numpy())
    ax.set_title("pix2pix (GAN) shoe")
    ax.axis("off")

plt.tight_layout()
plt.show()


In [None]:
import torch.nn.functional as F

G_pix.eval()
loader = val_e2s_dl if "val_e2s_dl" in globals() else train_e2s_dl

total_l1 = 0.0
count = 0

with torch.no_grad():
    for batch in loader:
        edge = batch["edge"].to(device)
        real_shoe = batch["shoe"].to(device)
        fake_shoe = G_pix(edge)

        l1 = F.l1_loss(fake_shoe, real_shoe, reduction="sum")
        total_l1 += l1.item()
        count += real_shoe.numel()  # number of pixels across batch

mean_l1 = total_l1 / count
print(f"Mean L1 per pixel (pix2pix GAN baseline): {mean_l1:.6f}")


In [None]:
#Qualitative comparison: SketchyCOCO vs FSCOCO vs edges2shoes

In [None]:
import itertools
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch

def take_n_from_dl(dataloader, n):
    """Take first n items from a dataloader, stack into a batch."""
    imgs = []
    it = iter(dataloader)
    while len(imgs) < n:
        batch = next(it)
        imgs.append(batch)
    return imgs[0]   # we know your loaders already give batch size >= n


In [None]:
#SD + LoRA generation helper

In [None]:
from torchvision.transforms.functional import to_pil_image

def sd_lora_compare(pipe_base, lora_dir, batch, title, num_examples=4):
    """
    batch: dict with keys:
      - "sketch"  -> (B, 1 or 3, H, W) optional
      - "image"   -> (B, 3, H, W)       real RGB
      - "caption" -> list of text
    """
    pipe_base.to("cuda")
    pipe_base.unload_lora_weights()  # make sure clean
    pipe_base.load_lora_weights(lora_dir)

    sketches = batch.get("sketch", None)
    reals    = batch["image"]
    caps     = batch["caption"]

    n = min(num_examples, reals.size(0))

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        gen_imgs = []
        for i in range(n):
            prompt = caps[i]
            img = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
            gen_imgs.append(img)

    fig, axes = plt.subplots(n, 3 if sketches is not None else 2,
                             figsize=(10 if sketches is not None else 7, 3*n))
    if n == 1:
        axes = axes.reshape(1, -1)

    for i in range(n):
        col = 0
        if sketches is not None:
            ax = axes[i, col]
            ax.imshow(to_pil_image((sketches[i].cpu() * 0.5 + 0.5).clamp(0,1)))
            ax.set_title("Sketch")
            ax.axis("off")
            col += 1

        ax = axes[i, col]
        ax.imshow(to_pil_image((reals[i].cpu() * 0.5 + 0.5).clamp(0,1)))
        ax.set_title("Real")
        ax.axis("off")

        ax = axes[i, col+1]
        ax.imshow(gen_imgs[i])
        ax.set_title("SD + LoRA")
        ax.axis("off")

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()


In [None]:
#Call it for the three diffusion datasets

In [None]:
import torch
from diffusers import StableDiffusionPipeline
from IPython.display import display
from PIL import Image
from pathlib import Path
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# --- Re-establish core paths and variables that might have been cleared ---
PROJ = "/content/drive/MyDrive/Gen AI/Gen AI project"
SUBSET_ROOT = f"{PROJ}/subsets"

device   = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "runwayml/stable-diffusion-v1-5"
IMAGE_SIZE = 512 # Ensure this matches your training image size
BATCH_SIZE_LOADER = 1 # For validation data loaders

# --- Helper functions (assuming they were previously defined) ---
# pil_to_tensor function definition (from cell XRBQ9kr4L2Rm)
def pil_to_tensor(img: Image.Image, size=IMAGE_SIZE, to_gray=False):
    """
    Resize, optionally convert to gray, convert to float tensor in [-1, 1].
    Returns shape (C, H, W).
    """
    if to_gray:
        img = img.convert("L")  # 1-channel
    else:
        img = img.convert("RGB")

    img = img.resize((size, size), Image.BICUBIC)
    arr = np.array(img).astype("float32") / 255.0

    if to_gray:
        if arr.ndim == 2:
            arr = arr[..., None]   # (H, W, 1)
    else:
        if arr.ndim == 2:  # just in case
            arr = np.stack([arr]*3, axis=-1)

    # HWC -> CHW
    arr = arr.transpose(2, 0, 1)
    # [0,1] -> [-1,1]
    arr = (arr * 2.0) - 1.0

    return torch.from_numpy(arr)

# Sketch2ImageDataset class definition (from cell d4oEhM0eL_gZ)
class Sketch2ImageDataset(Dataset):
    """
    Generic sketch → image dataset for diffusion, using manifests:
    columns: split, sketch_path, photo_path, caption
    """

    def __init__(self, manifest_csv, split="train", image_size=IMAGE_SIZE):
        self.manifest_csv = manifest_csv
        self.split = split
        self.image_size = image_size

        df = pd.read_csv(manifest_csv)
        if "split" not in df.columns:
            raise ValueError(f"'split' column not found in {manifest_csv}")

        self.df = df[df["split"] == split].reset_index(drop=True)
        if len(self.df) == 0:
            print(f"[WARN] No rows for split='{split}' in {manifest_csv}. This might lead to errors.")

        print(f"[{os.path.basename(manifest_csv)}] Loaded {len(self.df)} rows for split='{split}'")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        sketch_path = row["sketch_path"]
        photo_path  = row["photo_path"]
        caption     = row.get("caption", "")

        # Load images
        sketch_img = Image.open(sketch_path)
        photo_img  = Image.open(photo_path)

        # Convert to tensors
        sketch_tensor = pil_to_tensor(sketch_img, size=self.image_size, to_gray=True)   # (1, H, W)
        photo_tensor  = pil_to_tensor(photo_img,  size=self.image_size, to_gray=False)  # (3, H, W)

        return {
            "sketch": sketch_tensor,
            "image":  photo_tensor,
            "caption": str(caption),
            "sketch_path": sketch_path,
            "photo_path":  photo_path,
        }

# load_base_pipe function definition (from cell Y5FnGglnwPOQ)
def load_base_pipe(dtype=torch.float16):
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)
    pipe.enable_attention_slicing()
    return pipe

# take_n_from_dl function definition (from cell qZa0_4OzMBiK)
def take_n_from_dl(dataloader, n):
    """Take first n items from a dataloader, stack into a batch."""
    imgs = []
    it = iter(dataloader)
    while len(imgs) < n:
        try:
            batch = next(it)
            imgs.append(batch)
        except StopIteration:
            print(f"[WARN] Dataloader ran out of items after {len(imgs)} examples (requested {n}).")
            break

    if not imgs:
        raise ValueError("No items could be retrieved from the dataloader.")

    # Manually stack tensors from the list of dicts into a single dict of tensors
    stacked_batch = {
        key: torch.cat([item[key] for item in imgs], dim=0) if isinstance(imgs[0][key], torch.Tensor)
        else [item[key] for item in imgs]
        for key in imgs[0].keys()
    }
    return stacked_batch

# sd_lora_compare function definition (from cell 1k8GDW25MIxJ)
def sd_lora_compare(pipe_base, lora_dir, batch, title, num_examples=4):
    """
    batch: dict with keys:
      - "sketch"  -> (B, 1 or 3, H, W) optional
      - "image"   -> (B, 3, H, W)       real RGB
      - "caption" -> list of text
    """
    pipe_base.to("cuda")
    # Unload any previously loaded LoRA weights from the pipe_base before loading new ones
    pipe_base.unload_lora_weights()

    pipe_base.load_lora_weights(lora_dir)

    sketches = batch.get("sketch", None)
    reals    = batch["image"]
    caps     = batch["caption"]

    n = min(num_examples, reals.size(0))
    if n == 0:
        print(f"[WARN] No examples to display for {title}.")
        return

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        gen_imgs = []
        for i in range(n):
            prompt = caps[i]
            img = pipe_base(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
            gen_imgs.append(img)

    fig, axes = plt.subplots(n, 3 if sketches is not None else 2,
                             figsize=(10 if sketches is not None else 7, 3*n))
    if n == 1:
        axes = axes.reshape(1, -1)

    for i in range(n):
        col = 0
        if sketches is not None:
            ax = axes[i, col]
            # Convert 1-channel sketch to RGB for display (gray scale)
            sketch_img_tensor = sketches[i].cpu() * 0.5 + 0.5
            if sketch_img_tensor.shape[0] == 1:
                sketch_img_tensor = sketch_img_tensor.repeat(3, 1, 1)
            ax.imshow(to_pil_image(sketch_img_tensor.clamp(0,1)))
            ax.set_title("Sketch")
            ax.axis("off")
            col += 1

        ax = axes[i, col]
        ax.imshow(to_pil_image((reals[i].cpu() * 0.5 + 0.5).clamp(0,1)))
        ax.set_title("Real")
        ax.axis("off")

        ax = axes[i, col+1]
        ax.imshow(gen_imgs[i])
        ax.set_title("SD + LoRA")
        ax.axis("off")

    fig.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()


# === Paths to subset CSVs ===
SCENE_SUBSET_CSV        = f"{SUBSET_ROOT}/scene_subset.csv"
OBJECT_SUBSET_CSV       = f"{SUBSET_ROOT}/object_subset.csv"
FSC_SUBSET_CSV          = f"{SUBSET_ROOT}/fscoco_subset.csv"

# === Re-instantiate data loaders ===
obj_val_ds   = Sketch2ImageDataset(OBJECT_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)
obj_val_dl   = DataLoader(obj_val_ds,   batch_size=BATCH_SIZE_LOADER, shuffle=False, num_workers=2)

scene_val_ds = Sketch2ImageDataset(SCENE_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)
scene_val_dl   = DataLoader(scene_val_ds,   batch_size=BATCH_SIZE_LOADER, shuffle=False, num_workers=2)

fscoco_val_ds = Sketch2ImageDataset(FSC_SUBSET_CSV, split="val",   image_size=IMAGE_SIZE)
fscoco_val_dl   = DataLoader(fscoco_val_ds,   batch_size=BATCH_SIZE_LOADER, shuffle=False, num_workers=2)


# === Re-instantiate base pipeline for evaluation ===
pipe_base = load_base_pipe(torch.float16)

# === Define LORA_DIRs and checkpoint steps ===
# Using checkpoint step 1000 as a default for all, assuming they exist
CKPT_STEP_OBJ = 1000
LORA_OBJ_DIR  = f"{PROJ}/checkpoints/lora_object_steps/step_{CKPT_STEP_OBJ}"

CKPT_STEP_SCENE = 1000
LORA_SCENE_DIR  = f"{PROJ}/checkpoints/lora_scene_steps/step_{CKPT_STEP_SCENE}"

CKPT_STEP_FS = 1000
LORA_FSCOCO_DIR  = f"{PROJ}/checkpoints/lora_fscoco_steps/step_{CKPT_STEP_FS}"


# 1) SketchyCOCO objects
print("Generating for SketchyCOCO Objects...")
batch_obj = take_n_from_dl(obj_val_dl, 4)
sd_lora_compare(pipe_base, LORA_OBJ_DIR, batch_obj,
                title="SketchyCOCO Objects: Sketch / Real / SD+LoRA")

# 2) SketchyCOCO scenes
print("Generating for SketchyCOCO Scenes...")
batch_scene = take_n_from_dl(scene_val_dl, 4)
sd_lora_compare(pipe_base, LORA_SCENE_DIR, batch_scene,
                title="SketchyCOCO Scenes: Sketch / Real / SD+LoRA")

# 3) FSCOCO
print("Generating for FSCOCO...")
batch_fscoco = take_n_from_dl(fscoco_val_dl, 4)
sd_lora_compare(pipe_base, LORA_FSCOCO_DIR, batch_fscoco,
                title="FSCOCO: Sketch / Real / SD+LoRA")


In [None]:
# 1) SketchyCOCO objects
batch_obj = take_n_from_dl(obj_val_dl, 4)
sd_lora_compare(pipe_base, LORA_OBJ_DIR, batch_obj,
                title="SketchyCOCO Objects: Sketch / Real / SD+LoRA")

# 2) SketchyCOCO scenes
batch_scene = take_n_from_dl(scene_val_dl, 4)
sd_lora_compare(pipe_base, LORA_SCENE_DIR, batch_scene,
                title="SketchyCOCO Scenes: Sketch / Real / SD+LoRA")

# 3) FSCOCO
batch_fscoco = take_n_from_dl(fscoco_val_dl, 4)
sd_lora_compare(pipe_base, LORA_FSCOCO_DIR, batch_fscoco,
                title="FSCOCO: Sketch / Real / SD+LoRA")


In [None]:
#Call it for the three diffusion datasets

In [None]:
# 1) SketchyCOCO objects
batch_obj = take_n_from_dl(obj_val_dl, 4)
sd_lora_compare(pipe_base, LORA_OBJ_DIR, batch_obj,
                title="SketchyCOCO Objects: Sketch / Real / SD+LoRA")

# 2) SketchyCOCO scenes
batch_scene = take_n_from_dl(scene_val_dl, 4)
sd_lora_compare(pipe_base, LORA_SCENE_DIR, batch_scene,
                title="SketchyCOCO Scenes: Sketch / Real / SD+LoRA")

# 3) FSCOCO
batch_fscoco = take_n_from_dl(fscoco_val_dl, 4)
sd_lora_compare(pipe_base, LORA_FSCOCO_DIR, batch_fscoco,
                title="FSCOCO: Sketch / Real / SD+LoRA")


In [None]:
#Add edges2shoes pix2pix GAN row

In [None]:
G_pix.eval()

batch_e2s = next(iter(val_e2s_dl))   # Use val_e2s_dl instead of undefined test_e2s_dl
edge  = batch_e2s["edge"].to(device)  # (B,1,128,128)
real  = batch_e2s["shoe"].to(device)  # (B,3,128,128)
n = min(4, edge.size(0))

with torch.no_grad():
    fake  = G_pix(edge[:n]).cpu()

fig, axes = plt.subplots(n, 3, figsize=(10, 3*n))
if n == 1: # Reshape axes to be 2D if only one row
    axes = axes.reshape(1, -1)

for i in range(n):
    axes[i,0].imshow(to_pil_image((edge[i].cpu()*0.5+0.5).clamp(0,1)))
    axes[i,0].set_title("Edge")
    axes[i,0].axis("off")

    axes[i,1].imshow(to_pil_image((real[i].cpu()*0.5+0.5).clamp(0,1)))
    axes[i,1].set_title("Real shoe")
    axes[i,1].axis("off")

    axes[i,2].imshow(to_pil_image((fake[i]*0.5+0.5).clamp(0,1)))
    axes[i,2].set_title("pix2pix (GAN)")
    axes[i,2].axis("off")

plt.suptitle("edges2shoes: Edge / Real / pix2pix", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
#FID / Inception Score for one diffusion setup (SketchyCOCO scenes)

In [None]:
!pip install -q clean-fid torchmetrics

from cleanfid import fid
from torchvision.utils import save_image
import torch, os
from torchmetrics.image.inception import InceptionScore


In [None]:
REAL_DIR = f"{PROJ}/metrics/sketchycoco_scene/real"
FAKE_DIR = f"{PROJ}/metrics/sketchycoco_scene/fake"
os.makedirs(REAL_DIR, exist_ok=True)
os.makedirs(FAKE_DIR, exist_ok=True)

pipe_base.to("cuda")
pipe_base.unload_lora_weights()
pipe_base.load_lora_weights(LORA_SCENE_DIR)

num_samples = 500
saved = 0

scene_iter = iter(scene_val_dl)

while saved < num_samples:
    batch = next(scene_iter)
    imgs   = batch["image"].to(device)     # real images [-1,1]
    caps   = batch["caption"]

    bsz = imgs.size(0)
    take = min(bsz, num_samples - saved)

    # save real
    for i in range(take):
        save_image((imgs[i]*0.5+0.5).clamp(0,1),
                   os.path.join(REAL_DIR, f"{saved+i:05d}.png"))

    # generate fake and save using PIL
    for i in range(take):
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            pil = pipe_base(caps[i], num_inference_steps=30, guidance_scale=7.5).images[0]
        pil.save(os.path.join(FAKE_DIR, f"{saved+i:05d}.png"))

    saved += take
    print(f"saved {saved}/{num_samples}")

In [None]:
fid_score = fid.compute_fid(REAL_DIR, FAKE_DIR)
print("SketchyCOCO scene SD+LoRA FID:", fid_score)


In [None]:
#inception score

In [None]:
# Install torch-fidelity if not present
try:
    import torch_fidelity
    print("torch-fidelity is correctly imported.")
except ImportError:
    print("Installing torch-fidelity...")
    !pip install -q torch-fidelity
    print("Installation complete.")

# --- FIX: Force reload ALL relevant torchmetrics modules ---
import importlib
import torchmetrics.utilities.imports
import torchmetrics.image.fid
import torchmetrics.image.inception

print("Reloading torchmetrics modules (imports, fid, inception) to detect torch-fidelity...")
importlib.reload(torchmetrics.utilities.imports)
importlib.reload(torchmetrics.image.fid)
importlib.reload(torchmetrics.image.inception)
# --------------------------------------------------------------

from PIL import Image
from torchvision import transforms
import os
import torch
from torchmetrics.image.inception import InceptionScore

transform = transforms.Compose([
    transforms.Resize((299,299)),
    transforms.ToTensor(),
])

images = []
# Ensure we only try to load if FAKE_DIR exists and has files
if os.path.isdir(FAKE_DIR) and len(os.listdir(FAKE_DIR)) > 0:
    for fn in sorted(os.listdir(FAKE_DIR)):
        if fn.lower().endswith(('.png', '.jpg', '.jpeg')):
            img = Image.open(os.path.join(FAKE_DIR, fn)).convert("RGB")
            images.append(transform(img))

    if images:
        fake_tensor = torch.stack(images, dim=0).to(device)

        # Initialize metric with normalize=True to handle [0,1] float inputs
        is_metric = InceptionScore(normalize=True).to(device)

        # Compute
        is_score, is_std = is_metric(fake_tensor)
        print("Inception Score:", float(is_score), "+/-", float(is_std))
    else:
        print("No valid images found in FAKE_DIR to calculate Inception Score.")
else:
    print(f"Directory {FAKE_DIR} does not exist or is empty.")

In [None]:
#LoRA ablation: rank 4 vs 8 on FSCOCO

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from diffusers import StableDiffusionPipeline
from diffusers.utils import convert_state_dict_to_diffusers
from pathlib import Path
from tqdm.auto import tqdm

def train_fscoco_lora(rank, dataloader, max_steps=500, lr=5e-5):
    """
    Tiny LoRA training loop on FSCOCO for ablation.
    Uses ONLY the passed dataloader (e.g., fscoco_val_dl) and a fresh SD1.5 pipe.
    """

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ----- load base SD pipeline -----
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16,
        safety_checker=None,
    ).to(device)

    tokenizer     = pipe.tokenizer
    text_encoder  = pipe.text_encoder
    vae           = pipe.vae
    unet          = pipe.unet
    scheduler     = pipe.scheduler
    torch_dtype   = torch.float16

    # ----- attach LoRA to UNet -----
    lora_cfg = LoraConfig(
        r=rank,
        lora_alpha=rank,
        target_modules=["to_q", "to_k", "to_v", "to_out.0"],
        init_lora_weights="gaussian",
    )
    unet.add_adapter(lora_cfg)

    # freeze base UNet, train only LoRA params
    for p in unet.parameters():
        p.requires_grad_(False)
    lora_params = [p for n, p in unet.named_parameters() if "lora_" in n]
    for p in lora_params:
        p.requires_grad_(True)

    optimizer = AdamW(lora_params, lr=lr)
    global_step = 0

    # ----- local encoders (no external helpers needed) -----
    def encode_prompts_local(captions):
        enc = tokenizer(
            captions,
            padding="max_length",
            truncation=True,
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        )
        input_ids = enc.input_ids.to(device)
        with torch.no_grad():
            text_embeds = text_encoder(input_ids)[0]
        return text_embeds.to(device=device, dtype=torch_dtype)

    def encode_images_local(images):
        # images are in [-1,1] from dataset; bring to [0,1]
        imgs = (images + 1.0) / 2.0
        imgs = imgs.to(device=device, dtype=torch_dtype)
        with torch.no_grad():
            latents = vae.encode(imgs).latent_dist.sample()
        # SD1.5 scaling factor
        latents = latents * 0.18215
        return latents

    # ----- training loop -----
    while global_step < max_steps:
        pbar = tqdm(dataloader, desc=f"[FSCOCO LoRA rank={rank}] step={global_step}")
        for batch in pbar:
            if global_step >= max_steps:
                break

            images   = batch["image"].to(device)   # (B,3,H,W), range [-1,1]
            captions = batch["caption"]

            text_embeds = encode_prompts_local(captions)
            latents     = encode_images_local(images)

            bsz   = latents.shape[0]
            noise = torch.randn_like(latents)
            timesteps = torch.randint(
                0, scheduler.config.num_train_timesteps,
                (bsz,), device=device, dtype=torch.long
            )
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=text_embeds,
                ).sample
                loss = F.mse_loss(model_pred, noise, reduction="mean")

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lora_params, 1.0)
            optimizer.step()

            global_step += 1
            pbar.set_postfix({"loss": float(loss.item()), "step": global_step})

            if global_step >= max_steps:
                break

        if global_step >= max_steps:
            print(f"✅ Reached {max_steps} LoRA steps for rank={rank}")
            break

    # ----- save LoRA weights -----
    ckpt_dir = f"{PROJ}/checkpoints/fscoco_lora_rank{rank}"
    Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

    lora_state = get_peft_model_state_dict(unet)
    lora_state = convert_state_dict_to_diffusers(lora_state)

    StableDiffusionPipeline.save_lora_weights(
        save_directory=ckpt_dir,
        unet_lora_layers=lora_state,
        safe_serialization=True,
    )
    print(f" Saved FSCOCO LoRA rank={rank} -> {ckpt_dir}")
    return ckpt_dir


In [None]:
# # Use FSCOCO *val* loader as training source for this small ablation
# LORA_FSCOCO_R4_DIR = train_fscoco_lora(
#     rank=4,
#     dataloader=fscoco_val_dl,   # <--- change this if your loader has a different name
#     max_steps=501
# )
LORA_FSCOCO_R4_DIR = train_fscoco_lora(
    rank=4,
    dataloader=fscoco_val_dl,
    max_steps=1001,   # more steps
    lr=1e-5           # smaller LR for stability
)


In [None]:
#Visual comparison

In [None]:
batch_f = next(iter(fscoco_val_dl))

def compare_fscoco_lora(batch, lora_dir, num_examples=4):
    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16,
        safety_checker=None,
    ).to(device)

    sketches = batch.get("sketch", None)
    reals    = batch["image"]
    caps     = batch["caption"]
    n = min(num_examples, reals.size(0))

    # Load FSCOCO LoRA (rank-8)
    pipe.unload_lora_weights()
    pipe.load_lora_weights(lora_dir)
    gen_lora = []
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        for i in range(n):
            gen_lora.append(
                pipe(caps[i], num_inference_steps=30, guidance_scale=7.5).images[0]
            )

    # Plot: Sketch | Real | LoRA  (or Real | LoRA if no sketch)
    num_cols = 3 if sketches is not None else 2
    fig, axes = plt.subplots(n, num_cols, figsize=(12, 3 * n))
    if n == 1:
        axes = axes.reshape(1, -1)

    for i in range(n):
        col = 0
        if sketches is not None:
            axes[i, col].imshow(to_pil_image((sketches[i] * 0.5 + 0.5).clamp(0, 1)))
            axes[i, col].set_title("Sketch")
            axes[i, col].axis("off")
            col += 1

        axes[i, col].imshow(to_pil_image((reals[i] * 0.5 + 0.5).clamp(0, 1)))
        axes[i, col].set_title("Real")
        axes[i, col].axis("off")

        axes[i, col + 1].imshow(gen_lora[i])
        axes[i, col + 1].set_title("LoRA (FSCOCO)")
        axes[i, col + 1].axis("off")

    plt.tight_layout()
    plt.show()

compare_fscoco_lora(batch_f, LORA_FSCOCO_DIR)

