In [7]:
#Adain Hyper Parameter
import sys, os, platform, subprocess, shutil
import os, math, random, time, json, glob
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
from torchvision import models, utils

# Force CPU for now so everything runs
USE_CPU_FOR_NOW = True

import torch
device = torch.device('cpu' if USE_CPU_FOR_NOW else ('cuda' if torch.cuda.is_available() else 'cpu'))
print("Device:", device)


# Reproducibility
torch.manual_seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cpu
Device: cuda


In [8]:
USE_CPU_FOR_NOW = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [2]:
# ============================================================
# Neural Style Transfer (AdaIN, Huang & Belongie 2017)
# - Hyperparameter sweep over lambda_content and lambda_style
# - Optional alpha blending for AdaIN targets
# - Saves sample grids & checkpoints every 200 iterations
# - FIXED: config copying works even if Cfg uses class attributes
# ============================================================

import os, random, time, json
from pathlib import Path
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
from torchvision import models, utils

# -----------------------------
# Device & Reproducibility
# -----------------------------
torch.backends.cudnn.benchmark = True
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# -----------------------------
# Config
# -----------------------------
class Cfg:
    # Paths (adjust if needed)
    data_root = str((Path.cwd() / "").resolve())
    content_dir = os.path.join(data_root, "content")
    style_dir   = os.path.join(data_root, "style")
    out_dir     = "./Folder/adain_runs3"

    # Training
    image_size_crop   = 256
    resize_shorter_to = 512
    batch_size = 8
    num_workers = 2
    lr = 1e-4
    max_iterations = 80_000

    # Logging & saving
    save_every = 5000            # <-- save grid + checkpoint every 200 steps
    log_every  = 200

    # Loss weights (total = λc * Lc + λs * Ls)
    lambda_content = 0.5
    lambda_style   = 10.0

    # AdaIN blend strength for target feature (1.0 = pure style stats)
    alpha = 1.0

    # Resume (leave None for fresh runs)
    resume = None

cfg = Cfg()
os.makedirs(cfg.out_dir, exist_ok=True)

# Helper: turn a Cfg instance (class-level attrs) into a dict
def cfg_to_dict(obj):
    out = {}
    for k in dir(obj):
        if k.startswith("_"):
            continue
        v = getattr(obj, k)
        if callable(v):
            continue
        out[k] = v
    return out

print("CFG:", cfg_to_dict(cfg))

# -----------------------------
# Basic image helpers
# -----------------------------
IMG_EXTS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')

def count_imgs(p):
    p = Path(p)
    return sum(1 for f in p.rglob("*") if f.suffix.lower() in IMG_EXTS)

def first_img(p):
    p = Path(p)
    for f in p.rglob("*"):
        if f.suffix.lower() in IMG_EXTS:
            return str(f)
    return None

print("content_dir exists:", os.path.isdir(cfg.content_dir), "n_images:", count_imgs(cfg.content_dir) if os.path.isdir(cfg.content_dir) else 0)
print("style_dir   exists:", os.path.isdir(cfg.style_dir),   "n_images:", count_imgs(cfg.style_dir)   if os.path.isdir(cfg.style_dir)   else 0)
print("example content:", first_img(cfg.content_dir))
print("example style  :", first_img(cfg.style_dir))

# -----------------------------
# Dataset
# -----------------------------
class ImageFolderFlat(Dataset):
    def __init__(self, root, resize_shorter_to=512, crop_size=256):
        self.paths = []
        for p in sorted(Path(root).rglob("*")):
            if p.suffix.lower() in IMG_EXTS:
                self.paths.append(str(p))
        if not self.paths:
            raise RuntimeError(f"No images found under {root}")

        self.transform = T.Compose([
            T.Lambda(lambda im: im.convert("RGB")),
            T.Resize(resize_shorter_to, interpolation=Image.BICUBIC),
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx % len(self.paths)]
        img = Image.open(p)
        return self.transform(img)

class PairDataset(Dataset):
    """
    Iterates over content images; for each content item, pick a random style item.
    """
    def __init__(self, content_root, style_root, resize_shorter_to=512, crop_size=256):
        self.content_ds = ImageFolderFlat(content_root, resize_shorter_to, crop_size)
        self.style_ds   = ImageFolderFlat(style_root,   resize_shorter_to, crop_size)
        self.style_len  = len(self.style_ds)

    def __len__(self):
        return len(self.content_ds)

    def __getitem__(self, idx):
        content_img = self.content_ds[idx]
        style_img   = self.style_ds[random.randrange(self.style_len)]
        return content_img, style_img

# Build dataloader once; reused across sweeps
train_ds = PairDataset(cfg.content_dir, cfg.style_dir, cfg.resize_shorter_to, cfg.image_size_crop)
train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    pin_memory=(device.type == 'cuda'),
    drop_last=True
)
print("Dataset size:", len(train_ds), "Batches/epoch (approx):", len(train_loader))

# -----------------------------
# VGG Encoder & AdaIN
# -----------------------------
# Map torchvision VGG19 feature indices to friendly names
LAYER_NAME_MAP = {
    1 : 'relu1_1',
    6 : 'relu2_1',
    11: 'relu3_1',
    20: 'relu4_1',
}

STYLE_LAYERS = ['relu1_1','relu2_1','relu3_1','relu4_1']
CONTENT_LAYER = 'relu4_1'

def calc_mean_std(feat, eps=1e-5):
    # feat: (B, C, H, W)
    B, C = feat.size()[:2]
    feat_var = feat.view(B, C, -1).var(dim=2, unbiased=False) + eps
    feat_std = feat_var.sqrt().view(B, C, 1, 1)
    feat_mean = feat.view(B, C, -1).mean(dim=2).view(B, C, 1, 1)
    return feat_mean, feat_std

def adain(content_feat, style_feat, eps=1e-5):
    # Channel-wise align mean & std of content to style
    c_mean, c_std = calc_mean_std(content_feat, eps)
    s_mean, s_std = calc_mean_std(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

class VGGEncoder(nn.Module):
    """
    VGG-19 (imagenet) up to relu4_1. Returns dict of selected layer activations.
    """
    def __init__(self):
        super().__init__()
        try:
            self.vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        except Exception:
            self.vgg = models.vgg19(pretrained=True).features
        for p in self.vgg.parameters():
            p.requires_grad_(False)

    def forward(self, x, out_keys=('relu1_1','relu2_1','relu3_1','relu4_1')):
        feats = {}
        h = x
        for i, layer in enumerate(self.vgg):
            h = layer(h)
            name = LAYER_NAME_MAP.get(i, None)
            if name in out_keys:
                feats[name] = h
            if i >= 20:  # after relu4_1
                pass
        return feats

# -----------------------------
# Decoder (mirror-ish of VGG)
# -----------------------------
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            # relu4_1 -> block3
            nn.ReflectionPad2d(1), nn.Conv2d(512, 256, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 128, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 128, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 64, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 64, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 3, 3)
        )

    def forward(self, x):
        return self.body(x)

# -----------------------------
# Loss wrapper & style loss
# -----------------------------
class LossNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = VGGEncoder()
        self.encoder.eval()
        for p in self.encoder.parameters():
            p.requires_grad_(False)

    @torch.no_grad()
    def encode(self, x, keys=None):
        return self.encoder(x, out_keys=tuple(keys) if keys else tuple(STYLE_LAYERS))

def mean_std_loss(x_feats, y_feats):
    """
    IN-statistics style loss: sum over layers of (mean MSE + std MSE)
    """
    loss = 0.0
    for k in STYLE_LAYERS:
        xm, xs = calc_mean_std(x_feats[k])
        ym, ys = calc_mean_std(y_feats[k])
        loss = loss + F.mse_loss(xm, ym) + F.mse_loss(xs, ys)
    return loss

# -----------------------------
# Utility: (de)normalization & saving
# -----------------------------
@torch.no_grad()
def denorm_for_save(x):
    # x is normalized (ImageNet), bring it back to [0,1] for saving
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
    y = x * std + mean
    return torch.clamp(y, 0, 1)

# -----------------------------
# Training: one run with overrides  (FIXED cfg copy)
# -----------------------------
def train_once(cfg_base, overrides):
    # Build a dict from cfg_base (even if class-level attrs), then apply overrides
    base_dict = cfg_to_dict(cfg_base)
    base_dict.update(overrides)
    run_cfg = SimpleNamespace(**base_dict)

    # Unique run folder
    tag = f"c{run_cfg.lambda_content:g}_s{run_cfg.lambda_style:g}_a{run_cfg.alpha:g}_it{run_cfg.max_iterations}"
    run_out = os.path.join(run_cfg.out_dir, tag)
    os.makedirs(run_out, exist_ok=True)

    # Fresh model & optimizer
    decoder = Decoder().to(device)
    lossnet = LossNet().to(device).eval()
    optimizer = torch.optim.Adam(decoder.parameters(), lr=run_cfg.lr)

    # Resume (decoder-only resume supported; optional)
    start_iter = 0
    if run_cfg.resume:
        data = torch.load(run_cfg.resume, map_location=device)
        decoder.load_state_dict(data['decoder'], strict=True)
        if 'optimizer' in data:
            optimizer.load_state_dict(data['optimizer'])
        start_iter = data.get('iteration', 0)
        print(f"[resume] loaded {run_cfg.resume} @ iter {start_iter}")

    global_iter = start_iter
    data_iter = iter(train_loader)
    decoder.train()

    print(f"\n=== RUN {tag} ===")
    print({k: getattr(run_cfg, k) for k in ['lambda_content','lambda_style','alpha','max_iterations','save_every','log_every']})

    while global_iter < run_cfg.max_iterations:
        try:
            content, style = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            content, style = next(data_iter)

        content = content.to(device, non_blocking=True)
        style   = style.to(device, non_blocking=True)

        # --- Targets ---
        with torch.no_grad():
            c4 = lossnet.encoder(content, out_keys=['relu4_1'])['relu4_1']
            s4 = lossnet.encoder(style,   out_keys=['relu4_1'])['relu4_1']
        t_adain = adain(c4, s4)
        t_feat  = run_cfg.alpha * t_adain + (1.0 - run_cfg.alpha) * c4  # blend with content features

        # --- Decode ---
        g_img = decoder(t_feat)
        g_img_clamped = torch.clamp(g_img, -3.0, 3.0)

        # --- Re-encode generated (with grad) ---
        g_feats_all = lossnet.encoder(g_img_clamped, out_keys=STYLE_LAYERS)

        # --- Style targets (no grad) ---
        with torch.no_grad():
            s_feats_all = lossnet.encoder(style, out_keys=STYLE_LAYERS)

        # --- Losses ---
        loss_c = F.mse_loss(g_feats_all[CONTENT_LAYER], t_feat)
        loss_s = mean_std_loss(g_feats_all, s_feats_all)
        loss   = run_cfg.lambda_content * loss_c + run_cfg.lambda_style * loss_s

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        global_iter += 1

        if global_iter % run_cfg.log_every == 0:
            print(f"[{global_iter:>6d}/{run_cfg.max_iterations}] "
                  f"loss={loss.item():.4f}  Lc={loss_c.item():.4f}  Ls={loss_s.item():.4f}")

        if global_iter % run_cfg.save_every == 0:
            with torch.no_grad():
                c_show = denorm_for_save(content[:2])
                s_show = denorm_for_save(style[:2])
                g_show = denorm_for_save(g_img_clamped[:2])
                samples = torch.cat([c_show, s_show, g_show], dim=0)
            grid = utils.make_grid(samples, nrow=2)
            grid_path = os.path.join(run_out, f"samples_iter_{global_iter}.png")
            utils.save_image(grid, grid_path)

            ckpt = {
                'iteration': global_iter,
                'decoder': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'cfg': vars(run_cfg),
                'time': time.time(),
            }
            ckpt_path = os.path.join(run_out, f"decoder_iter3_{global_iter}.pth")
            torch.save(ckpt, ckpt_path)
            try:
                os.sync()
            except Exception:
                pass

    # Final save
    final_path = os.path.join(run_out, "decoder_final3.pth")
    torch.save({'iteration': global_iter, 'decoder': decoder.state_dict(), 'cfg': vars(run_cfg)}, final_path)
    print(f"[final] saved: {final_path}")

    return {
        "tag": tag,
        "final_iteration": global_iter,
        "out_dir": run_out,
        "lambda_content": run_cfg.lambda_content,
        "lambda_style": run_cfg.lambda_style,
        "alpha": run_cfg.alpha
    }

# -----------------------------
# Evaluation: stylize fixed pairs with alpha sweep (after training)
# -----------------------------
@torch.no_grad()
def stylize_pairs(ckpt_path, pairs, alpha_vals=(1.0,), save_path="eval_grid.png"):
    # Load trained decoder
    decoder = Decoder().to(device).eval()
    data = torch.load(ckpt_path, map_location=device)
    decoder.load_state_dict(data['decoder'], strict=True)

    enc = VGGEncoder().to(device).eval()

    to_norm = T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    to_tensor = T.Compose([
        T.Lambda(lambda im: im.convert("RGB")),
        T.Resize(512, interpolation=Image.BICUBIC),
        T.CenterCrop(256),
        T.ToTensor(),
        to_norm
    ])

    rows = []
    for c_path, s_path in pairs:
        c = to_tensor(Image.open(c_path)).unsqueeze(0).to(device)
        s = to_tensor(Image.open(s_path)).unsqueeze(0).to(device)
        c4 = enc(c, out_keys=['relu4_1'])['relu4_1']
        s4 = enc(s, out_keys=['relu4_1'])['relu4_1']
        t  = adain(c4, s4)
        row_imgs = [denorm_for_save(c), denorm_for_save(s)]
        for a in alpha_vals:
            t_blend = a * t + (1.0 - a) * c4
            y = decoder(t_blend).clamp(-3,3)
            row_imgs.append(denorm_for_save(y))
        rows.append(torch.cat(row_imgs, dim=0))

    grid = utils.make_grid(torch.cat(rows, dim=0), nrow=2 + len(alpha_vals))
    utils.save_image(grid, save_path)
    print("Saved eval grid:", save_path)

# -----------------------------
# Sweep definition & execution
# -----------------------------
if __name__ == "__main__":
    # Define a first-pass sweep over style weights; keep λc = 1, α = 1
    sweep = []
    for lam_s in [0.5, 1.0, 2.0, 5.0, 10.0, 20.0]:
        sweep.append({"lambda_content": 1.0, "lambda_style": lam_s, "alpha": 1.0, "max_iterations": 10_000})

    # Optional: vary content weight a bit (still α = 1)
    for lam_c in [0.5, 1.0, 2.0]:
        sweep.append({"lambda_content": lam_c, "lambda_style": 5.0, "alpha": 1.0, "max_iterations": 10_000})

    results = []
    for overrides in sweep:
        results.append(train_once(cfg, overrides))

    # Save a summary JSON
    with open(os.path.join(cfg.out_dir, "sweep_results.json"), "w") as f:
        json.dump(results, f, indent=2)

    # Example: quick evaluation grid for one trained checkpoint (edit paths if needed)
    try:
        ex_content = first_img(cfg.content_dir)
        ex_style   = first_img(cfg.style_dir)
        if ex_content and ex_style:
            ckpt_example = Path(cfg.out_dir) / "c1_s5_a1_it10000" / "decoder_final.pth"
            if ckpt_example.exists():
                stylize_pairs(
                    str(ckpt_example),
                    pairs=[(ex_content, ex_style)],           # add more pairs here
                    alpha_vals=(0.3, 0.6, 1.0),
                    save_path=str(Path(cfg.out_dir) / "eval_s5.png")
                )
    except Exception as e:
        print("Eval example skipped:", e)


Device: cuda
CFG: {'alpha': 1.0, 'batch_size': 8, 'content_dir': '/workspace/content', 'data_root': '/workspace', 'image_size_crop': 256, 'lambda_content': 0.5, 'lambda_style': 10.0, 'log_every': 200, 'lr': 0.0001, 'max_iterations': 80000, 'num_workers': 2, 'out_dir': './Folder/adain_runs3', 'resize_shorter_to': 512, 'resume': None, 'save_every': 5000, 'style_dir': '/workspace/style'}
content_dir exists: True n_images: 49981
style_dir   exists: True n_images: 49981
example content: /workspace/content/000000000045.jpg
example style  : /workspace/style/1.jpg
Dataset size: 49981 Batches/epoch (approx): 6247

=== RUN c1_s0.5_a1_it10000 ===
{'lambda_content': 1.0, 'lambda_style': 0.5, 'alpha': 1.0, 'max_iterations': 10000, 'save_every': 5000, 'log_every': 200}
[   200/10000] loss=10.7049  Lc=8.5227  Ls=4.3644
[   400/10000] loss=7.4854  Lc=6.0294  Ls=2.9120
[   600/10000] loss=9.6542  Lc=7.6899  Ls=3.9287
[   800/10000] loss=9.8296  Lc=8.1587  Ls=3.3418
[  1000/10000] loss=9.7106  Lc=8.0580

In [5]:
#Adain Hyper 2

import os, random, time, json
from pathlib import Path
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
from torchvision import models, utils

# -----------------------------
# Device & Reproducibility
# -----------------------------
torch.backends.cudnn.benchmark = True
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# -----------------------------
# Config
# -----------------------------
class Cfg:
    # Paths (adjust if needed)
    data_root = str((Path.cwd() / "").resolve())
    content_dir = os.path.join(data_root, "content")
    style_dir   = os.path.join(data_root, "style")
    out_dir     = "./Folder/adain_runs4"

    # Training
    image_size_crop   = 256
    resize_shorter_to = 512
    batch_size = 8
    num_workers = 2
    lr = 1e-4
    max_iterations = 80_000

    # Logging & saving
    save_every = 5000            # <-- save grid + checkpoint every 200 steps
    log_every  = 200

    # Loss weights (total = λc * Lc + λs * Ls)
    lambda_content = 0.8
    lambda_style   = 58.0

    # AdaIN blend strength for target feature (1.0 = pure style stats)
    alpha = 1.0

    # Resume (leave None for fresh runs)
    resume = None

cfg = Cfg()
os.makedirs(cfg.out_dir, exist_ok=True)

# Helper: turn a Cfg instance (class-level attrs) into a dict
def cfg_to_dict(obj):
    out = {}
    for k in dir(obj):
        if k.startswith("_"):
            continue
        v = getattr(obj, k)
        if callable(v):
            continue
        out[k] = v
    return out

print("CFG:", cfg_to_dict(cfg))

# -----------------------------
# Basic image helpers
# -----------------------------
IMG_EXTS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')

def count_imgs(p):
    p = Path(p)
    return sum(1 for f in p.rglob("*") if f.suffix.lower() in IMG_EXTS)

def first_img(p):
    p = Path(p)
    for f in p.rglob("*"):
        if f.suffix.lower() in IMG_EXTS:
            return str(f)
    return None

print("content_dir exists:", os.path.isdir(cfg.content_dir), "n_images:", count_imgs(cfg.content_dir) if os.path.isdir(cfg.content_dir) else 0)
print("style_dir   exists:", os.path.isdir(cfg.style_dir),   "n_images:", count_imgs(cfg.style_dir)   if os.path.isdir(cfg.style_dir)   else 0)
print("example content:", first_img(cfg.content_dir))
print("example style  :", first_img(cfg.style_dir))

# -----------------------------
# Dataset
# -----------------------------
class ImageFolderFlat(Dataset):
    def __init__(self, root, resize_shorter_to=512, crop_size=256):
        self.paths = []
        for p in sorted(Path(root).rglob("*")):
            if p.suffix.lower() in IMG_EXTS:
                self.paths.append(str(p))
        if not self.paths:
            raise RuntimeError(f"No images found under {root}")

        self.transform = T.Compose([
            T.Lambda(lambda im: im.convert("RGB")),
            T.Resize(resize_shorter_to, interpolation=Image.BICUBIC),
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx % len(self.paths)]
        img = Image.open(p)
        return self.transform(img)

class PairDataset(Dataset):
    """
    Iterates over content images; for each content item, pick a random style item.
    """
    def __init__(self, content_root, style_root, resize_shorter_to=512, crop_size=256):
        self.content_ds = ImageFolderFlat(content_root, resize_shorter_to, crop_size)
        self.style_ds   = ImageFolderFlat(style_root,   resize_shorter_to, crop_size)
        self.style_len  = len(self.style_ds)

    def __len__(self):
        return len(self.content_ds)

    def __getitem__(self, idx):
        content_img = self.content_ds[idx]
        style_img   = self.style_ds[random.randrange(self.style_len)]
        return content_img, style_img

# Build dataloader once; reused across sweeps
train_ds = PairDataset(cfg.content_dir, cfg.style_dir, cfg.resize_shorter_to, cfg.image_size_crop)
train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    pin_memory=(device.type == 'cuda'),
    drop_last=True
)
print("Dataset size:", len(train_ds), "Batches/epoch (approx):", len(train_loader))

# -----------------------------
# VGG Encoder & AdaIN
# -----------------------------
# Map torchvision VGG19 feature indices to friendly names
LAYER_NAME_MAP = {
    1 : 'relu1_1',
    6 : 'relu2_1',
    11: 'relu3_1',
    20: 'relu4_1',
}

STYLE_LAYERS = ['relu1_1','relu2_1','relu3_1','relu4_1']
CONTENT_LAYER = 'relu4_1'

def calc_mean_std(feat, eps=1e-5):
    # feat: (B, C, H, W)
    B, C = feat.size()[:2]
    feat_var = feat.view(B, C, -1).var(dim=2, unbiased=False) + eps
    feat_std = feat_var.sqrt().view(B, C, 1, 1)
    feat_mean = feat.view(B, C, -1).mean(dim=2).view(B, C, 1, 1)
    return feat_mean, feat_std

def adain(content_feat, style_feat, eps=1e-5):
    # Channel-wise align mean & std of content to style
    c_mean, c_std = calc_mean_std(content_feat, eps)
    s_mean, s_std = calc_mean_std(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

class VGGEncoder(nn.Module):
    """
    VGG-19 (imagenet) up to relu4_1. Returns dict of selected layer activations.
    """
    def __init__(self):
        super().__init__()
        try:
            self.vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        except Exception:
            self.vgg = models.vgg19(pretrained=True).features
        for p in self.vgg.parameters():
            p.requires_grad_(False)

    def forward(self, x, out_keys=('relu1_1','relu2_1','relu3_1','relu4_1')):
        feats = {}
        h = x
        for i, layer in enumerate(self.vgg):
            h = layer(h)
            name = LAYER_NAME_MAP.get(i, None)
            if name in out_keys:
                feats[name] = h
            if i >= 20:  # after relu4_1
                pass
        return feats

# -----------------------------
# Decoder (mirror-ish of VGG)
# -----------------------------
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            # relu4_1 -> block3
            nn.ReflectionPad2d(1), nn.Conv2d(512, 256, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 128, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 128, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 64, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 64, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 3, 3)
        )

    def forward(self, x):
        return self.body(x)

# -----------------------------
# Loss wrapper & style loss
# -----------------------------
class LossNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = VGGEncoder()
        self.encoder.eval()
        for p in self.encoder.parameters():
            p.requires_grad_(False)

    @torch.no_grad()
    def encode(self, x, keys=None):
        return self.encoder(x, out_keys=tuple(keys) if keys else tuple(STYLE_LAYERS))

def mean_std_loss(x_feats, y_feats):
    """
    IN-statistics style loss: sum over layers of (mean MSE + std MSE)
    """
    loss = 0.0
    for k in STYLE_LAYERS:
        xm, xs = calc_mean_std(x_feats[k])
        ym, ys = calc_mean_std(y_feats[k])
        loss = loss + F.mse_loss(xm, ym) + F.mse_loss(xs, ys)
    return loss

# -----------------------------
# Utility: (de)normalization & saving
# -----------------------------
@torch.no_grad()
def denorm_for_save(x):
    # x is normalized (ImageNet), bring it back to [0,1] for saving
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
    y = x * std + mean
    return torch.clamp(y, 0, 1)

# -----------------------------
# Training: one run with overrides  (FIXED cfg copy)
# -----------------------------
def train_once(cfg_base, overrides):
    # Build a dict from cfg_base (even if class-level attrs), then apply overrides
    base_dict = cfg_to_dict(cfg_base)
    base_dict.update(overrides)
    run_cfg = SimpleNamespace(**base_dict)

    # Unique run folder
    tag = f"c{run_cfg.lambda_content:g}_s{run_cfg.lambda_style:g}_a{run_cfg.alpha:g}_it{run_cfg.max_iterations}"
    run_out = os.path.join(run_cfg.out_dir, tag)
    os.makedirs(run_out, exist_ok=True)

    # Fresh model & optimizer
    decoder = Decoder().to(device)
    lossnet = LossNet().to(device).eval()
    optimizer = torch.optim.Adam(decoder.parameters(), lr=run_cfg.lr)

    # Resume (decoder-only resume supported; optional)
    start_iter = 0
    if run_cfg.resume:
        data = torch.load(run_cfg.resume, map_location=device)
        decoder.load_state_dict(data['decoder'], strict=True)
        if 'optimizer' in data:
            optimizer.load_state_dict(data['optimizer'])
        start_iter = data.get('iteration', 0)
        print(f"[resume] loaded {run_cfg.resume} @ iter {start_iter}")

    global_iter = start_iter
    data_iter = iter(train_loader)
    decoder.train()

    print(f"\n=== RUN {tag} ===")
    print({k: getattr(run_cfg, k) for k in ['lambda_content','lambda_style','alpha','max_iterations','save_every','log_every']})

    while global_iter < run_cfg.max_iterations:
        try:
            content, style = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            content, style = next(data_iter)

        content = content.to(device, non_blocking=True)
        style   = style.to(device, non_blocking=True)

        # --- Targets ---
        with torch.no_grad():
            c4 = lossnet.encoder(content, out_keys=['relu4_1'])['relu4_1']
            s4 = lossnet.encoder(style,   out_keys=['relu4_1'])['relu4_1']
        t_adain = adain(c4, s4)
        t_feat  = run_cfg.alpha * t_adain + (1.0 - run_cfg.alpha) * c4  # blend with content features

        # --- Decode ---
        g_img = decoder(t_feat)
        g_img_clamped = torch.clamp(g_img, -3.0, 3.0)

        # --- Re-encode generated (with grad) ---
        g_feats_all = lossnet.encoder(g_img_clamped, out_keys=STYLE_LAYERS)

        # --- Style targets (no grad) ---
        with torch.no_grad():
            s_feats_all = lossnet.encoder(style, out_keys=STYLE_LAYERS)

        # --- Losses ---
        loss_c = F.mse_loss(g_feats_all[CONTENT_LAYER], t_feat)
        loss_s = mean_std_loss(g_feats_all, s_feats_all)
        loss   = run_cfg.lambda_content * loss_c + run_cfg.lambda_style * loss_s

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        global_iter += 1

        if global_iter % run_cfg.log_every == 0:
            print(f"[{global_iter:>6d}/{run_cfg.max_iterations}] "
                  f"loss={loss.item():.4f}  Lc={loss_c.item():.4f}  Ls={loss_s.item():.4f}")

        if global_iter % run_cfg.save_every == 0:
            with torch.no_grad():
                c_show = denorm_for_save(content[:2])
                s_show = denorm_for_save(style[:2])
                g_show = denorm_for_save(g_img_clamped[:2])
                samples = torch.cat([c_show, s_show, g_show], dim=0)
            grid = utils.make_grid(samples, nrow=2)
            grid_path = os.path.join(run_out, f"samples_iter_{global_iter}.png")
            utils.save_image(grid, grid_path)

            ckpt = {
                'iteration': global_iter,
                'decoder': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'cfg': vars(run_cfg),
                'time': time.time(),
            }
            ckpt_path = os.path.join(run_out, f"decoder_iter4_{global_iter}.pth")
            torch.save(ckpt, ckpt_path)
            try:
                os.sync()
            except Exception:
                pass

    # Final save
    final_path = os.path.join(run_out, "decoder_final4.pth")
    torch.save({'iteration': global_iter, 'decoder': decoder.state_dict(), 'cfg': vars(run_cfg)}, final_path)
    print(f"[final] saved: {final_path}")

    return {
        "tag": tag,
        "final_iteration": global_iter,
        "out_dir": run_out,
        "lambda_content": run_cfg.lambda_content,
        "lambda_style": run_cfg.lambda_style,
        "alpha": run_cfg.alpha
    }

# -----------------------------
# Evaluation: stylize fixed pairs with alpha sweep (after training)
# -----------------------------
@torch.no_grad()
def stylize_pairs(ckpt_path, pairs, alpha_vals=(1.0,), save_path="eval_grid.png"):
    # Load trained decoder
    decoder = Decoder().to(device).eval()
    data = torch.load(ckpt_path, map_location=device)
    decoder.load_state_dict(data['decoder'], strict=True)

    enc = VGGEncoder().to(device).eval()

    to_norm = T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    to_tensor = T.Compose([
        T.Lambda(lambda im: im.convert("RGB")),
        T.Resize(512, interpolation=Image.BICUBIC),
        T.CenterCrop(256),
        T.ToTensor(),
        to_norm
    ])

    rows = []
    for c_path, s_path in pairs:
        c = to_tensor(Image.open(c_path)).unsqueeze(0).to(device)
        s = to_tensor(Image.open(s_path)).unsqueeze(0).to(device)
        c4 = enc(c, out_keys=['relu4_1'])['relu4_1']
        s4 = enc(s, out_keys=['relu4_1'])['relu4_1']
        t  = adain(c4, s4)
        row_imgs = [denorm_for_save(c), denorm_for_save(s)]
        for a in alpha_vals:
            t_blend = a * t + (1.0 - a) * c4
            y = decoder(t_blend).clamp(-3,3)
            row_imgs.append(denorm_for_save(y))
        rows.append(torch.cat(row_imgs, dim=0))

    grid = utils.make_grid(torch.cat(rows, dim=0), nrow=2 + len(alpha_vals))
    utils.save_image(grid, save_path)
    print("Saved eval grid:", save_path)

# -----------------------------
# Sweep definition & execution
# -----------------------------
if __name__ == "__main__":
    # Define a first-pass sweep over style weights; keep λc = 1, α = 1
    sweep = []
    for lam_s in [0.5, 1.0, 2.0, 5.0, 10.0, 20.0]:
        sweep.append({"lambda_content": 1.0, "lambda_style": lam_s, "alpha": 1.0, "max_iterations": 10_000})

    # Optional: vary content weight a bit (still α = 1)
    for lam_c in [0.5, 1.0, 2.0]:
        sweep.append({"lambda_content": lam_c, "lambda_style": 5.0, "alpha": 1.0, "max_iterations": 10_000})

    results = []
    for overrides in sweep:
        results.append(train_once(cfg, overrides))

    # Save a summary JSON
    with open(os.path.join(cfg.out_dir, "sweep_results.json"), "w") as f:
        json.dump(results, f, indent=2)

    # Example: quick evaluation grid for one trained checkpoint (edit paths if needed)
    try:
        ex_content = first_img(cfg.content_dir)
        ex_style   = first_img(cfg.style_dir)
        if ex_content and ex_style:
            ckpt_example = Path(cfg.out_dir) / "c1_s5_a1_it10000" / "decoder_final.pth"
            if ckpt_example.exists():
                stylize_pairs(
                    str(ckpt_example),
                    pairs=[(ex_content, ex_style)],           # add more pairs here
                    alpha_vals=(0.3, 0.6, 1.0),
                    save_path=str(Path(cfg.out_dir) / "eval_s5.png")
                )
    except Exception as e:
        print("Eval example skipped:", e)


Device: cuda
CFG: {'alpha': 1.0, 'batch_size': 8, 'content_dir': '/workspace/content', 'data_root': '/workspace', 'image_size_crop': 256, 'lambda_content': 0.8, 'lambda_style': 58.0, 'log_every': 200, 'lr': 0.0001, 'max_iterations': 80000, 'num_workers': 2, 'out_dir': './Folder/adain_runs4', 'resize_shorter_to': 512, 'resume': None, 'save_every': 5000, 'style_dir': '/workspace/style'}
content_dir exists: True n_images: 49981
style_dir   exists: True n_images: 49981
example content: /workspace/content/000000000045.jpg
example style  : /workspace/style/1.jpg
Dataset size: 49981 Batches/epoch (approx): 6247

=== RUN c1_s0.5_a1_it10000 ===
{'lambda_content': 1.0, 'lambda_style': 0.5, 'alpha': 1.0, 'max_iterations': 10000, 'save_every': 5000, 'log_every': 200}
[   200/10000] loss=10.6290  Lc=8.4794  Ls=4.2993
[   400/10000] loss=7.4365  Lc=5.9102  Ls=3.0526
[   600/10000] loss=9.6552  Lc=7.7455  Ls=3.8193
[   800/10000] loss=9.7455  Lc=7.9632  Ls=3.5646
[  1000/10000] loss=9.5849  Lc=7.8290

In [6]:
# ===== AdaIN inference (no crop, preserve exact size) — self-contained =====
import os, torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as T
from torchvision import models, utils

# ---------------- Device ----------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Inference device:", device)

# ---------------- Core helpers ----------------
def calc_mean_std(feat, eps=1e-5):
    B, C = feat.size()[:2]
    var = feat.view(B, C, -1).var(dim=2, unbiased=False) + eps
    std = var.sqrt().view(B, C, 1, 1)
    mean = feat.view(B, C, -1).mean(dim=2).view(B, C, 1, 1)
    return mean, std

def adain(content_feat, style_feat, eps=1e-5):
    c_mean, c_std = calc_mean_std(content_feat, eps)
    s_mean, s_std = calc_mean_std(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

@torch.no_grad()
def denorm_for_save(x):
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
    y = x * std + mean
    return torch.clamp(y, 0, 1)

# ---------------- VGG encoder up to relu4_1 ----------------
LAYER_NAME_MAP = {1:'relu1_1', 6:'relu2_1', 11:'relu3_1', 20:'relu4_1'}
class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        try:
            self.vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        except Exception:
            self.vgg = models.vgg19(pretrained=True).features
        for p in self.vgg.parameters():
            p.requires_grad_(False)
    def forward(self, x, out_keys=('relu1_1','relu2_1','relu3_1','relu4_1')):
        feats, h = {}, x
        for i, layer in enumerate(self.vgg):
            h = layer(h)
            name = LAYER_NAME_MAP.get(i, None)
            if name in out_keys:
                feats[name] = h
        return feats

# ---------------- Decoder (must match training) ----------------
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            nn.ReflectionPad2d(1), nn.Conv2d(512, 256, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 128, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 128, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 64, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 64, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 3, 3)
        )
    def forward(self, x): return self.body(x)

# ---------------- Robust checkpoint loader ----------------
def _strip_module_prefix(sd):
    if not any(k.startswith('module.') for k in sd.keys()):
        return sd
    return {k.replace('module.', '', 1): v for k,v in sd.items()}

def load_decoder_from_ckpt(ckpt_path, device):
    assert os.path.isfile(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    dec = Decoder().to(device).eval()
    data = torch.load(ckpt_path, map_location=device)
    last_err = None
    if isinstance(data, dict):
        for key in ('decoder','state_dict','model'):
            if key in data and isinstance(data[key], dict):
                try:
                    dec.load_state_dict(_strip_module_prefix(data[key]), strict=True)
                    return dec
                except Exception as e:
                    last_err = e
        # raw state_dict?
        try:
            dec.load_state_dict(_strip_module_prefix(data), strict=True)
            return dec
        except Exception as e:
            last_err = e
    raise RuntimeError(f"Could not load decoder weights from {ckpt_path}. Last error: {last_err}")

# ---------------- Inference that preserves exact size ----------------
@torch.no_grad()
def stylize_preserve_size(
    ckpt_path: str,
    content_path: str,
    style_path: str,
    out_path: str = "stylized_preserve.png",
    alpha: float = 1.0,
    max_long_side: int | None = None,   # None = keep original; or e.g. 1024 to downscale for memory
    style_max_long_side: int = 512      # style can be smaller
):
    assert os.path.isfile(content_path), f"Content not found: {content_path}"
    assert os.path.isfile(style_path),   f"Style not found: {style_path}"
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)

    dec = load_decoder_from_ckpt(ckpt_path, device)
    enc = VGGEncoder().to(device).eval()

    # No crop transforms
    to_norm = T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    def load_no_crop(path, limit=None):
        im = Image.open(path).convert("RGB")
        if limit is not None:
            w, h = im.size
            scale = min(1.0, limit / max(w, h))  # downscale only
            if scale < 1.0:
                im = im.resize((round(w*scale), round(h*scale)), Image.BICUBIC)
        t = T.ToTensor()(im)
        return to_norm(t), im.size  # (C,H,W), (W,H)

    c_t, (cw, ch) = load_no_crop(content_path, max_long_side)      # (C,H,W), (W,H)
    s_t, _        = load_no_crop(style_path,   style_max_long_side)

    c = c_t.unsqueeze(0).to(device)  # [1,3,H,W]
    s = s_t.unsqueeze(0).to(device)

    # Pad to multiples of 8 so decode returns same spatial size
    H, W = c.shape[-2:]
    mul = 8
    newH = ((H + mul - 1) // mul) * mul
    newW = ((W + mul - 1) // mul) * mul
    pad_top    = (newH - H) // 2
    pad_bottom = newH - H - pad_top
    pad_left   = (newW - W) // 2
    pad_right  = newW - W - pad_left
    if newH != H or newW != W:
        c = F.pad(c, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')

    # Encode, AdaIN, blend, decode
    c4 = enc(c, out_keys=['relu4_1'])['relu4_1']
    s4 = enc(s, out_keys=['relu4_1'])['relu4_1']
    t  = adain(c4, s4)
    t_blend = alpha * t + (1.0 - alpha) * c4
    y = dec(t_blend).clamp(-3, 3)   # [1,3,newH,newW]

    # Remove padding back to original size
    if newH != H or newW != W:
        y = y[..., pad_top:pad_top+H, pad_left:pad_left+W]

    y_vis = denorm_for_save(y)[0].cpu()
    utils.save_image(y_vis, out_path)
    print(f"Saved → {out_path} | original (HxW): {(ch, cw)} | output (HxW): {tuple(y_vis.shape[-2:])}")
    return out_path

# ---------------- Example call (EDIT THESE PATHS) ----------------
ckpt = "/workspace/Folder/adain_runs4/c2_s5_a1_it10000/decoder_final4.pth"  # ensure the filename matches your saved model
content = "/workspace/content test2.jpg"
style   = "/workspace/Van_Gogh_-_Starry_Night.jpg"
stylize_preserve_size(ckpt, content, style, out_path="stylized_full_exact3.png", alpha=1.0, max_long_side=None)


Inference device: cuda
Saved → stylized_full_exact3.png | original (HxW): (480, 910) | output (HxW): (480, 910)


'stylized_full_exact3.png'