In [1]:
import sys, os, platform, subprocess, shutil
print("Python:", sys.version)
print("Executable:", sys.executable)
print("Platform:", platform.platform())
print("Conda env:", os.environ.get("CONDA_DEFAULT_ENV"))
print("Has nvidia-smi?", shutil.which("nvidia-smi") is not None)
if shutil.which("nvidia-smi"):
    print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)


Python: 3.12.11 | packaged by conda-forge | (main, Jun  4 2025, 14:45:31) [GCC 13.3.0]
Executable: /venv/main/bin/python
Platform: Linux-6.8.0-60-generic-x86_64-with-glibc2.39
Conda env: None
Has nvidia-smi? True
Thu Aug 21 01:52:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.153.02             Driver Version: 570.153.02     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5060 Ti     On  |   00000000:03:00.0 Off |                  N/A |
|  0%   38C    P8              7W /  180W |       0MiB /  16311MiB |      0%      Default |
|                  

In [2]:
# --- cell 0: setup ---
import os, math, random, time, json, glob
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
from torchvision import models, utils

# Force CPU for now so everything runs
USE_CPU_FOR_NOW = True

import torch
device = torch.device('cpu' if USE_CPU_FOR_NOW else ('cuda' if torch.cuda.is_available() else 'cpu'))
print("Device:", device)


# Reproducibility
torch.manual_seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cpu
Device: cuda


In [3]:
import sys, subprocess
print("Python exe:", sys.executable)
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout.splitlines()[:5])


Python exe: /venv/main/bin/python
['Thu Aug 21 01:52:50 2025       ', '+-----------------------------------------------------------------------------------------+', '| NVIDIA-SMI 570.153.02             Driver Version: 570.153.02     CUDA Version: 12.8     |', '|-----------------------------------------+------------------------+----------------------+', '| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |']


In [47]:
import sys, subprocess

# Remove mismatched installs
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "torch", "torchvision", "torchaudio"])

# Install the CUDA 12.8 build (stable)
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "pip"])
subprocess.check_call([sys.executable, "-m", "pip", "install",
                      "torch", "torchvision", "torchaudio",
                      "--index-url", "https://download.pytorch.org/whl/cu128"])


Found existing installation: torch 2.8.0+cu128
Uninstalling torch-2.8.0+cu128:
  Successfully uninstalled torch-2.8.0+cu128
Found existing installation: torchvision 0.23.0+cu128
Uninstalling torchvision-0.23.0+cu128:
  Successfully uninstalled torchvision-0.23.0+cu128
Found existing installation: torchaudio 2.8.0+cu128
Uninstalling torchaudio-2.8.0+cu128:
  Successfully uninstalled torchaudio-2.8.0+cu128


[0m



[0m

Looking in indexes: https://download.pytorch.org/whl/cu128
Collecting torch
  Using cached https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu128/torchvision-0.23.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/cu128/torchaudio-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (7.2 kB)
Downloading https://download.pytorch.org/whl/cu128/torch-2.8.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl (889.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m889.0/889.0 MB[0m [31m64.1 MB/s[0m  [33m0:00:12[0mm0:00:01[0m00:01[0m
[?25hUsing cached https://download.pytorch.org/whl/cu128/torchvision-0.23.0%2Bcu128-cp312-cp312-manylinux_2_28_x86_64.whl (8.6 MB)
Using cached https://download.pytorch.org/whl/cu128/torchaudio-2.8.0%2Bcu128-cp312-cp31

[0m

0

In [3]:
# --- cell 1: config ---
class Cfg:
    # paths
    data_root = str((Path.cwd() / "/workspace").resolve())
    content_dir = os.path.join(data_root, "content")   
    style_dir   = os.path.join(data_root, "style")     
    
    out_dir     = "./adain_runs"                # where to save checkpoints & samples
    # os.makedirs(out_dir, exist_ok=True)

    # training
    image_size_crop = 256        # 256x256 training crops (per AdaIN paper)
    resize_shorter_to = 512      # resize shorter side -> 512, then random crop 256
    batch_size = 8               # adjust to your GPU
    num_workers = 2
    lr = 1e-4                    # Adam LR used commonly for AdaIN decoder training
    max_iterations = 160_000     # total training iterations (common range: 80k-160k+)
    save_every = 2000            # save checkpoint every N iterations (your requirement)
    log_every = 200              # print losses every N iterations

    # loss weights (L = Lc + lambda_style * Ls), paper uses IN-statistics style loss
    lambda_style = 10.0

    # resume (set to checkpoint path if resuming)
    resume = None

cfg = Cfg()
print(cfg.__dict__)


{}


In [5]:
# Sanity check for Step 1 (Config)

from pathlib import Path
import os

print("Device OK? ->", device)

print("data_root exists:", os.path.isdir(cfg.data_root))
print("content_dir exists:", os.path.isdir(cfg.content_dir))
print("style_dir exists:", os.path.isdir(cfg.style_dir))

IMG_EXTS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')

def count_imgs(p):
    p = Path(p)
    return sum(1 for f in p.rglob("*") if f.suffix.lower() in IMG_EXTS)

def first_img(p):
    p = Path(p)
    for f in p.rglob("*"):
        if f.suffix.lower() in IMG_EXTS:
            return f
    return None

n_content = count_imgs(cfg.content_dir) if os.path.isdir(cfg.content_dir) else 0
n_style   = count_imgs(cfg.style_dir)   if os.path.isdir(cfg.style_dir)   else 0

print("content images:", n_content)
print("style images:", n_style)
print("example content:", first_img(cfg.content_dir))
print("example style:", first_img(cfg.style_dir))


Device OK? -> cuda
data_root exists: True
content_dir exists: True
style_dir exists: True
content images: 49981
style images: 49981
example content: /workspace/content/000000000045.jpg
example style: /workspace/style/1.jpg


In [6]:
# --- cell 2: dataset ---
IMG_EXTS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')

class ImageFolderFlat(Dataset):
    def __init__(self, root, resize_shorter_to=512, crop_size=256):
        self.paths = []
        for p in sorted(Path(root).rglob("*")):
            if p.suffix.lower() in IMG_EXTS:
                self.paths.append(str(p))
        if not self.paths:
            raise RuntimeError(f"No images found under {root}")

        self.transform = T.Compose([
            T.Lambda(lambda im: im.convert("RGB")),
            T.Resize(resize_shorter_to, interpolation=Image.BICUBIC),
            T.RandomCrop(crop_size),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx % len(self.paths)]
        img = Image.open(p)
        return self.transform(img)

class PairDataset(Dataset):
    """
    Iterates over content images; for each content item, pick a random style item.
    """
    def __init__(self, content_root, style_root, resize_shorter_to=512, crop_size=256):
        self.content_ds = ImageFolderFlat(content_root, resize_shorter_to, crop_size)
        self.style_ds   = ImageFolderFlat(style_root,   resize_shorter_to, crop_size)
        self.style_len  = len(self.style_ds)

    def __len__(self):
        return len(self.content_ds)

    def __getitem__(self, idx):
        content_img = self.content_ds[idx]
        style_img   = self.style_ds[random.randrange(self.style_len)]
        return content_img, style_img

train_ds = PairDataset(cfg.content_dir, cfg.style_dir, cfg.resize_shorter_to, cfg.image_size_crop)
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
len(train_ds), len(train_loader)


(49981, 6247)

In [7]:
# --- cell 3: vgg encoder & helpers ---
def calc_mean_std(feat, eps=1e-5):
    # feat: (B, C, H, W)
    B, C = feat.size()[:2]
    feat_var = feat.view(B, C, -1).var(dim=2, unbiased=False) + eps
    feat_std = feat_var.sqrt().view(B, C, 1, 1)
    feat_mean = feat.view(B, C, -1).mean(dim=2).view(B, C, 1, 1)
    return feat_mean, feat_std

def adain(content_feat, style_feat, eps=1e-5):
    # Channel-wise align mean & std of content to style
    size = content_feat.size()
    c_mean, c_std = calc_mean_std(content_feat, eps)
    s_mean, s_std = calc_mean_std(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

# Map torchvision VGG19 feature indices to friendly names
VGG_FEATURES = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features

LAYER_NAME_MAP = {
    1 : 'relu1_1',
    6 : 'relu2_1',
    11: 'relu3_1',
    20: 'relu4_1',
}

STYLE_LAYERS = ['relu1_1','relu2_1','relu3_1','relu4_1']
CONTENT_LAYER = 'relu4_1'  # for AdaIN content loss

class VGGEncoder(nn.Module):
    """
    VGG-19 (imagenet) up to relu4_1. Returns a dict of selected layer activations.
    """
    def __init__(self):
        super().__init__()
        self.vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        # Freeze params
        for p in self.vgg.parameters():
            p.requires_grad_(False)

    def forward(self, x, out_keys=('relu1_1','relu2_1','relu3_1','relu4_1')):
        feats = {}
        h = x
        for i, layer in enumerate(self.vgg):
            h = layer(h)
            name = LAYER_NAME_MAP.get(i, None)
            if name in out_keys:
                feats[name] = h
            if i >= 20:  # up to relu4_1
                # We still finish this iteration to capture relu4_1
                pass
        return feats

# Decoder mirrors encoder with nearest upsample, reflection padding, no norm layers
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        layers = []

        def block(in_c, out_c, k=3, s=1, pad='reflect', act=True):
            if pad == 'reflect':
                layers.append(nn.ReflectionPad2d((k-1)//2))
                layers.append(nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=0))
            else:
                layers.append(nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=(k-1)//2))
            if act:
                layers.append(nn.ReLU(inplace=True))

        # Start from relu4_1 feature space (512 channels)
        # Mirror path back to image (3 channels)
        # A common decoder matching VGG blocks (no norm, reflection pads, NN upsample)
        self.body = nn.Sequential(
            # relu4_1 -> block3
            nn.ReflectionPad2d(1), nn.Conv2d(512, 256, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 128, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 128, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 64, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 64, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 3, 3)  # last conv, no activation
        )

    def forward(self, x):
        return self.body(x)

# Loss network wrapper using VGG-19 (fixed)
class LossNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = VGGEncoder()
        self.encoder.eval()
        for p in self.encoder.parameters():
            p.requires_grad_(False)

    @torch.no_grad()
    def encode(self, x, keys=None):
        return self.encoder(x, out_keys=tuple(keys) if keys else tuple(STYLE_LAYERS))

    def forward(self, x):
        # not used directly
        return self.encode(x)

def mean_std_loss(x_feats, y_feats):
    """
    IN-statistics loss: sum over layers of (mean MSE + std MSE)
    """
    loss = 0.0
    for k in STYLE_LAYERS:
        xm, xs = calc_mean_std(x_feats[k])
        ym, ys = calc_mean_std(y_feats[k])
        loss = loss + F.mse_loss(xm, ym) + F.mse_loss(xs, ys)
    return loss


In [8]:
import torch, subprocess, sys, os
print("torch:", torch.__version__)
print("built CUDA:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())
print("arch list (compiled into this wheel):", torch.cuda.get_arch_list())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    p = torch.cuda.get_device_properties(0)
    print("compute capability: sm_{}{}".format(p.major, p.minor))
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout[:400])


torch: 2.8.0+cu128
built CUDA: 12.8
CUDA available: True
arch list (compiled into this wheel): ['sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'sm_100', 'sm_120']
GPU: NVIDIA GeForce RTX 5060 Ti
compute capability: sm_120
Thu Aug 21 01:53:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.153.02             Driver Version: 570.153.02     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |



In [9]:
enc = VGGEncoder().to(device).eval()
dec = Decoder().to(device).eval()
lossnet = LossNet().to(device).eval()

import torch.nn.functional as F
x_content = torch.randn(2, 3, 256, 256, device=device)
x_style   = torch.randn(2, 3, 256, 256, device=device)

with torch.no_grad():
    c4 = enc(x_content, out_keys=['relu4_1'])['relu4_1']
    s4 = enc(x_style,   out_keys=['relu4_1'])['relu4_1']
    t  = adain(c4, s4)
    y  = dec(t).clamp(-3, 3)
    fy = lossnet.encode(y, keys=STYLE_LAYERS)
    fs = lossnet.encode(x_style, keys=STYLE_LAYERS)
    Lc = F.mse_loss(fy[CONTENT_LAYER], t).item()
    Ls = mean_std_loss(fy, fs).item()

print("Output image shape:", tuple(y.shape))    # expect (2, 3, 256, 256)
print("Content loss:", Lc)                      # positive float
print("Style loss:", Ls)                        # positive float


Output image shape: (2, 3, 256, 256)
Content loss: 6.061115264892578
Style loss: 47.425384521484375


In [11]:
# --- cell 4: training (fixed) ---
import torch.nn.functional as F

@torch.no_grad()
def denorm_for_save(x):
    # x is normalized (ImageNet), bring it back to [0,1] for saving
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
    y = x * std + mean
    return torch.clamp(y, 0, 1)

def save_checkpoint(iteration, decoder, opt, cfg, extra_samples=None):
    ckpt = {
        'iteration': iteration,
        'decoder': decoder.state_dict(),
        'optimizer': opt.state_dict(),
        'cfg': cfg.__dict__,
        'time': time.time(),
    }
    path = os.path.join(cfg.out_dir, f"decoder_iter_{iteration}.pth")
    torch.save(ckpt, path)
    print(f"[ckpt] saved: {path}")

    # save a grid of sample images if provided (for quick visual check)
    if extra_samples is not None:
        grid = utils.make_grid(extra_samples, nrow=len(extra_samples)//2 or 4)
        utils.save_image(grid, os.path.join(cfg.out_dir, f"samples_iter_{iteration}.png"))

def load_checkpoint(path, decoder, opt=None):
    data = torch.load(path, map_location='cpu')
    decoder.load_state_dict(data['decoder'], strict=True)
    if opt is not None and 'optimizer' in data:
        opt.load_state_dict(data['optimizer'])
    print(f"[ckpt] loaded: {path} @ iter={data.get('iteration')}")
    return data.get('iteration', 0)

# Instantiate nets
decoder = Decoder().to(device)
lossnet = LossNet().to(device).eval()  # VGG features (frozen)

optimizer = torch.optim.Adam(decoder.parameters(), lr=cfg.lr)

start_iter = 0
if cfg.resume:
    start_iter = load_checkpoint(cfg.resume, decoder, optimizer)

# Training
decoder.train()
global_iter = start_iter

# Simple infinite loader over epochs until max_iterations
data_iter = iter(train_loader)

print("Starting training...")
while global_iter < cfg.max_iterations:
    try:
        content, style = next(data_iter)
    except StopIteration:
        data_iter = iter(train_loader)
        content, style = next(data_iter)

    content = content.to(device, non_blocking=True)
    style   = style.to(device, non_blocking=True)

    # --- Encode targets (constants for this step) ---
    # Use no_grad ONLY for target encodes; keep graph for generated encodes.
    with torch.no_grad():
        c4 = lossnet.encoder(content, out_keys=['relu4_1'])['relu4_1']
        s4 = lossnet.encoder(style,   out_keys=['relu4_1'])['relu4_1']

    # AdaIN target feature (no grad needed; serves as fixed target)
    t_feat = adain(c4, s4)

    # --- Decode: we need gradients through the decoder ---
    g_img = decoder(t_feat)                      # (B,3,H,W) in "normalized" space
    g_img_clamped = torch.clamp(g_img, -3.0, 3.0)  # keep stability; retains grad

    # --- Re-encode generated for losses (WITH grad) ---
    g_feats_all = lossnet.encoder(g_img_clamped, out_keys=STYLE_LAYERS)

    # Style features for targets (no grad)
    with torch.no_grad():
        s_feats_all = lossnet.encoder(style, out_keys=STYLE_LAYERS)

    # --- Losses ---
    # Content loss: || f(g(t))_relu4_1 - t ||^2
    loss_c = F.mse_loss(g_feats_all[CONTENT_LAYER], t_feat)
    # Style loss: mean/std over multiple layers
    loss_s = mean_std_loss(g_feats_all, s_feats_all)
    loss   = loss_c + cfg.lambda_style * loss_s

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    global_iter += 1

    # Optional tiny grad sanity after first update
    if global_iter == start_iter + 1:
        gmean = sum((p.grad.abs().mean().item() for p in decoder.parameters() if p.grad is not None))
        print("Grad sanity (decoder mean |grad|):", f"{gmean:.6f}")

    if global_iter % cfg.log_every == 0:
        print(f"[{global_iter:>6d}/{cfg.max_iterations}] "
              f"loss={loss.item():.4f}  Lc={loss_c.item():.4f}  Ls={loss_s.item():.4f}")

    if global_iter % cfg.save_every == 0:
        # Save checkpoint + a quick sample grid (content, style, stylized)
        with torch.no_grad():
            c_show = denorm_for_save(content[:2])
            s_show = denorm_for_save(style[:2])
            g_show = denorm_for_save(g_img_clamped[:2])
            samples = torch.cat([c_show, s_show, g_show], dim=0)
        save_checkpoint(global_iter, decoder, optimizer, cfg, extra_samples=samples)

# Final save
final_path = os.path.join(cfg.out_dir, "decoder_final.pth")
torch.save({'iteration': global_iter, 'decoder': decoder.state_dict(), 'cfg': cfg.__dict__}, final_path)
print(f"[final] saved: {final_path}")



Starting training...
Grad sanity (decoder mean |grad|): 15.536363
[   200/160000] loss=52.2870  Lc=12.9208  Ls=3.9366
[   400/160000] loss=40.1096  Lc=13.9661  Ls=2.6143
[   600/160000] loss=29.9711  Lc=11.2713  Ls=1.8700
[   800/160000] loss=30.3923  Lc=14.3270  Ls=1.6065
[  1000/160000] loss=24.5743  Lc=11.6843  Ls=1.2890
[  1200/160000] loss=29.3277  Lc=10.0298  Ls=1.9298
[  1400/160000] loss=28.0367  Lc=11.8288  Ls=1.6208
[  1600/160000] loss=26.0955  Lc=10.5608  Ls=1.5535
[  1800/160000] loss=27.2486  Lc=12.2926  Ls=1.4956
[  2000/160000] loss=55.4214  Lc=16.6493  Ls=3.8772
[ckpt] saved: ./adain_runs/decoder_iter_2000.pth
[  2200/160000] loss=32.3475  Lc=13.9576  Ls=1.8390
[  2400/160000] loss=19.9653  Lc=8.1332  Ls=1.1832
[  2600/160000] loss=39.2639  Lc=14.6628  Ls=2.4601
[  2800/160000] loss=26.8605  Lc=10.9935  Ls=1.5867
[  3000/160000] loss=22.3445  Lc=10.0339  Ls=1.2311
[  3200/160000] loss=20.4982  Lc=10.3818  Ls=1.0116
[  3400/160000] loss=20.3055  Lc=8.3495  Ls=1.1956
[  

In [2]:
# --- Self-contained AdaIN inference cell (defines everything + runs test) ---

import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, utils
from PIL import Image
import torchvision.transforms as T

# ------------------------ CONFIG (edit these) ------------------------
content_path_in = "content test2.jpg"   # your content image
style_path_in   = "Van_Gogh_-_Starry_Night.jpg"     # your style image
preferred_ckpt  = "/Test_image/adain_runs/decoder_iter_160000.pth"  # try this first
fallback_dirs   = ["adain_runs", "Test_image/adain_runs"]          # auto-search if needed
out_path        = "Test_image/adain_runs/test_stylized3.png"
alpha           = 1.0      # 0..1 (0 more content, 1 more style)
max_side        = 512      # resize shorter side to this
# --------------------------------------------------------------------

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# --- helpers: normalization / AdaIN ---
def calc_mean_std(feat, eps=1e-5):
    B, C = feat.size()[:2]
    feat_var = feat.view(B, C, -1).var(dim=2, unbiased=False) + eps
    feat_std = feat_var.sqrt().view(B, C, 1, 1)
    feat_mean = feat.view(B, C, -1).mean(dim=2).view(B, C, 1, 1)
    return feat_mean, feat_std

def adain(content_feat, style_feat, eps=1e-5):
    c_mean, c_std = calc_mean_std(content_feat, eps)
    s_mean, s_std = calc_mean_std(style_feat, eps)
    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

@torch.no_grad()
def denorm_for_save(x):
    # x assumed normalized for ImageNet
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
    y = x * std + mean
    return torch.clamp(y, 0, 1)

# --- VGG19 encoder up to relu4_1 (frozen), returns selected layers ---
LAYER_NAME_MAP = { 1:'relu1_1', 6:'relu2_1', 11:'relu3_1', 20:'relu4_1' }
STYLE_LAYERS = ['relu1_1','relu2_1','relu3_1','relu4_1']
CONTENT_LAYER = 'relu4_1'

class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
        for p in self.vgg.parameters():
            p.requires_grad_(False)

    def forward(self, x, out_keys=('relu4_1',)):
        feats = {}
        h = x
        for i, layer in enumerate(self.vgg):
            h = layer(h)
            name = LAYER_NAME_MAP.get(i, None)
            if name in out_keys:
                feats[name] = h
            if i >= 20:
                # we only care up to relu4_1
                pass
        return feats

# --- Decoder (must match what you trained) ---
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            nn.ReflectionPad2d(1), nn.Conv2d(512, 256, 3), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 256, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(256, 128, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 128, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(128, 64, 3), nn.ReLU(inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 64, 3), nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1), nn.Conv2d(64, 3, 3)  # last conv, no activation
        )

    def forward(self, x):
        return self.body(x)

# --- image IO ---
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def load_image_for_vgg(path: Path, max_side=512):
    im = Image.open(path).convert('RGB')
    w, h = im.size
    if min(w, h) != max_side:
        if w < h:
            new_w, new_h = max_side, int(h * max_side / w)
        else:
            new_h, new_w = max_side, int(w * max_side / h)
        im = im.resize((new_w, new_h), Image.BICUBIC)
    x = T.ToTensor()(im)
    x = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)(x)
    return x.unsqueeze(0)

# --- path resolution & checkpoint finder ---
CWD = Path.cwd()
def resolve(p: str|Path) -> Path:
    p = Path(p)
    return p if p.is_absolute() else (CWD / p)

def find_latest_ckpt(search_dirs):
    best = None
    for d in search_dirs:
        d = resolve(d)
        if not d.is_dir():
            continue
        for p in d.glob("decoder_iter_*.pth"):
            try:
                it = int(p.stem.split("_")[-1])
                best = max(best, (it, p)) if best else (it, p)
            except:
                pass
        if best is None:
            fin = d / "decoder_final.pth"
            if fin.is_file():
                best = (0, fin)
    return best[1] if best else None

content_path = resolve(content_path_in)
style_path   = resolve(style_path_in)
if not content_path.is_file():
    raise FileNotFoundError(f"Content image not found: {content_path}")
if not style_path.is_file():
    raise FileNotFoundError(f"Style image not found: {style_path}")

ckpt_path = resolve(preferred_ckpt)
if not ckpt_path.is_file():
    print(f"Preferred checkpoint not found: {ckpt_path}")
    ckpt_path = find_latest_ckpt([preferred_ckpt] + fallback_dirs)
    if ckpt_path is None:
        print("Top-level:", list(CWD.iterdir()))
        if (CWD/"adain_runs").exists():
            print("adain_runs:", list((CWD/"adain_runs").glob("*")))
        if (CWD/"Test_image").exists():
            print("Test_image:", list((CWD/"Test_image").glob("*")))
        raise FileNotFoundError("No checkpoint found in the usual locations.")
print("Using checkpoint:", ckpt_path)

# --- run stylization ---
@torch.no_grad()
def stylize_one(content_path: Path, style_path: Path, decoder_ckpt: Path,
                alpha=1.0, out_path="stylized.png", max_side=512):
    # nets
    dec = Decoder().to(device).eval()
    ckpt = torch.load(str(decoder_ckpt), map_location=device)
    dec.load_state_dict(ckpt['decoder'])

    enc = VGGEncoder().to(device).eval()

    # imgs
    c = load_image_for_vgg(content_path, max_side=max_side).to(device)
    s = load_image_for_vgg(style_path,   max_side=max_side).to(device)

    # AdaIN at relu4_1
    c4 = enc(c, out_keys=['relu4_1'])['relu4_1']
    s4 = enc(s, out_keys=['relu4_1'])['relu4_1']
    t  = adain(c4, s4)
    if alpha < 1.0:
        t = alpha * t + (1 - alpha) * c4

    y = dec(t).clamp(-3, 3)
    y_img = denorm_for_save(y)

    out_path = resolve(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    utils.save_image(y_img, str(out_path))
    print(f"[stylize] saved: {out_path}")

stylize_one(
    content_path=content_path,
    style_path=style_path,
    decoder_ckpt=ckpt_path,
    alpha=alpha,
    out_path=out_path,
    max_side=max_side
)



Device: cuda
Preferred checkpoint not found: /Test_image/adain_runs/decoder_iter_160000.pth
Using checkpoint: /workspace/adain_runs/decoder_iter_160000.pth
[stylize] saved: /workspace/Test_image/adain_runs/test_stylized.png
