In [7]:
# ===== cell 0: config =====
from pathlib import Path
import torch, torch.nn as nn, random, numpy as np

# paths
DATA_DIR = Path("severe cases/HMI_continuum")
RUN_DIR  = Path("./wgangp_128")
(RUN_DIR/"samples").mkdir(parents=True, exist_ok=True)
(RUN_DIR/"ckpt").mkdir(parents=True, exist_ok=True)

# model
image_size = 128
nc = 1          # 1=grayscale, 3=RGB
nz = 128        # latent size
ngf = 160       # generator width (keep if you used 160 earlier)
ndf = 128       # discriminator width

# training
batch_size = 64                # drop to 64 if OOM
num_epochs = 150               # total epochs target
n_critic   = 3                 # D steps per G step (classic WGAN-GP)
lambda_gp  = 5.0              # gradient penalty coefficient
lr_g = 1.5e-4; lr_d = 8e-5     # WGAN-GP rec: betas=(0.0, 0.9)1e-4 
beta1, beta2 = 0.0, 0.9
mixed_precision = True

# reproducibility
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)


device: cuda


In [8]:
# ===== cell 1: dataset =====
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageEnhance
import os, random

class EnhanceContrast:
    def __init__(self, p=0.4, factor_range=(1.1,1.4)): self.p=p; self.fr=factor_range
    def __call__(self, img):
        if random.random() < self.p:
            return ImageEnhance.Contrast(img).enhance(random.uniform(*self.fr))
        return img

class FlatImageFolder(Dataset):
    def __init__(self, root, transform=None, exts=(".jpg",".jpeg",".png",".tif",".tiff",".bmp")):
        p = Path(root)
        self.paths = [p/f for f in os.listdir(p) if (p/f).is_file() and (p/f).suffix.lower() in exts]
        if not self.paths: raise RuntimeError(f"No images found in {root}")
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert("RGB")
        return self.transform(img)

aug = [
    # transforms.RandomHorizontalFlip(0.5),
    # transforms.RandomVerticalFlip(0.5),
    # transforms.RandomAffine(degrees=10, translate=(0.05,0.05), scale=(0.95,1.05)),
    # EnhanceContrast(0.5, (1.1,1.5)),
]

tfm = transforms.Compose([
    transforms.Grayscale(num_output_channels=1) if nc==1 else transforms.Lambda(lambda x: x),
    *aug,
    transforms.Resize((image_size, image_size), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*nc, [0.5]*nc),  # [-1,1]
])

ds = FlatImageFolder(DATA_DIR, transform=tfm)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
print("images:", len(ds), "| batches/epoch:", len(dl))


images: 2705 | batches/epoch: 42


In [9]:
# ===== Generator (ResNet-style upsampling + Self-Attention at 32x32) =====
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm as SN


# ---------- helper blocks ----------
def weights_init_dcgan(m):
    name = m.__class__.__name__
    if "Conv" in name and hasattr(m, "weight"):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if getattr(m, "bias", None) is not None:
            nn.init.zeros_(m.bias)
    elif "BatchNorm" in name:
        if hasattr(m, "weight") and m.weight is not None:
            nn.init.normal_(m.weight, 1.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.zeros_(m.bias)

class ResUpBlock(nn.Module):
    """Upsample x2 with nearest, then 2 convs + residual skip."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.skip  = nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        out = self.conv1(x); out = self.bn1(out); out = F.relu(out, inplace=True)
        out = self.conv2(out); out = self.bn2(out)
        skip = self.skip(x)
        return F.relu(out + skip, inplace=True)


class SelfAttention(nn.Module):
    """Simple self-attention block."""
    def __init__(self, ch, sn=False):
        super().__init__()
        conv = (lambda *a, **k: SN(nn.Conv2d(*a, **k))) if sn else nn.Conv2d
        self.f = conv(ch, ch//8, 1, bias=False)
        self.g = conv(ch, ch//8, 1, bias=False)
        self.h = conv(ch, ch//2, 1, bias=False)
        self.v = conv(ch//2, ch, 1, bias=False)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        b,c,H,W = x.size()
        f = self.f(x).view(b,-1,H*W)
        g = self.g(x).view(b,-1,H*W)
        beta = torch.softmax(torch.bmm(f.transpose(1,2), g), dim=-1)
        h_ = self.h(x).view(b,-1,H*W)
        o = torch.bmm(h_, beta).view(b,-1,H,W)
        o = self.v(o)
        return self.gamma * o + x


# ---------- main generator ----------
class Generator128_Res(nn.Module):
    def __init__(self, nz, ngf, nc):
        super().__init__()
        self.fc = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*16), nn.ReLU(True),
            nn.Conv2d(ngf*16, ngf*16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(ngf*16), nn.ReLU(True),
        )
        self.b1 = ResUpBlock(ngf*16, ngf*8)   # 4 -> 8
        self.b2 = ResUpBlock(ngf*8,  ngf*4)   # 8 -> 16
        self.b3 = ResUpBlock(ngf*4,  ngf*2)   # 16 -> 32

        # âœ… <--- this is where we add attention (32x32 resolution)
        self.attn32 = SelfAttention(ngf*2, sn=False)

        self.b4 = ResUpBlock(ngf*2,  ngf)     # 32 -> 64
        self.to_rgb = nn.Sequential(
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),  # 64 -> 128
            nn.Tanh(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = self.b1(x)       # 4 -> 8
        x = self.b2(x)       # 8 -> 16
        x = self.b3(x)       # 16 -> 32
        x = self.attn32(x)   # <-- attention here
        x = self.b4(x)       # 32 -> 64
        x = self.to_rgb(x)   # 64 -> 128
        return x

class Discriminator128(nn.Module):
    # 128 -> 64 -> 32 -> 16 -> 8 -> 4 -> 1  (no BN)
    def __init__(self, ndf, nc):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nc,     ndf,    4,2,1, bias=False), nn.LeakyReLU(0.2, True),  # 128->64
            nn.Conv2d(ndf,    ndf*2,  4,2,1, bias=False), nn.LeakyReLU(0.2, True),  # 64->32
            nn.Conv2d(ndf*2,  ndf*4,  4,2,1, bias=False), nn.LeakyReLU(0.2, True),  # 32->16
            nn.Conv2d(ndf*4,  ndf*8,  4,2,1, bias=False), nn.LeakyReLU(0.2, True),  # 16->8
            nn.Conv2d(ndf*8,  ndf*16, 4,2,1, bias=False), nn.LeakyReLU(0.2, True),  # 8->4
            nn.Conv2d(ndf*16, 1,      4,1,0, bias=False),                           # 4->1  (linear output)
        )
    def forward(self, x): return self.net(x).view(-1)

G = Generator128_Res(nz, ngf, nc).to(device); G.apply(weights_init_dcgan)
D = Discriminator128(ndf, nc).to(device);  D.apply(weights_init_dcgan)

optG = torch.optim.Adam(G.parameters(), lr=lr_g, betas=(beta1, beta2))
optD = torch.optim.Adam(D.parameters(), lr=lr_d, betas=(beta1, beta2))

print("G params (M):", sum(p.numel() for p in G.parameters())/1e6)
print("D params (M):", sum(p.numel() for p in D.parameters())/1e6)


G params (M): 127.479681
D params (M): 44.599296


In [10]:
# ===== cell 3: AMP =====
import inspect
from functools import partial

has_new_amp = hasattr(torch, "amp") and hasattr(torch.amp, "autocast")
if has_new_amp:
    from torch import amp as _amp
    autocast_cm = partial(_amp.autocast, device_type="cuda")
    _GradScaler = _amp.GradScaler
    if "device_type" in inspect.signature(_GradScaler).parameters:
        scaler = _GradScaler(device_type="cuda", enabled=(mixed_precision and torch.cuda.is_available()))
    else:
        scaler = _GradScaler(enabled=(mixed_precision and torch.cuda.is_available()))
else:
    from torch.cuda.amp import autocast as legacy_autocast, GradScaler as LegacyGradScaler
    autocast_cm = partial(legacy_autocast)
    scaler = LegacyGradScaler(enabled=(mixed_precision and torch.cuda.is_available()))

use_amp = (getattr(scaler, "is_enabled", lambda: False)()) or (mixed_precision and torch.cuda.is_available())
print(f"AMP API: {'new' if has_new_amp else 'legacy'} | AMP enabled: {use_amp}")


AMP API: new | AMP enabled: True


In [11]:
# ===== cell 4: losses & utils =====
num_epochs = 500

from torchvision.utils import save_image

def d_wgan_loss(d_real, d_fake):
    # maximize D(real) - D(fake) -> minimize (fake - real)
    return d_fake.mean() - d_real.mean()

def g_wgan_loss(d_fake):
    # maximize D(fake) -> minimize -D(fake)
    return -d_fake.mean()

def gradient_penalty(D, real, fake):
    b = real.size(0)
    eps = torch.rand(b, 1, 1, 1, device=real.device)
    x_hat = eps * real + (1 - eps) * fake
    x_hat.requires_grad_(True)
    d_hat = D(x_hat)
    grad = torch.autograd.grad(d_hat.sum(), x_hat, create_graph=True, retain_graph=True, only_inputs=True)[0]
    gp = (grad.view(b, -1).norm(2, dim=1) - 1.0).pow(2).mean()
    return gp

# display helpers (clear grids)
def _to_display(x, mode="linear", eps=1e-6):
    if mode == "linear":
        y = (x + 1) / 2
    elif mode == "stretch":
        B = x.size(0); flat = x.view(B, -1)
        mn = flat.min(dim=1, keepdim=True).values; mx = flat.max(dim=1, keepdim=True).values
        y = ((flat - mn) / (mx - mn + eps)).view_as(x)
    else:
        y = (x + 1) / 2
    return y.clamp(0,1)

fixed_z = torch.randn(128, nz, 1, 1, device=device)

def save_samples(epoch, model=None, nrow=16, vis_mode="stretch"):
    model = model or G
    model.eval()
    with torch.no_grad():
        x = model(fixed_z.to(next(model.parameters()).device))
        v = _to_display(x, mode=vis_mode)
    save_image(v, RUN_DIR/f"samples/epoch_{epoch:04d}.png", nrow=nrow)
    model.train()

# ckpt (with scaler state)
def save_ckpt(epoch, scaler_obj=None):
    g_tmp = RUN_DIR/"ckpt"/f".tmp_epoch_{epoch:04d}_G.pt"
    d_tmp = RUN_DIR/"ckpt"/f".tmp_epoch_{epoch:04d}_D.pt"
    g_dst = RUN_DIR/"ckpt"/f"epoch_{epoch:04d}_G.pt"
    d_dst = RUN_DIR/"ckpt"/f"epoch_{epoch:04d}_D.pt"
    torch.save({"epoch":epoch, "G":G.state_dict(), "optG":optG.state_dict(),
                "scaler": (scaler_obj.state_dict() if scaler_obj is not None else None)}, g_tmp)
    torch.save({"epoch":epoch, "D":D.state_dict(), "optD":optD.state_dict()}, d_tmp)
    import os; os.replace(g_tmp, g_dst); os.replace(d_tmp, d_dst)
    print(f"[ckpt] saved {epoch:04d}")

def resume_if_any():
    import re
    ckpts = sorted((RUN_DIR/"ckpt").glob("epoch_*_G.pt"))
    if not ckpts: print("[resume] none"); return 1
    last = ckpts[-1]
    e = int(re.search(r"epoch_(\d+)_G\.pt", last.name).group(1))
    g_state = torch.load(last, map_location=device)
    d_state = torch.load(str(last).replace("_G.pt","_D.pt"), map_location=device)
    G.load_state_dict(g_state["G"]); D.load_state_dict(d_state["D"])
    optG.load_state_dict(g_state["optG"]); optD.load_state_dict(d_state["optD"])
    if "scaler" in g_state and g_state["scaler"] is not None:
        try: scaler.load_state_dict(g_state["scaler"]); print("[resume] scaler restored")
        except Exception as e2: print("[resume] scaler not restored:", e2)
    print(f"[resume] loaded epoch {e} â†’ continue from {e+1}")
    return e+1

start_epoch = resume_if_any()


  g_state = torch.load(last, map_location=device)
  d_state = torch.load(str(last).replace("_G.pt","_D.pt"), map_location=device)


[resume] scaler restored
[resume] loaded epoch 440 â†’ continue from 441


In [12]:
# ===== cell 5: training (improved with detailed logging) =====

try:
    global_step = 0
    for epoch in range(start_epoch, num_epochs + 1):
        G.train(); D.train()

        # New meters for detailed logging
        d_meter = g_meter = 0.0
        dreal_meter = dfake_meter = w_meter = gp_meter = 0.0

        for real in dl:
            real = real.to(device, non_blocking=True)
            b = real.size(0)

            # --- train D n_critic times ---
            for _ in range(n_critic):
                optD.zero_grad(set_to_none=True)
                if use_amp:
                    with autocast_cm(enabled=True):
                        z = torch.randn(b, nz, 1, 1, device=device)
                        fake = G(z).detach()

                        d_real = D(real)
                        d_fake = D(fake)
                        loss_d = d_wgan_loss(d_real, d_fake)
                        gp = gradient_penalty(D, real, fake)
                        d_loss = loss_d + lambda_gp * gp

                        # --- logging additions ---
                        d_real_mean = d_real.mean().item()
                        d_fake_mean = d_fake.mean().item()
                        gp_val = gp.item()
                        wasserstein = d_real_mean - d_fake_mean
                        dreal_meter += d_real_mean
                        dfake_meter += d_fake_mean
                        w_meter += wasserstein
                        gp_meter += gp_val
                        # ---------------------------

                    scaler.scale(d_loss).backward()
                    scaler.step(optD); scaler.update()
                else:
                    z = torch.randn(b, nz, 1, 1, device=device)
                    fake = G(z).detach()
                    d_real = D(real); d_fake = D(fake)
                    loss_d = d_wgan_loss(d_real, d_fake)
                    gp = gradient_penalty(D, real, fake)
                    d_loss = loss_d + lambda_gp * gp

                    # --- logging additions ---
                    d_real_mean = d_real.mean().item()
                    d_fake_mean = d_fake.mean().item()
                    gp_val = gp.item()
                    wasserstein = d_real_mean - d_fake_mean
                    dreal_meter += d_real_mean
                    dfake_meter += d_fake_mean
                    w_meter += wasserstein
                    gp_meter += gp_val
                    # ---------------------------

                    d_loss.backward(); optD.step()

                d_meter += float(d_loss.detach().cpu())
                global_step += 1

            # --- train G once ---
            optG.zero_grad(set_to_none=True)
            if use_amp:
                with autocast_cm(enabled=True):
                    z = torch.randn(b, nz, 1, 1, device=device)
                    fake = G(z)
                    d_fake = D(fake)
                    g_loss = g_wgan_loss(d_fake)
                scaler.scale(g_loss).backward()
                scaler.step(optG); scaler.update()
            else:
                z = torch.randn(b, nz, 1, 1, device=device)
                fake = G(z); d_fake = D(fake)
                g_loss = g_wgan_loss(d_fake)
                g_loss.backward(); optG.step()

            g_meter += float(g_loss.detach().cpu())

        # --- end of epoch summary ---
        n_batches = len(dl) * n_critic
        print(f"[{epoch:03d}/{num_epochs}] "
              f"Wâ‰ˆ{w_meter/n_batches:.2f} | "
              f"Dreal={dreal_meter/n_batches:.2f} | "
              f"Dfake={dfake_meter/n_batches:.2f} | "
              f"GP={gp_meter/n_batches:.2f} | "
              f"G={g_meter/len(dl):.4f}")

        # --- save samples + checkpoints ---
        save_samples(epoch, model=G, nrow=16, vis_mode="stretch")
        if epoch % 10 == 0:
            save_ckpt(epoch, scaler_obj=scaler)

    save_ckpt(num_epochs, scaler_obj=scaler)
    print("âœ… training complete")

except KeyboardInterrupt:
    print("\nðŸ›‘ interrupted â€” savingâ€¦")
    try:
        save_ckpt(epoch, scaler_obj=scaler)
    except:
        save_ckpt(0, scaler_obj=scaler)
    print("âœ… saved")


[441/500] Wâ‰ˆ7.45 | Dreal=9.09 | Dfake=1.64 | GP=0.37 | G=-0.1809
[442/500] Wâ‰ˆ8.76 | Dreal=8.56 | Dfake=-0.20 | GP=0.50 | G=1.4968
[443/500] Wâ‰ˆ8.02 | Dreal=11.43 | Dfake=3.41 | GP=0.46 | G=-0.1995
[444/500] Wâ‰ˆ7.04 | Dreal=8.18 | Dfake=1.14 | GP=0.31 | G=1.3479
[445/500] Wâ‰ˆ7.45 | Dreal=8.38 | Dfake=0.93 | GP=0.39 | G=0.1622
[446/500] Wâ‰ˆ6.52 | Dreal=8.40 | Dfake=1.88 | GP=0.30 | G=-0.8682
[447/500] Wâ‰ˆ7.65 | Dreal=9.00 | Dfake=1.35 | GP=0.41 | G=-1.3803
[448/500] Wâ‰ˆ8.48 | Dreal=10.11 | Dfake=1.63 | GP=0.48 | G=0.5639
[449/500] Wâ‰ˆ8.26 | Dreal=7.92 | Dfake=-0.34 | GP=0.44 | G=0.9655
[450/500] Wâ‰ˆ7.83 | Dreal=10.38 | Dfake=2.55 | GP=0.37 | G=0.2133
[ckpt] saved 0450
[451/500] Wâ‰ˆ7.15 | Dreal=7.83 | Dfake=0.68 | GP=0.36 | G=1.4044
[452/500] Wâ‰ˆ8.43 | Dreal=8.83 | Dfake=0.39 | GP=0.46 | G=2.5385
[453/500] Wâ‰ˆ7.19 | Dreal=6.16 | Dfake=-1.03 | GP=0.38 | G=2.3958
[454/500] Wâ‰ˆ7.60 | Dreal=9.84 | Dfake=2.24 | GP=0.38 | G=0.8866
[455/500] Wâ‰ˆ7.86 | Dreal=9.81 | Dfake=1.96 | G