In [None]:
import os
import sys
import subprocess
import json
import textwrap
import pathlib
from pathlib import Path
import shutil

REPO_URL = "https://github.com/willloe/MSML640_Group10.git"
REPO_DIR = "/content/MSML640_Group10"
BRANCH = "feature/LoRA"

if not pathlib.Path(REPO_DIR).exists():
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print("Repo already present; pulling latest...")
    subprocess.run(["git", "-C", REPO_DIR, "pull", "--ff-only"], check=True)

subprocess.run(["git", "-C", REPO_DIR, "fetch", "origin", BRANCH], check=True)
subprocess.run(["git", "-C", REPO_DIR, "checkout", BRANCH], check=True)
subprocess.run(["git", "-C", REPO_DIR, "pull", "--ff-only"], check=True)

SRC = f"{REPO_DIR}/packages/diffusion/src"
if SRC not in sys.path:
    sys.path.append(SRC)

import torch, pkgutil
print("CUDA available:", torch.cuda.is_available())
print("CUDA device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

import importlib
diffusers = importlib.import_module("diffusers")
print("diffusers version:", getattr(diffusers, "__version__", "unknown"))

# print("\nSmoke_sdxl_load")
# subprocess.run(["python", f"{REPO_DIR}/scripts/smoke_sdxl_load.py"], check=False)

# print("\nSmoke_synthetic")
# subprocess.run(["python", f"{REPO_DIR}/scripts/smoke_synthetic.py"], check=False)

# print("\nSmoke_infer (base)")
# subprocess.run(["python", f"{REPO_DIR}/scripts/smoke_infer.py"], check=False)

# print("\nSmoke_infer (controlnet edge, dpmpp)")
# subprocess.run([
#     "python", f"{REPO_DIR}/scripts/smoke_infer.py",
#     "--use_controlnet", "1",
#     "--control_from", "edge",
#     "--scheduler", "dpmpp",
#     "--steps", "14",
#     "--guidance", "6.0",
#     "--width", "1024",
#     "--height", "768",
#     "--seed", "1234",
# ], check=False)

# print("\nSmoke_upscale_inpaint")
# subprocess.run(["python", f"{REPO_DIR}/scripts/smoke_upscale_inpaint.py"], check=False)

LORA_STYLE_DIR = Path(REPO_DIR) / "data" / "lora_style"
LORA_STYLE_DIR.mkdir(parents=True, exist_ok=True)

outputs_dir = Path(REPO_DIR) / "outputs"
pngs = sorted(outputs_dir.rglob("*.png"), key=lambda p: p.stat().st_mtime) if outputs_dir.exists() else []
copied = 0
for p in reversed(pngs):
    try:
        dest = LORA_STYLE_DIR / p.name
        shutil.copy2(p, dest)
        copied += 1
        if copied >= 12:
            break
    except Exception:
        pass

if copied == 0:
    from PIL import Image
    import numpy as np

    w, h = 768, 768
    xs = np.linspace(0.0, 1.0, w, dtype=np.float32)
    ys = np.linspace(0.0, 1.0, h, dtype=np.float32)
    X, Y = np.meshgrid(xs, ys)

    palettes = [
        ((24, 58, 150),  (197, 219, 255)),
        ((9, 96, 121),   (232, 246, 249)),
        ((32, 32, 32),   (200, 200, 200)),
        ((106, 27, 154), (236, 224, 246)),
        ((20, 97, 48),   (216, 240, 223)),
        ((120, 54, 35),  (244, 229, 224)),
    ]

    def to_img(c0, c1, alpha, vignette=0.12, noise_amp=0.02, seed=0):
        a = np.clip(alpha, 0.0, 1.0)

        rng = np.random.default_rng(seed)
        fx, fy = rng.uniform(0.4, 0.8), rng.uniform(0.4, 0.8)
        phase = rng.uniform(0, 2*np.pi)
        tex = np.sin(2*np.pi*(fx*X + fy*Y) + phase) * noise_amp

        R = np.sqrt((X - 0.5)**2 + (Y - 0.5)**2)
        vig = 1.0 - vignette * (R / R.max())**2

        a = np.clip(a + tex, 0.0, 1.0) * vig

        c0 = np.array(c0, dtype=np.float32)[None, None, :]
        c1 = np.array(c1, dtype=np.float32)[None, None, :]
        img = c0*(1.0 - a[..., None]) + c1*(a[..., None])
        return Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))

    seeds = [11, 22, 33, 44, 55, 66]

    alpha0 = X
    alpha1 = (X*np.cos(np.deg2rad(45)) + Y*np.sin(np.deg2rad(45)))
    alpha2 = Y
    alpha3 = (X*np.cos(np.deg2rad(135)) + Y*np.sin(np.deg2rad(135)))
    alpha4 = np.sqrt((X - 0.5)**2 + (Y - 0.5)**2)
    alpha4 = (alpha4 / alpha4.max())
    alpha5 = np.sqrt((X - 0.35)**2 + (Y - 0.6)**2)
    alpha5 = (alpha5 / alpha5.max())

    alphas = [alpha0, alpha1, alpha2, alpha3, alpha4, alpha5]

    for i in range(6):
        c0, c1 = palettes[i % len(palettes)]
        img = to_img(c0, c1, alphas[i], seed=seeds[i])
        out = LORA_STYLE_DIR / f"synthetic_{i:02d}.png"
        img.save(out)

    copied = 6

print(f"LoRA Style set prepared with {copied} images at: {LORA_STYLE_DIR}")
print("Sample files:", [p.name for p in list(LORA_STYLE_DIR.glob('*.png'))[:3]])

RUN_DIR = Path(REPO_DIR) / "outputs" / "lora" / "runs" / "exp01"
MANIFESTS = Path(RUN_DIR) / "manifests"
MANIFESTS.mkdir(parents=True, exist_ok=True)
CAP_JSONL = MANIFESTS / "captions.jsonl"

print("\nLoRA prepare_lora_dataset.py")
subprocess.run([
    "python", f"{REPO_DIR}/scripts/prepare_lora_dataset.py",
    "--images_dir", f"{LORA_STYLE_DIR}",
    "--out_jsonl", f"{CAP_JSONL}",
    "--fallback_caption", "soft abstract gradient, slide-safe, minimal clutter"
], check=False)

lines = CAP_JSONL.read_text().strip().splitlines() if CAP_JSONL.exists() else []
if len(lines) == 0:
    print("LoRA prepare_lora_dataset.py produced no lines; writing fallback captions.jsonl...")
    CAP_JSONL.parent.mkdir(parents=True, exist_ok=True)
    with CAP_JSONL.open("w", encoding="utf-8") as f:
        for p in sorted(LORA_STYLE_DIR.glob("*.png")):
            f.write(json.dumps({"image": str(p), "caption": "soft abstract gradient, slide-safe, minimal clutter"}) + "\n")
        for p in sorted(LORA_STYLE_DIR.glob("*.jpg")):
            f.write(json.dumps({"image": str(p), "caption": "soft abstract gradient, slide-safe, minimal clutter"}) + "\n")
    lines = CAP_JSONL.read_text().strip().splitlines()

print(f"LoRA captions.jsonl lines: {len(lines)}")
assert len(lines) > 0, "captions.jsonl is empty; style set may be empty."

print("\nLoRA train_lora.py (short run)")

env = os.environ.copy()
SRC = f"{REPO_DIR}/packages/diffusion/src"
env["PYTHONPATH"] = SRC + ":" + env.get("PYTHONPATH", "")


# cp = subprocess.run(
#     [
#         "python", "-u", f"{REPO_DIR}/scripts/train_lora.py",
#         "--images_dir", f"{LORA_STYLE_DIR}",
#         "--output_dir", f"{RUN_DIR}",
#         "--resolution", "512",
#         "--rank", "8",
#         "--batch_size", "1",
#         "--gradient_accumulation_steps", "1",
#         "--max_train_steps", "20",
#         "--train_jsonl", f"{CAP_JSONL}",
#         "--checkpoint_steps", "10",
#     ],
#     check=False,
#     text=True,
#     capture_output=True,
#     env=env,
# )

# print("=== train_lora.py STDOUT ===")
# print(cp.stdout)
# print("=== train_lora.py STDERR ===")
# print(cp.stderr)
# print("train_lora.py return code:", cp.returncode)

FINAL_LORA_DIR = Path(RUN_DIR) / "final_lora"
print("LoRA final_lora exists:", FINAL_LORA_DIR.exists(),
      "contents:", list(FINAL_LORA_DIR.glob("*")))

assert FINAL_LORA_DIR.exists(), "final_lora directory was not created."

lora_files = list(FINAL_LORA_DIR.glob("*.pt"))
assert lora_files, "No LoRA .pt file found in final_lora (expected unet_lora_peft.pt)."

print("Using LoRA file:", lora_files[0])
AB_OUT = Path(REPO_DIR) / "outputs" / "lora_ab"
AB_OUT.mkdir(parents=True, exist_ok=True)

print("\nLoRA smoke_lora_ab.py (A/B compare)")

cp2 = subprocess.run(
    [
        "python", "-u", f"{REPO_DIR}/scripts/smoke_lora_ab.py",
        "--lora_dir", f"{FINAL_LORA_DIR}",
        "--out_dir", f"{AB_OUT}",
        "--seed", "777",
        "--width", "512",   # 512x512 to be T4-friendly
        "--height", "512",
        "--steps", "20",
        "--guidance", "5.5",
        "--control_mode", "safe",
    ],
    check=False,
    text=True,
    capture_output=True,
    env=env,
)

print("=== smoke_lora_ab.py STDOUT ===")
print(cp2.stdout)
print("=== smoke_lora_ab.py STDERR ===")
print(cp2.stderr)
print("smoke_lora_ab.py return code:", cp2.returncode)

if cp2.returncode != 0:
    raise AssertionError("smoke_lora_ab.py failed; see stderr above.")

base_imgs = list(AB_OUT.glob("ab_seed*_base.png"))
lora_imgs = list(AB_OUT.glob("ab_seed*_lora.png"))
print("[A/B] base images:", base_imgs)
print("[A/B] lora images:", lora_imgs)
assert len(base_imgs) > 0 and len(lora_imgs) > 0, "A/B images not found. Check smoke_lora_ab output."

print("\nCompleted. Check /content/MSML640_Group10/outputs for images, including LoRA A/B outputs.")