# Ancient-to-Film GAN ‚Äî CycleGAN Training Notebook
**Deliverable 2 | Implementation & Early Evaluation**

This notebook provides a minimal, end-to-end CycleGAN training scaffold for unpaired image-to-image translation.
- **Domain A**: Ancient paintings (`data/A`)
- **Domain B**: Film-style photos (`data/B`)

> Tip: Start with small images (128√ó128) and a few epochs to verify the pipeline.

## 0. Environment & Config

In [None]:

import os, random, itertools, time, math
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Paths
ROOT = Path("..").resolve() if (Path.cwd().name == "notebooks") else Path(".").resolve()
DATA_A = ROOT / "data" / "A"
DATA_B = ROOT / "data" / "B"
OUT_SAMPLES = ROOT / "results" / "samples"
OUT_CKPTS = ROOT / "results" / "checkpoints"
OUT_SAMPLES.mkdir(parents=True, exist_ok=True)
OUT_CKPTS.mkdir(parents=True, exist_ok=True)

# Hyperparameters (start small for demo)
IMG_SIZE = 128
BATCH_SIZE = 2
EPOCHS = 5
LR_G = 2e-4
LR_D = 2e-4
LAMBDA_CYCLE = 10.0
LAMBDA_ID = 5.0
NUM_WORKERS = 2


: 

## 1. Dataset (Unpaired)

In [None]:

# --- Robust Dataset Setup (safe version) ---
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch

# Ëã•Áí∞Â¢É cell Ê≤íË∑ëÔºåÁµ¶È†êË®≠ÂÄº
IMG_SIZE = globals().get("IMG_SIZE", 128)
BATCH_SIZE = globals().get("BATCH_SIZE", 2)
NUM_WORKERS = 0  # üëà Âú® macOS / Windows Âª∫Ë≠∞ÂÖàÁî® 0ÔºåÈÅøÂÖçÂ§öÈÄ≤Á®ãÂïèÈ°å

# Ê†πÁõÆÈåÑÊé®Êñ∑ÔºàNotebookÈÄöÂ∏∏Âú® notebooks/ ÂÖßÔºâ
ROOT = Path("..").resolve() if (Path.cwd().name == "notebooks") else Path(".").resolve()
DATA_A = ROOT / "data" / "A"
DATA_B = ROOT / "data" / "B"
print("ROOT =", ROOT)
print("DATA_A exists:", DATA_A.exists(), "| DATA_B exists:", DATA_B.exists())

# Ê™¢Êü•ÂúñÊ™îÊï∏Èáè
def count_images(p):
    exts = ["*.jpg","*.jpeg","*.png","*.JPG","*.JPEG","*.PNG"]
    files = []
    for e in exts:
        files += list(p.glob(e))
    return len(files)

cntA = count_images(DATA_A)
cntB = count_images(DATA_B)
print(f"Found images -> A: {cntA} | B: {cntB}")
assert cntA > 0 and cntB > 0, "data/A Êàñ data/B Ê≤íÊúâÊâæÂà∞ÂúñÊ™îÔºà.jpg/.pngÔºâ„ÄÇ"

# Ë≥áÊñôËΩâÊèõ
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
])

class UnpairedImageDataset(Dataset):
    def __init__(self, dir_a, dir_b, transform=None):
        exts = {".jpg",".jpeg",".png",".JPG",".JPEG",".PNG"}
        self.paths_a = sorted([p for p in Path(dir_a).glob("*") if p.suffix in exts])
        self.paths_b = sorted([p for p in Path(dir_b).glob("*") if p.suffix in exts])
        self.transform = transform
        if len(self.paths_a) == 0 or len(self.paths_b) == 0:
            raise RuntimeError("Ë´ãÁ¢∫Ë™ç data/A Ëàá data/B ÂÖßÊúâÂúñÊ™îÔºà.jpg/.pngÔºâ„ÄÇ")
    def __len__(self):
        return max(len(self.paths_a), len(self.paths_b))
    def __getitem__(self, idx):
        pa = self.paths_a[idx % len(self.paths_a)]
        pb = self.paths_b[random.randint(0, len(self.paths_b)-1)]
        ia, ib = Image.open(pa).convert("RGB"), Image.open(pb).convert("RGB")
        if self.transform:
            ia = self.transform(ia)
            ib = self.transform(ib)
        return ia, ib

ds = UnpairedImageDataset(DATA_A, DATA_B, transform=tfm)
# drop_last=True ÂèØÈÅøÂÖçÊúÄÂæå‰∏ÄÂÄã batch ‰∏çË∂≥ÈÄ†Êàê shape ÂïèÈ°å
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
print("‚úÖ DataLoader Âª∫Á´ãÊàêÂäüÔºådataset Èï∑Â∫¶ =", len(ds))



In [None]:
import sys, platform
print("Python:", sys.version)
print("Executable:", sys.executable)

In [None]:
import numpy as np, PIL, torch, torchvision
print("numpy:", np.__version__)
print("Pillow:", PIL.__version__)
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)

In [None]:
import numpy as np, torch, torchvision, PIL
print(np.__version__)      # ÊúüÊúõ 1.26.4
print(torch.__version__)   # 2.2.2
print(torchvision.__version__)  # 0.17.2
x = torch.from_numpy(np.zeros((3,3), dtype=np.float32))  # ‰∏çË©≤Â†±ÈåØ
print(x.shape)

In [None]:
import numpy as np, torch, torchvision, PIL
print("numpy:", np.__version__)          # 2.xÔºàÂèØ‰øùÁïôÔºâ
print("torch:", torch.__version__)       # 2.4.xÔºàÊàñÊõ¥È´ò‰ΩÜ <2.6Ôºâ
print("torchvision:", torchvision.__version__)  # 0.19.xÔºàÊàñ <0.21Ôºâ
print("pillow:", PIL.__version__)
import numpy as _np, torch as _t
print(_t.from_numpy(_np.zeros((2,2), dtype=_np.float32)).shape)  # ÊáâËº∏Âá∫ torch.Size([2, 2])

### Preview a mini-batch

In [None]:

import matplotlib.pyplot as plt
import torch
import numpy

def denorm(x):
    return (x * 0.5 + 0.5).clamp(0,1)

# Á¢∫‰øù dl Â≠òÂú®
if 'dl' not in globals():
    raise RuntimeError("DataLoader (dl) Â∞öÊú™Âª∫Á´ãÔºåË´ãÂÖàÂü∑Ë°å Dataset ÁöÑ cell„ÄÇ")

batch = next(iter(dl))
a_batch, b_batch = batch
print("Batch shapes:", a_batch.shape, b_batch.shape)

fig, axes = plt.subplots(1, a_batch.shape[0]*2, figsize=(12, 3))
for i in range(a_batch.shape[0]):
    axes[2*i].imshow(denorm(a_batch[i]).permute(1,2,0).numpy())
    axes[2*i].set_title("A (Ancient)"); axes[2*i].axis("off")
    axes[2*i+1].imshow(denorm(b_batch[i]).permute(1,2,0).numpy())
    axes[2*i+1].set_title("B (Film)"); axes[2*i+1].axis("off")
plt.show()



## 2. Models (Generator & Discriminator)

In [None]:

from src.cyclegan_min import GeneratorResnet, DiscriminatorPatchGAN

# Generators: A->B (G_AB) and B->A (G_BA)
G_AB = GeneratorResnet().to(DEVICE)
G_BA = GeneratorResnet().to(DEVICE)

# Discriminators: for domain B (D_B) and domain A (D_A)
D_B = DiscriminatorPatchGAN().to(DEVICE)
D_A = DiscriminatorPatchGAN().to(DEVICE)

# Init weights
def init_weights(m):
    if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, torch.nn.InstanceNorm2d):
        if m.affine:
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

G_AB.apply(init_weights); G_BA.apply(init_weights); D_A.apply(init_weights); D_B.apply(init_weights)

# Losses & Optimizers
mse = nn.MSELoss()
l1  = nn.L1Loss()
opt_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=LR_G, betas=(0.5, 0.999))
opt_D_A = torch.optim.Adam(D_A.parameters(), lr=LR_D, betas=(0.5, 0.999))
opt_D_B = torch.optim.Adam(D_B.parameters(), lr=LR_D, betas=(0.5, 0.999))

# Buffers for GAN targets
def real_like(x): return torch.ones_like(x, device=DEVICE)
def fake_like(x): return torch.zeros_like(x, device=DEVICE)


## 3. Training Loop (Toy Epochs)

In [None]:

def save_sample(a, b_fake, step, out_dir=OUT_SAMPLES):
    a = denorm(a.detach().cpu())
    b_fake = denorm(b_fake.detach().cpu())
    grid = torch.cat([a, b_fake], dim=0)  # stack
    # save first pair
    a0 = (a[0].permute(1,2,0).numpy()*255).astype("uint8")
    b0 = (b_fake[0].permute(1,2,0).numpy()*255).astype("uint8")
    Image.fromarray(a0).save(out_dir / f"step{step:06d}_A.png")
    Image.fromarray(b0).save(out_dir / f"step{step:06d}_AtoB.png")

step = 0
for epoch in range(1, EPOCHS+1):
    pbar = tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for a, b in pbar:
        a = a.to(DEVICE); b = b.to(DEVICE)

        # --------------------
        #  Train Generators
        # --------------------
        opt_G.zero_grad()

        b_fake = G_AB(a)
        a_rec  = G_BA(b_fake)
        a_fake = G_BA(b)
        b_rec  = G_AB(a_fake)

        # identity loss (optional but stabilizes)
        loss_id_a = l1(a_fake, b) * LAMBDA_ID
        loss_id_b = l1(b_fake, a) * LAMBDA_ID

        # adversarial
        pred_b = D_B(b_fake)
        pred_a = D_A(a_fake)
        loss_gan_ab = mse(pred_b, real_like(pred_b))
        loss_gan_ba = mse(pred_a, real_like(pred_a))

        # cycle-consistency
        loss_cyc_a = l1(a_rec, a) * LAMBDA_CYCLE
        loss_cyc_b = l1(b_rec, b) * LAMBDA_CYCLE

        loss_G = loss_id_a + loss_id_b + loss_gan_ab + loss_gan_ba + loss_cyc_a + loss_cyc_b
        loss_G.backward()
        opt_G.step()

        # --------------------
        #  Train D_A
        # --------------------
        opt_D_A.zero_grad()
        pred_real_a = D_A(a)
        pred_fake_a = D_A(a_fake.detach())
        loss_D_A = (mse(pred_real_a, real_like(pred_real_a)) + mse(pred_fake_a, fake_like(pred_fake_a))) * 0.5
        loss_D_A.backward()
        opt_D_A.step()

        # --------------------
        #  Train D_B
        # --------------------
        opt_D_B.zero_grad()
        pred_real_b = D_B(b)
        pred_fake_b = D_B(b_fake.detach())
        loss_D_B = (mse(pred_real_b, real_like(pred_real_b)) + mse(pred_fake_b, fake_like(pred_fake_b))) * 0.5
        loss_D_B.backward()
        opt_D_B.step()

        step += 1
        if step % 50 == 0:
            save_sample(a, b_fake, step)
        pbar.set_postfix({
            "G": f"{loss_G.item():.3f}",
            "D_A": f"{loss_D_A.item():.3f}",
            "D_B": f"{loss_D_B.item():.3f}"
        })

    # save checkpoint each epoch
    torch.save({
        "G_AB": G_AB.state_dict(),
        "G_BA": G_BA.state_dict(),
        "D_A": D_A.state_dict(),
        "D_B": D_B.state_dict(),
        "epoch": epoch
    }, OUT_CKPTS / f"cyclegan_epoch_{epoch:02d}.pt")
    print(f"[Epoch {epoch}] checkpoint saved.")
print("Training loop finished.")


## 4. Inference Helper

In [None]:

@torch.inference_mode()
def translate_image(path_in, path_out, ckpt=None, direction="A2B"):
    img = Image.open(path_in).convert("RGB")
    x = tfm(img).unsqueeze(0).to(DEVICE)
    if ckpt:
        state = torch.load(ckpt, map_location=DEVICE)
        G_AB.load_state_dict(state["G_AB"]); G_BA.load_state_dict(state["G_BA"])
    if direction == "A2B":
        y = G_AB(x)
    else:
        y = G_BA(x)
    y = denorm(y[0].cpu()).permute(1,2,0).numpy()
    Image.fromarray((y*255).astype("uint8")).save(path_out)
    return path_out

# Example:
# translate_image(ROOT/'data/A/sample.jpg', ROOT/'results/samples/sample_A2B.png', ckpt=ROOT/'results/checkpoints/cyclegan_epoch_01.pt')


## 5. Notes for Deliverable 2 Report


- Keep **EPOCHS small** (e.g., 3‚Äì5) and **IMG_SIZE=128** for a quick demo.
- Save a few outputs in `results/samples/` and insert them into your report.
- Record losses per epoch; optionally add a simple loss curve.
- Mention compute setup (CPU/GPU), and any training instability or artifacts observed.
- For interface (Step 3), load the latest checkpoint in `ui/app.py` and call `translate_image`.
