In [None]:
"""
U-CNN (Conv) Pipeline 
=================================================
Goal
----
1) Train ONLY the U-CNN pipeline end-to-end on your Mac (M4) using PyTorch MPS:

2) Supervised classifier training with heavy imbalance handling, F1-based threshold tuning,
   early stopping, checkpoints, resume-on-start.

3) Produce submission CSV: columns [label, ID] where ID = "test/xxxxx.jpg".

"""

from __future__ import annotations

import os, json, time, math, random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms

# For safety with partially-corrupt JPEGs
ImageFile.LOAD_TRUNCATED_IMAGES = True

# -----------------------------
# 0) Repro + device
# -----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

DEVICE = get_device()
print("Device:", DEVICE)

# -----------------------------
# 1) Paths
# -----------------------------
DATA_ROOT = Path("/Users/zamir/AI/dsm/dsm-2025")
TRAIN_DIR = DATA_ROOT / "train"
NEG_DIR   = TRAIN_DIR / "0"
POS_DIR   = TRAIN_DIR / "1"
UNLAB_DIR = TRAIN_DIR / "Unlabeled"
TEST_DIR  = DATA_ROOT / "test"

OUT_DIR   = DATA_ROOT / "outputs_ucnn"
CKPT_DIR  = OUT_DIR / "_checkpoints"
OUT_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

print("DATA_ROOT:", DATA_ROOT)
print("TRAIN_DIR:", TRAIN_DIR)
print("NEG_DIR  :", NEG_DIR)
print("POS_DIR  :", POS_DIR)
print("UNLAB_DIR:", UNLAB_DIR)
print("TEST_DIR :", TEST_DIR)
print("OUT_DIR  :", OUT_DIR)
print("CKPT_DIR :", CKPT_DIR)

assert NEG_DIR.exists(), f"Missing: {NEG_DIR}"
assert POS_DIR.exists(), f"Missing: {POS_DIR}"
assert TEST_DIR.exists(), f"Missing: {TEST_DIR}"
assert UNLAB_DIR.exists(), f"Missing: {UNLAB_DIR}"

# -----------------------------
# 2) Speed knobs (Mac MPS)
# -----------------------------
# Mac MPS likes smallish batches; keep model lightweight.
IMG_SIZE = 224
NUM_WORKERS = 0  # IMPORTANT on macOS notebooks to avoid multiprocessing pickling issues
PIN_MEMORY = False  # MPS doesn't benefit like CUDA
PERSISTENT_WORKERS = False

# To fit memory and time budget:
BATCH_SSL = 32
BATCH_CLS = 32

# Time budget controls
MAX_HOURS_BUDGET = 4.0
START_TIME = time.time()

# SSL subset 
SSL_SUBSET = 30000  
EPOCHS_SSL = 10     
EPOCHS_CLS = 25    
EARLY_STOP_PATIENCE = 10


# -----------------------------
# 3) Helpers
# -----------------------------
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}

def now_h():
    return (time.time() - START_TIME) / 3600.0

def budget_ok():
    return now_h() <= MAX_HOURS_BUDGET

def list_images(folder: Path) -> List[Path]:
    # Uses full paths; sorted for determinism
    paths = [p for p in folder.rglob("*") if p.is_file() and p.suffix.lower() in IMG_EXTS]
    paths.sort()
    return paths

def safe_open_rgb(path: Path) -> Image.Image:
    # Robust open; if fails, raise to be filtered earlier.
    img = Image.open(path)
    img = img.convert("RGB")
    return img

def validate_images(paths: List[Path], max_check: int = 2000) -> Tuple[List[Path], List[Tuple[str, str]]]:
    """
    Quick validation on up to max_check images (speed).
    Returns: good_paths, bad_list [(path, err)]
    """
    bad = []
    good = []
    n = len(paths)
    # Only sample for speed
    idxs = list(range(n))
    random.shuffle(idxs)
    idxs = idxs[: min(max_check, n)]
    bad_set = set()

    for i, idx in enumerate(idxs, 1):
        p = paths[idx]
        try:
            _ = safe_open_rgb(p)
        except Exception as e:
            bad.append((str(p), repr(e)))
            bad_set.add(p)

        if i % 500 == 0:
            print(f"[validate] checked {i}/{len(idxs)} bad={len(bad)}")

    for p in paths:
        if p not in bad_set:
            good.append(p)
    return good, bad

def save_json(path: Path, obj: dict):
    path.write_text(json.dumps(obj, indent=2))

def load_json(path: Path) -> dict:
    return json.loads(path.read_text())

# -----------------------------
# 4) Build file lists (FULL PATHS ONLY)
# -----------------------------
neg_paths = list_images(NEG_DIR)
pos_paths = list_images(POS_DIR)
unlab_paths = list_images(UNLAB_DIR)
test_paths = list_images(TEST_DIR)

print("\nDisk counts (by full path):")
print("train/0:", len(neg_paths))
print("train/1:", len(pos_paths))
print("unlabeled:", len(unlab_paths))
print("test:", len(test_paths))

neg_names = set(p.name for p in neg_paths)
pos_names = set(p.name for p in pos_paths)
overlap = sorted(list(neg_names & pos_names))
print("\nFilename overlap (should be >0 in your dataset):", len(overlap))
if len(overlap) > 0:
    print("Overlap examples:", overlap[:10])
print("‚úÖ This code is safe because it never keys by filename-only.")

# -----------------------------
# 5) Build supervised split (full paths + labels)
# -----------------------------
# Ensures class ratio doesn't collapse:
# - Use ALL positives (293)
# - Sample negatives to keep training stable (e.g., 10x negatives)
# - Keep val with some positives

POS_ALL = pos_paths[:]  # 293
NEG_ALL = neg_paths[:]  # 5000

# Choose negative multiplier
NEG_MULT = 10  # "extreme" push: still keeps majority but reduces dominance
neg_needed = min(len(NEG_ALL), len(POS_ALL) * NEG_MULT)
random.shuffle(NEG_ALL)
NEG_USE = NEG_ALL[:neg_needed]

# Merge and stratify split
all_labeled = [(p, 1) for p in POS_ALL] + [(p, 0) for p in NEG_USE]
random.shuffle(all_labeled)

# Stratified split
val_frac = 0.25
pos_items = [(p,y) for (p,y) in all_labeled if y==1]
neg_items = [(p,y) for (p,y) in all_labeled if y==0]

n_pos_val = max(1, int(len(pos_items)*val_frac))
n_neg_val = max(1, int(len(neg_items)*val_frac))

pos_val = pos_items[:n_pos_val]
neg_val = neg_items[:n_neg_val]
pos_trn = pos_items[n_pos_val:]
neg_trn = neg_items[n_neg_val:]

train_items = pos_trn + neg_trn
val_items   = pos_val + neg_val
random.shuffle(train_items)
random.shuffle(val_items)

print("\nSplit:")
print(f"train: {len(train_items)}  pos={sum(y for _,y in train_items)} neg={sum(1-y for _,y in train_items)}")
print(f"val  : {len(val_items)}  pos={sum(y for _,y in val_items)} neg={sum(1-y for _,y in val_items)}")

# -----------------------------
# 6) Transforms (fixed 224x224 ALWAYS)
# -----------------------------
# Train aug: moderate (too strong can hurt small positives)
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.10, hue=0.02),
    transforms.ToTensor(),
])

eval_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
])

# SSL transform: produce clean + corrupted version
ssl_clean_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
])

def corrupt_tensor(x: torch.Tensor) -> torch.Tensor:
    # Simple corruption: gaussian noise + random masking blocks
    # (Keeps it fast; no fancy ops)
    noise = torch.randn_like(x) * 0.08
    x2 = (x + noise).clamp(0, 1)

    # random cutout blocks
    _, H, W = x2.shape
    for _ in range(6):
        h = random.randint(12, 40)
        w = random.randint(12, 40)
        y0 = random.randint(0, H - h)
        x0 = random.randint(0, W - w)
        x2[:, y0:y0+h, x0:x0+w] = 0.0
    return x2

# -----------------------------
# 7) Datasets (FULL PATH IDs)
# -----------------------------
class LabeledDataset(Dataset):
    def __init__(self, items: List[Tuple[Path,int]], tfm):
        self.items = items
        self.tfm = tfm

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        p, y = self.items[idx]
        img = safe_open_rgb(p)
        x = self.tfm(img)
        # return FULL PATH as id to avoid collisions
        return x, torch.tensor(y, dtype=torch.long), str(p)

class UnlabeledDenoiseDataset(Dataset):
    def __init__(self, paths: List[Path]):
        self.paths = paths

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = safe_open_rgb(p)
        clean = ssl_clean_tf(img)
        corrupt = corrupt_tensor(clean)
        return corrupt, clean, str(p)

class TestDataset(Dataset):
    def __init__(self, paths: List[Path]):
        self.paths = paths

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = safe_open_rgb(p)
        x = eval_tf(img)
        return x, str(p)

def make_loader(ds, batch, shuffle):
    return DataLoader(
        ds,
        batch_size=batch,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=PERSISTENT_WORKERS
    )

# -----------------------------
# 8) Model: Tiny U-Net style encoder + classifier head
# -----------------------------
def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1),
        nn.GroupNorm(8, out_ch),
        nn.SiLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, 3, padding=1),
        nn.GroupNorm(8, out_ch),
        nn.SiLU(inplace=True),
    )

class Encoder(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.stem = conv_block(3, base)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), conv_block(base, base*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), conv_block(base*2, base*4))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), conv_block(base*4, base*8))

    def forward(self, x):
        x0 = self.stem(x)      # 224
        x1 = self.down1(x0)    # 112
        x2 = self.down2(x1)    # 56
        x3 = self.down3(x2)    # 28
        return x3  # latent

class Decoder(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
                                 conv_block(base*8, base*4))
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
                                 conv_block(base*4, base*2))
        self.up0 = nn.Sequential(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
                                 conv_block(base*2, base))
        self.out = nn.Conv2d(base, 3, kernel_size=1)

    def forward(self, z):
        x = self.up2(z)   # 56
        x = self.up1(x)   # 112
        x = self.up0(x)   # 224
        x = self.out(x)
        x = torch.sigmoid(x)
        return x

class DenoiseAE(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.enc = Encoder(base=base)
        self.dec = Decoder(base=base)

    def forward(self, x):
        z = self.enc(x)
        return self.dec(z)

class Classifier(nn.Module):
    def __init__(self, base=32):
        super().__init__()
        self.enc = Encoder(base=base)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(base*8, base*4),
            nn.SiLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(base*4, 1),
        )

    def forward(self, x):
        z = self.enc(x)
        z = self.pool(z)
        logits = self.head(z).squeeze(1)
        return logits

# -----------------------------
# 9) Checkpoint / resume utils
# -----------------------------
def ckpt_path(name: str) -> Path:
    return CKPT_DIR / name

def save_ckpt(name: str, payload: dict):
    torch.save(payload, ckpt_path(name))

def load_ckpt(name: str):
    p = ckpt_path(name)
    if p.exists():
        return torch.load(p, map_location="cpu")
    return None

def mark_done(flag: str, meta: dict):
    p = CKPT_DIR / f"{flag}.json"
    save_json(p, {"done": True, "meta": meta, "time": time.time()})

def is_done(flag: str) -> bool:
    return (CKPT_DIR / f"{flag}.json").exists()

# -----------------------------
# 10) Metrics (F1 + threshold search)
# -----------------------------
def f1_at_threshold(probs: np.ndarray, y_true: np.ndarray, thr: float) -> Tuple[float,float,float]:
    y_pred = (probs >= thr).astype(np.int32)
    tp = np.sum((y_pred==1) & (y_true==1))
    fp = np.sum((y_pred==1) & (y_true==0))
    fn = np.sum((y_pred==0) & (y_true==1))
    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2*prec*rec/(prec+rec+1e-9)
    return float(f1), float(prec), float(rec)

def best_threshold(probs: np.ndarray, y_true: np.ndarray) -> Tuple[float,float,float,float]:
    best = (-1.0, 0.5, 0.0, 0.0)  # f1, thr, p, r
    for thr in np.linspace(0.05, 0.95, 19):
        f1, p, r = f1_at_threshold(probs, y_true, float(thr))
        if f1 > best[0]:
            best = (f1, float(thr), p, r)
    return best[0], best[1], best[2], best[3]

# -----------------------------
# 11) Training loops (MPS friendly)
# -----------------------------
def train_ssl(ae: DenoiseAE, loader: DataLoader, epochs: int, lr: float = 2e-4):
    ae.to(DEVICE)
    opt = torch.optim.AdamW(ae.parameters(), lr=lr, weight_decay=1e-4)

    start_epoch = 0
    resume = load_ckpt("ssl_last.pt")
    if resume is not None:
        try:
            ae.load_state_dict(resume["model"])
            opt.load_state_dict(resume["opt"])
            start_epoch = int(resume["epoch"]) + 1
            print(f"üîÅ [SSL] resume from epoch {start_epoch}")
        except Exception as e:
            print("‚ö†Ô∏è [SSL] could not load resume, starting fresh:", e)

    for epoch in range(start_epoch, epochs):
        if not budget_ok():
            print("‚è±Ô∏è Budget exceeded, stopping SSL early.")
            break

        ae.train()
        t0 = time.time()
        total = 0.0
        n = 0

        print(f"\nüöÄ [SSL] epoch {epoch+1}/{epochs} lr={lr:.2e}")

        for it, (x_corrupt, x_clean, _ids) in enumerate(loader, 1):
            x_corrupt = x_corrupt.to(DEVICE)
            x_clean   = x_clean.to(DEVICE)

            opt.zero_grad(set_to_none=True)
            recon = ae(x_corrupt)
            loss = F.l1_loss(recon, x_clean)

            loss.backward()
            opt.step()

            total += float(loss.item())
            n += 1

            # frequent prints (you asked)
            if it % 50 == 0:
                print(f"   [SSL] it {it:5d}/{len(loader)} loss={total/n:.4f}")

        avg = total / max(n,1)
        dt = time.time() - t0
        print(f"‚úÖ [SSL] epoch {epoch+1} avg_loss={avg:.4f} time={dt:.1f}s")

        save_ckpt("ssl_last.pt", {"model": ae.state_dict(), "opt": opt.state_dict(), "epoch": epoch})

    save_ckpt("ssl_final.pt", {"model": ae.state_dict()})
    mark_done("SSL_DONE", {"epochs": epochs, "subset": len(loader.dataset)})
    print("üèÅ [SSL] done")

def train_classifier(model: Classifier, train_loader: DataLoader, val_loader: DataLoader, epochs: int, lr: float = 2e-4):
    model.to(DEVICE)

    # Weighted BCE (helps with imbalance)
    # pos_weight = (#neg / #pos) in TRAIN set
    y_train = []
    for _, y, _ in train_loader.dataset:
        y_train.append(int(y))
    n_pos = sum(y_train)
    n_neg = len(y_train) - n_pos
    pos_weight = torch.tensor([max(1.0, n_neg / max(1, n_pos))], device=DEVICE)

    print(f"\n[CLS] pos_weight={float(pos_weight.item()):.3f} (neg={n_neg}, pos={n_pos})")

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    best_f1 = -1.0
    best_thr = 0.5
    best_epoch = -1
    bad_epochs = 0

    # Resume
    resume = load_ckpt("cls_last.pt")
    start_epoch = 0
    if resume is not None:
        try:
            model.load_state_dict(resume["model"])
            opt.load_state_dict(resume["opt"])
            best_f1 = float(resume.get("best_f1", best_f1))
            best_thr = float(resume.get("best_thr", best_thr))
            best_epoch = int(resume.get("best_epoch", best_epoch))
            start_epoch = int(resume["epoch"]) + 1
            bad_epochs = int(resume.get("bad_epochs", 0))
            print(f"üîÅ [CLS] resume from epoch {start_epoch} (best_f1={best_f1:.4f})")
        except Exception as e:
            print("‚ö†Ô∏è [CLS] could not load resume, starting fresh:", e)

    for epoch in range(start_epoch, epochs):
        if not budget_ok():
            print("‚è±Ô∏è Budget exceeded, stopping classifier early.")
            break

        # ---- train
        model.train()
        t0 = time.time()
        total = 0.0
        n = 0

        print(f"\nüöÄ [CLS] Epoch {epoch+1}/{epochs} lr={lr:.2e}")

        for it, (x, y, _ids) in enumerate(train_loader, 1):
            x = x.to(DEVICE)
            y = y.to(DEVICE).float()

            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=pos_weight)
            loss.backward()
            opt.step()

            total += float(loss.item())
            n += 1

            if it % 50 == 0:
                print(f"   [CLS] it {it:5d}/{len(train_loader)} loss={total/n:.4f}")

        train_loss = total / max(n,1)

        # ---- val
        model.eval()
        probs_all = []
        y_all = []
        with torch.no_grad():
            for x, y, _ids in val_loader:
                x = x.to(DEVICE)
                logits = model(x)
                probs = torch.sigmoid(logits).detach().cpu().numpy()
                probs_all.append(probs)
                y_all.append(y.numpy())

        probs_all = np.concatenate(probs_all, axis=0)
        y_all = np.concatenate(y_all, axis=0).astype(np.int32)

        f1, thr, p, r = best_threshold(probs_all, y_all)
        dt = time.time() - t0
        print(f"‚úÖ [CLS] epoch {epoch+1} train_loss={train_loss:.4f} val_f1={f1:.4f} thr={thr:.2f} (P={p:.3f}, R={r:.3f}) time={dt:.1f}s")

        # Save last
        save_ckpt("cls_last.pt", {
            "model": model.state_dict(),
            "opt": opt.state_dict(),
            "epoch": epoch,
            "best_f1": best_f1,
            "best_thr": best_thr,
            "best_epoch": best_epoch,
            "bad_epochs": bad_epochs,
        })

        # Best
        if f1 > best_f1 + 1e-5:
            best_f1 = f1
            best_thr = thr
            best_epoch = epoch
            bad_epochs = 0
            save_ckpt("cls_best.pt", {"model": model.state_dict(), "best_f1": best_f1, "best_thr": best_thr, "epoch": epoch})
            print(f"üíæ [CLS] New BEST saved: cls_best.pt (best_f1={best_f1:.4f}, thr={best_thr:.2f})")
        else:
            bad_epochs += 1
            if bad_epochs >= EARLY_STOP_PATIENCE:
                print(f"üõë [CLS] Early stop (no val F1 improvement for {EARLY_STOP_PATIENCE} epochs).")
                break

    mark_done("CLS_DONE", {"epochs": epochs, "best_f1": best_f1, "best_thr": best_thr})
    return best_f1, best_thr

# -----------------------------
# 12) Inference + submission
# -----------------------------
def make_submission(model: Classifier, test_loader: DataLoader, thr: float, out_csv: Path):
    model.to(DEVICE)
    model.eval()

    rows = []
    with torch.no_grad():
        for x, pid in test_loader:
            x = x.to(DEVICE)
            logits = model(x)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            for pth, pr in zip(pid, probs):
                # ID must be "test/filename.jpg"
                fname = Path(pth).name
                rows.append({"label": int(pr >= thr), "ID": f"test/{fname}"})

    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print("‚úÖ Wrote submission:", out_csv)
    return df

# -----------------------------
# 13) MAIN PIPELINE
# -----------------------------
def main():
    print("\n[Data] quick validation samples (optional)")
    _good_unlab, bad_unlab = validate_images(unlab_paths, max_check=800)
    if bad_unlab:
        (OUT_DIR / "bad_unlabeled_sample.json").write_text(json.dumps(bad_unlab[:200], indent=2))
        print(f"‚ö†Ô∏è Found some bad unlabeled samples: {len(bad_unlab)} (logged first 200)")

    # --- Build unlabeled subset 
    ssl_paths = unlab_paths[:]
    random.shuffle(ssl_paths)
    if SSL_SUBSET > 0:
        ssl_paths = ssl_paths[: min(SSL_SUBSET, len(ssl_paths))]
        print(f"\n‚ö° SSL subset enabled: {len(ssl_paths):,} unlabeled images")
    else:
        ssl_paths = []
        print("\n‚ö° SSL disabled (SSL_SUBSET=0)")

    print("\nCounts used by this run:")
    print("Unlabeled:", len(ssl_paths))
    print("Train labeled:", len(train_items), " (pos=", sum(y for _,y in train_items), ")")
    print("Val labeled  :", len(val_items),   " (pos=", sum(y for _,y in val_items),   ")")
    print("Test         :", len(test_paths))

    # --- Loaders
    train_ds = LabeledDataset(train_items, train_tf)
    val_ds   = LabeledDataset(val_items, eval_tf)
    test_ds  = TestDataset(test_paths)

    train_loader = make_loader(train_ds, BATCH_CLS, shuffle=True)
    val_loader   = make_loader(val_ds,   BATCH_CLS, shuffle=False)
    test_loader  = make_loader(test_ds,  BATCH_CLS, shuffle=False)

    # sanity batch sizes fixed
    xb, yb, _ = next(iter(train_loader))
    print("\nSanity batch shapes:", xb.shape, yb.shape, "(must be [B,3,224,224])")

    # --- SSL stage: denoise autoencoder
    if SSL_SUBSET > 0 and EPOCHS_SSL > 0:
        print("\n==============================")
        print("U-ONLY PIPELINE: SSL Denoise AE -> Classifier -> Submission")
        print("==============================")

        if is_done("SSL_DONE") and ckpt_path("ssl_final.pt").exists():
            print("‚è≠Ô∏è  Skip SSL (DONE flag found).")
        else:
            ssl_ds = UnlabeledDenoiseDataset(ssl_paths)
            ssl_loader = make_loader(ssl_ds, BATCH_SSL, shuffle=True)

            # Use a smaller base if memory is tight
            ae = DenoiseAE(base=24)  # lighter than 32
            train_ssl(ae, ssl_loader, EPOCHS_SSL, lr=2e-4)

    # --- Classifier stage
    cls = Classifier(base=24)

    # If SSL exists, load encoder weights into classifier encoder
    ssl_final = load_ckpt("ssl_final.pt")
    if ssl_final is not None and "model" in ssl_final:
        print("\n[Init] Loading encoder weights from SSL AE into classifier.")
        ae_state = ssl_final["model"]
        # Copy only encoder keys
        enc_state = {k.replace("enc.", ""): v for k, v in ae_state.items() if k.startswith("enc.")}
        missing, unexpected = cls.enc.load_state_dict(enc_state, strict=False)
        print("Encoder load missing:", len(missing), "unexpected:", len(unexpected))

    best_f1, best_thr = train_classifier(cls, train_loader, val_loader, EPOCHS_CLS, lr=2e-4)

    # Load BEST before submission
    best_ckpt = load_ckpt("cls_best.pt")
    if best_ckpt is not None and "model" in best_ckpt:
        cls.load_state_dict(best_ckpt["model"])
        best_thr = float(best_ckpt.get("best_thr", best_thr))
        best_f1  = float(best_ckpt.get("best_f1", best_f1))
        print(f"\n[Best] loaded cls_best.pt best_f1={best_f1:.4f} thr={best_thr:.2f}")

    # Submission
    out_csv = OUT_DIR / "submission_ucnn.csv"
    df = make_submission(cls, test_loader, best_thr, out_csv)

    # Preview
    print("\nSubmission head:")
    print(df.head(10))

    # Save run meta
    meta = {
        "device": str(DEVICE),
        "img_size": IMG_SIZE,
        "ssl_subset": SSL_SUBSET,
        "epochs_ssl": EPOCHS_SSL,
        "epochs_cls": EPOCHS_CLS,
        "best_f1": best_f1,
        "best_thr": best_thr,
        "time_hours": now_h(),
        "train_counts": {
            "train_total": len(train_items),
            "train_pos": int(sum(y for _,y in train_items)),
            "train_neg": int(sum(1-y for _,y in train_items)),
            "val_total": len(val_items),
            "val_pos": int(sum(y for _,y in val_items)),
            "val_neg": int(sum(1-y for _,y in val_items)),
        },
    }
    save_json(OUT_DIR / "run_meta.json", meta)
    print("\n‚úÖ Wrote meta:", OUT_DIR / "run_meta.json")

if __name__ == "__main__":
    main()
