In [3]:
#!/usr/bin/env python3
"""
Stegano Training + Comprehensive Evaluation (reduced metrics)
- Keeps: PSNR, SSIM, NRPC, USEI (UQI), LPIPS (optional)
- Adds: Recovery bitstream accuracy, Residual-signal estimated capacity (bpp)
- Adds separate graphs for every metric.
- Usage: same as your original script — ensure dataset path correct.
"""

import os, math, torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms, models
from torchvision.utils import make_grid, save_image
from pytorch_msssim import ssim
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
from glob import glob
from datetime import datetime
import pandas as pd
from skimage.metrics import structural_similarity as ssim_sk

# Optional / best-effort imports
has_lpips = False
try:
    import lpips
    has_lpips = True
except Exception:
    print("lpips not available — LPIPS will be skipped. Install with 'pip install lpips'")

# -----------------------------
# Config
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.manual_seed(42)

img_size = 128
batch_size = 24
lr = 1e-4
epochs = 50
lambda_ssim = 1.0
lambda_rec = 15.0
lambda_perc = 0.5
save_every = 5
DATA_FOLDER = "img_128x_50000"

print("Device:", device, "img_size:", img_size)

# -----------------------------
# Dataset
# -----------------------------
class CustomImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        patterns = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]
        self.paths = []
        for p in patterns:
            self.paths += sorted(glob(os.path.join(folder_path, "*" + p)))
        self.transform = transform

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, 0

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.1,0.1,0.1,0.02),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

dataset = CustomImageDataset(DATA_FOLDER, transform=transform)
print("Loaded images:", len(dataset))

# -----------------------------
# Utility metrics
# -----------------------------
def psnr_t(a, b, data_range=2.0):
    mse_val = F.mse_loss(a, b).item()
    if mse_val == 0: return float("inf")
    return 10 * math.log10((data_range ** 2) / mse_val)

def tensor_to_vgg_input(x):
    x = (x + 1) / 2
    mean = torch.tensor([0.485, 0.456, 0.406]).to(x.device).view(1,3,1,1)
    std = torch.tensor([0.229, 0.224, 0.225]).to(x.device).view(1,3,1,1)
    return (x - mean) / std

def compute_nrpc(cover, stego):
    c = ((cover + 1) / 2).detach().cpu().numpy()
    s = ((stego + 1) / 2).detach().cpu().numpy()
    res_c = c - np.mean(c)
    res_s = s - np.mean(s)
    num = np.sum(res_c * res_s)
    denom = np.sqrt(np.sum(res_c ** 2) * np.sum(res_s ** 2)) + 1e-12
    return float(num / denom)

def compute_usei_uqi(a, b):
    a_np = ((a + 1) / 2).detach().cpu().numpy()
    b_np = ((b + 1) / 2).detach().cpu().numpy()
    Bs = a_np.shape[0]
    vals = []
    for i in range(Bs):
        uqi_ch = []
        for ch in range(a_np.shape[1]):
            x = a_np[i,ch].flatten()
            y = b_np[i,ch].flatten()
            mean_x = x.mean(); mean_y = y.mean()
            cov = np.mean((x-mean_x)*(y-mean_y))
            var_x = x.var(); var_y = y.var()
            num = 4 * mean_x * mean_y * cov
            den = (mean_x**2 + mean_y**2) * (var_x + var_y) + 1e-12
            uqi = num/den
            uqi_ch.append(uqi)
        vals.append(np.mean(uqi_ch))
    return float(np.mean(vals))

def image_to_uint8(t):
    # t expected in [-1,1]; convert -> [0,255] uint8
    with torch.no_grad():
        x = ((t + 1) / 2.0 * 255.0).round().clamp(0,255).to(torch.uint8).cpu().numpy()
    return x  # shape: (B, C, H, W), dtype=uint8

def batch_bitstream_accuracy(secret_t, recovered_t):
    """
    Convert secret and recovered images to 8-bit/channel bitstreams and compute bit accuracy.
    Returns (matches / total_bits) as float.
    """
    s_uint8 = image_to_uint8(secret_t)
    r_uint8 = image_to_uint8(recovered_t)
    # pack bits per byte (big-endian) -> shape (..., 8)
    # flatten across batch, channels, height, width
    s_flat = s_uint8.reshape(-1).astype(np.uint8)
    r_flat = r_uint8.reshape(-1).astype(np.uint8)
    # convert to bits
    s_bits = np.unpackbits(s_flat)
    r_bits = np.unpackbits(r_flat)
    # safety: lengths same
    total_bits = s_bits.size
    if total_bits == 0:
        return float('nan')
    matches = int(np.sum(s_bits == r_bits))
    return matches / total_bits

def residual_bpp_estimate(cover, stego):
    """
    Estimate residual bit-capacity (bpp) from residual distribution.
    - residual = stego - cover  (both in [-1,1])
    - map residual to 0..255 and compute Shannon entropy (bits per channel symbol)
    - multiply by number of channels (3) => bits per pixel (bpp)
    """
    with torch.no_grad():
        resid = (stego - cover).detach().cpu().numpy()  # shape (B,C,H,W)
    # map resid from [-2,2] into [0,255] (since st ranges [-1,1], resid in [-2,2])
    mapped = ((resid + 2.0) / 4.0 * 255.0).round().clip(0,255).astype(np.uint8).ravel()
    if mapped.size == 0:
        return float('nan')
    hist = np.bincount(mapped, minlength=256).astype(np.float64)
    probs = hist / (hist.sum() + 1e-12)
    probs = probs[probs > 0]
    entropy_bits = -np.sum(probs * np.log2(probs))  # bits per symbol (channel)
    # bits per pixel = entropy_bits * channels (3)
    bpp = float(entropy_bits * 3.0)
    return bpp

def compute_lpips(a, b, lpips_model=None):
    if not has_lpips or lpips_model is None: return float('nan')
    with torch.no_grad(): return float(lpips_model(a, b).mean().item())

# -----------------------------
# Model definitions (same)
# -----------------------------
class CSPBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        split_point = max(in_channels // 2, 1)
        self.split_point = split_point
        self.part1 = nn.Sequential(
            nn.Conv2d(split_point, split_point, 3, padding=1, bias=False),
            nn.BatchNorm2d(split_point),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(split_point, split_point, 3, padding=1, bias=False),
            nn.BatchNorm2d(split_point),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.transition = nn.Sequential(
            nn.Conv2d(split_point + (in_channels - split_point), out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, x):
        sp = self.split_point
        x1, x2 = x[:, :sp], x[:, sp:]
        y1 = self.part1(x1)
        return self.transition(torch.cat((y1, x2), dim=1))

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(6, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.csp1 = CSPBlock(64, 128)
        self.down1 = nn.Sequential(
            nn.Conv2d(128, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.csp2 = CSPBlock(128, 128)
        self.csp3 = CSPBlock(128, 64)
        self.refine = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.final = nn.Sequential(nn.Conv2d(64, 3, 1), nn.Tanh())
    def forward(self, cover, secret):
        x = torch.cat((cover, secret), dim=1)
        x = self.initial(x)
        x = self.csp1(x)
        x = self.down1(x)
        x = self.csp2(x)
        x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        x = self.csp3(x)
        x = self.refine(x)
        return self.final(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AdaptiveAvgPool2d((4,4)),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1)
        )
    def forward(self, x): return self.main(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class UNetDecoder(nn.Module):
    def __init__(self, in_channels=3, base=64):
        super().__init__()
        self.enc1 = ConvBlock(in_channels, base)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base, base*2)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2)
        self.bottleneck = ConvBlock(base*4, base*8)
        self.up1 = nn.ConvTranspose2d(base*8, base*4, 2, 2)
        self.dec1 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, 2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up3 = nn.ConvTranspose2d(base*2, base, 2, 2)
        self.dec3 = ConvBlock(base*2, base)
        self.final_conv = nn.Conv2d(base, 3, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        d1 = self.up1(b)
        if d1.size(2)!=e3.size(2): e3 = F.interpolate(e3,size=d1.shape[2:])
        d1 = self.dec1(torch.cat([d1,e3],dim=1))
        d2 = self.up2(d1)
        if d2.size(2)!=e2.size(2): e2 = F.interpolate(e2,size=d2.shape[2:])
        d2 = self.dec2(torch.cat([d2,e2],dim=1))
        d3 = self.up3(d2)
        if d3.size(2)!=e1.size(2): e1 = F.interpolate(e1,size=d3.shape[2:])
        d3 = self.dec3(torch.cat([d3,e1],dim=1))
        return torch.tanh(self.final_conv(d3))

class VGGPerceptual(nn.Module):
    def __init__(self, device):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
        for p in vgg.parameters(): p.requires_grad = False
        self.vgg = vgg; self.layers = [4,9,16,23]
    def forward(self, x):
        feats=[]
        for i,l in enumerate(self.vgg):
            x=l(x)
            if i in self.layers: feats.append(x)
        return feats

# -----------------------------
# Initialize & train loop
# -----------------------------
G = Generator().to(device)
D = Discriminator().to(device)
Dec = UNetDecoder(in_channels=3, base=64).to(device)
VGG = VGGPerceptual(device)

print("Models initialized")

# Data split & loaders
val_size = int(0.2 * len(dataset))
train_ds, val_ds = random_split(dataset, [len(dataset)-val_size, val_size])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))
opt_Dec = optim.Adam(Dec.parameters(), lr=lr, betas=(0.5,0.999))

sched_G = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_G, 'max', factor=0.5, patience=4)
sched_Dec = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_Dec, 'max', factor=0.5, patience=4)

mse_loss = nn.MSELoss()       # for LSGAN
l1 = nn.L1Loss()
mse = nn.MSELoss()

scaler_G = GradScaler()
scaler_D = GradScaler()
scaler_Dec = GradScaler()

# LPIPS model init (if available)
lpips_model = None
if has_lpips:
    try:
        lpips_model = lpips.LPIPS(net='vgg').to(device)
    except Exception:
        lpips_model = None

# -----------------------------
# Training loop (with new eval metrics)
# -----------------------------
best_val_psnr = -1.0
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"runs_steg_{timestamp}"
os.makedirs(log_dir, exist_ok=True)

# Logging containers
train_D_losses, train_G_losses, train_Dec_losses = [], [], []
val_psnr_hist, val_ssim_hist = [], []
per_epoch_metrics = []

print("Starting training...")
for epoch in range(1, epochs+1):
    G.train(); D.train(); Dec.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
    iter_idx = 0
    for cover, _ in pbar:
        cover = cover.to(device)
        secret = torch.flip(cover, dims=[-1]).to(device)

        bs = cover.size(0)
        real_label = torch.full((bs,1), 0.9, device=device)
        fake_label = torch.zeros((bs,1), device=device)

        # ---- Train D ----
        opt_D.zero_grad()
        with autocast():
            with torch.no_grad():
                gen_out = G(cover, secret)
                stego_fake = torch.clamp(cover + 0.1 * gen_out, -1, 1)
            pred_real = D(cover)
            pred_fake = D(stego_fake)
            D_loss = 0.5 * (mse_loss(pred_real, real_label) + mse_loss(pred_fake, fake_label))
        scaler_D.scale(D_loss).backward(); scaler_D.step(opt_D); scaler_D.update()

        # ---- Train G ----
        opt_G.zero_grad()
        for p in Dec.parameters(): p.requires_grad = False
        with autocast():
            gen_out = G(cover, secret)
            stego = torch.clamp(cover + 0.1 * gen_out, -1, 1)
            pred = D(stego)
            adv_loss = mse_loss(pred, real_label)
            ssim_loss = 1 - ssim(stego, cover, data_range=2.0, size_average=True)
            recovered = Dec(stego)
            rec_loss = l1(recovered, secret)
            v_stego = VGG(tensor_to_vgg_input(stego))
            v_cover = VGG(tensor_to_vgg_input(cover))
            perc1 = sum(mse(a,b) for a,b in zip(v_stego, v_cover))
            v_rec = VGG(tensor_to_vgg_input(recovered))
            v_secret = VGG(tensor_to_vgg_input(secret))
            perc2 = sum(mse(a,b) for a,b in zip(v_rec, v_secret))
            perc_loss = (perc1 + perc2)
            G_loss = 0.6 * adv_loss + lambda_ssim * ssim_loss + lambda_rec * rec_loss + lambda_perc * perc_loss
        scaler_G.scale(G_loss).backward(); scaler_G.step(opt_G); scaler_G.update()
        for p in Dec.parameters(): p.requires_grad = True

        # ---- Train Decoder ----
        opt_Dec.zero_grad()
        with autocast():
            rec_pred = Dec(stego.detach())
            Dec_loss = l1(rec_pred, secret) + 0.1 * sum(mse(a,b) for a,b in zip(VGG(tensor_to_vgg_input(rec_pred)), VGG(tensor_to_vgg_input(secret))))
        scaler_Dec.scale(Dec_loss).backward(); scaler_Dec.step(opt_Dec); scaler_Dec.update()

        pbar.set_postfix({"D_loss": f"{D_loss.item():.4f}", "G_loss": f"{G_loss.item():.4f}", "Dec_loss": f"{Dec_loss.item():.4f}"})

        train_D_losses.append(D_loss.item())
        train_G_losses.append(G_loss.item())
        train_Dec_losses.append(Dec_loss.item())

        iter_idx += 1

    # ---- Validation ----
    G.eval(); Dec.eval()
    val_psnrs = []
    val_ssims = []
    val_nrpcs = []
    val_useis = []
    val_lpips = []
    val_recovery_accs = []
    val_bpps = []
    with torch.no_grad():
        for j, (vcov, _) in enumerate(val_loader):
            if j >= 50: break
            vcov = vcov.to(device)
            vsecret = torch.flip(vcov, dims=[-1]).to(device)
            gen_out = G(vcov, vsecret)
            vstego = torch.clamp(vcov + 0.1 * gen_out, -1, 1)
            vrecovered = Dec(vstego)

            # PSNR / SSIM / NRPC / USEI
            val_psnrs.append(psnr_t(vrecovered, vsecret, 2.0))
            try:
                val_ssims.append(ssim(vrecovered, vsecret, data_range=2.0, size_average=True).item())
            except Exception:
                # fallback to skimage SSIM per image
                a = ((vrecovered+1)/2).cpu().numpy()
                b = ((vsecret+1)/2).cpu().numpy()
                try:
                    ssum = np.mean([ssim_sk(a[i].transpose(1,2,0), b[i].transpose(1,2,0), multichannel=True) for i in range(a.shape[0])])
                    val_ssims.append(ssum)
                except Exception:
                    pass
            val_nrpcs.append(compute_nrpc(vcov, vstego))
            val_useis.append(compute_usei_uqi(vrecovered, vsecret))

            # LPIPS (optional)
            if has_lpips and lpips_model is not None:
                try:
                    val_lpips.append(compute_lpips(vstego, vcov, lpips_model))
                except Exception:
                    val_lpips.append(float('nan'))

            # Recovery accuracy (bitstream)
            try:
                acc = batch_bitstream_accuracy(vsecret, vrecovered)
                val_recovery_accs.append(acc)
            except Exception:
                val_recovery_accs.append(float('nan'))

            # Residual bpp estimate
            try:
                bpp = residual_bpp_estimate(vcov, vstego)
                val_bpps.append(bpp)
            except Exception:
                val_bpps.append(float('nan'))

    avg_val_psnr = float(np.nanmean(val_psnrs)) if len(val_psnrs)>0 else 0.0
    avg_val_ssim = float(np.nanmean(val_ssims)) if len(val_ssims)>0 else 0.0
    avg_val_nrpc = float(np.nanmean(val_nrpcs)) if len(val_nrpcs)>0 else float('nan')
    avg_val_usei = float(np.nanmean(val_useis)) if len(val_useis)>0 else float('nan')
    avg_val_lpips = float(np.nanmean(val_lpips)) if len(val_lpips)>0 else float('nan')
    avg_val_recovery_acc = float(np.nanmean(val_recovery_accs)) if len(val_recovery_accs)>0 else float('nan')
    avg_val_bpp = float(np.nanmean(val_bpps)) if len(val_bpps)>0 else float('nan')

    print(f"Epoch {epoch}: Avg Val PSNR = {avg_val_psnr:.3f} dB | SSIM = {avg_val_ssim:.4f} | NRPC = {avg_val_nrpc:.4f} | USEI(UQI) = {avg_val_usei:.4f} | LPIPS = {avg_val_lpips:.4f} | RecAcc = {avg_val_recovery_acc:.4f} | bpp = {avg_val_bpp:.4f}")

    val_psnr_hist.append(avg_val_psnr)
    val_ssim_hist.append(avg_val_ssim)

    per_epoch_metrics.append({
        'epoch': epoch,
        'psnr': avg_val_psnr,
        'ssim': avg_val_ssim,
        'nrpc': avg_val_nrpc,
        'usei': avg_val_usei,
        'lpips': avg_val_lpips,
        'recovery_acc': avg_val_recovery_acc,
        'bpp': avg_val_bpp
    })

    sched_G.step(avg_val_psnr)
    sched_Dec.step(avg_val_psnr)

    # save best
    if avg_val_psnr > best_val_psnr:
        best_val_psnr = avg_val_psnr
        torch.save(G.state_dict(), os.path.join(log_dir, "G_best.pth"))
        torch.save(Dec.state_dict(), os.path.join(log_dir, "Dec_best.pth"))
        torch.save(D.state_dict(), os.path.join(log_dir, "D_best.pth"))
        print(f"Saved best models (PSNR {best_val_psnr:.2f} dB)")

    # periodic checkpoint
    if epoch % save_every == 0:
        torch.save({
            "G": G.state_dict(),
            "Dec": Dec.state_dict(),
            "D": D.state_dict(),
            "opt_G": opt_G.state_dict(),
            "opt_Dec": opt_Dec.state_dict(),
            "opt_D": opt_D.state_dict(),
            "epoch": epoch
        }, os.path.join(log_dir, f"checkpoint_epoch_{epoch}.pth"))

print("Training finished. Best PSNR:", best_val_psnr)

# -----------------------------
# Final Inference / Visualization
# -----------------------------
G.load_state_dict(torch.load(os.path.join(log_dir,"G_best.pth"), map_location=device))
Dec.load_state_dict(torch.load(os.path.join(log_dir,"Dec_best.pth"), map_location=device))
G.eval(); Dec.eval()

sample, _ = next(iter(val_loader))
cover_samp = sample[:8].to(device)
secret_samp = torch.flip(cover_samp, dims=[-1]).to(device)
with torch.no_grad():
    gen_out = G(cover_samp, secret_samp)
    stego_samp = torch.clamp(cover_samp + 0.1 * gen_out, -1, 1)
    recovered_samp = Dec(stego_samp)

# Save sample grid
try:
    grid = make_grid(torch.cat([cover_samp, secret_samp, stego_samp, recovered_samp], dim=0), nrow=8)
    save_path = os.path.join(log_dir, "stegano_results_best.png")
    save_image(((grid+1)/2).clamp(0,1), save_path)
    print("Saved visualization to", save_path)
except Exception as e:
    print("Failed to save sample grid:", e)

# per-epoch metrics CSV
metrics_df = pd.DataFrame(per_epoch_metrics)
metrics_df.to_csv(os.path.join(log_dir, 'per_epoch_metrics.csv'), index=False)

# -----------------------------
# Plotting: one graph per metric (saved individually)
# -----------------------------
# Helper for plotting
def save_metric_plot(x, y, title, ylabel, save_path):
    plt.figure(figsize=(8,5))
    plt.plot(x, y, marker='o')
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel(ylabel)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

epochs_arr = metrics_df['epoch'].tolist()

# PSNR
save_metric_plot(epochs_arr, metrics_df['psnr'].tolist(), 'Validation PSNR across epochs', 'PSNR (dB)', os.path.join(log_dir, 'val_psnr.png'))

# SSIM
save_metric_plot(epochs_arr, metrics_df['ssim'].tolist(), 'Validation SSIM across epochs', 'SSIM', os.path.join(log_dir, 'val_ssim.png'))

# NRPC
save_metric_plot(epochs_arr, metrics_df['nrpc'].tolist(), 'Validation NRPC across epochs', 'NRPC', os.path.join(log_dir, 'val_nrpc.png'))

# USEI / UQI
save_metric_plot(epochs_arr, metrics_df['usei'].tolist(), 'Validation USEI (UQI) across epochs', 'USEI (UQI)', os.path.join(log_dir, 'val_usei.png'))

# LPIPS
save_metric_plot(epochs_arr, metrics_df['lpips'].tolist(), 'Validation LPIPS across epochs', 'LPIPS', os.path.join(log_dir, 'val_lpips.png'))

# Recovery accuracy (bitstream)
save_metric_plot(epochs_arr, metrics_df['recovery_acc'].tolist(), 'Recovery Bitstream Accuracy across epochs', 'Accuracy (fraction)', os.path.join(log_dir, 'val_recovery_acc.png'))

# Residual capacity (bpp)
save_metric_plot(epochs_arr, metrics_df['bpp'].tolist(), 'Estimated Residual Capacity (bpp) across epochs', 'bits per pixel (bpp)', os.path.join(log_dir, 'val_bpp.png'))

# Combined training losses plot (per iteration)
plt.figure(figsize=(14,6))
plt.subplot(1,1,1)
plt.plot(train_G_losses, label='Generator Loss')
plt.plot(train_D_losses, label='Discriminator Loss')
plt.plot(train_Dec_losses, label='Decoder Loss')
plt.title('Training Losses (per iteration)')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(log_dir, 'training_losses_per_iter.png'))
plt.close()

# Histogram of residuals (sample)
try:
    resid = ((stego_samp - cover_samp).detach().cpu().numpy()).flatten()
    plt.hist(resid, bins=100)
    plt.title('Histogram of Residuals (stego - cover) [sample]')
    plt.savefig(os.path.join(log_dir, 'residual_hist.png'))
    plt.close()
except Exception as e:
    print('Failed residual histogram:', e)

print('All plots, CSVs and sample images saved to', log_dir)


Device: cuda img_size: 128
Loaded images: 50000
Models initialized
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]


  scaler_G = GradScaler()
  scaler_D = GradScaler()
  scaler_Dec = GradScaler()


Loading model from: /home/user7/.local/lib/python3.12/site-packages/lpips/weights/v0.1/vgg.pth
Starting training...


  with autocast():
  with autocast():
  with autocast():
Epoch 1/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0800, G_loss=20.2044, Dec_loss=3.2343]


Epoch 1: Avg Val PSNR = 20.014 dB | SSIM = 0.6705 | NRPC = 0.9967 | USEI(UQI) = 0.8954 | LPIPS = 0.0473 | RecAcc = 0.5974 | bpp = 10.6102
Saved best models (PSNR 20.01 dB)


Epoch 2/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0594, G_loss=12.5293, Dec_loss=2.0196]


Epoch 2: Avg Val PSNR = 21.015 dB | SSIM = 0.6813 | NRPC = 0.9978 | USEI(UQI) = 0.9372 | LPIPS = 0.0171 | RecAcc = 0.5962 | bpp = 9.8595
Saved best models (PSNR 21.02 dB)


Epoch 3/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0183, G_loss=7.1529, Dec_loss=1.1120]


Epoch 3: Avg Val PSNR = 25.762 dB | SSIM = 0.8347 | NRPC = 0.9980 | USEI(UQI) = 0.9765 | LPIPS = 0.0084 | RecAcc = 0.6507 | bpp = 9.7024
Saved best models (PSNR 25.76 dB)


Epoch 4/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.77it/s, D_loss=0.0214, G_loss=4.8663, Dec_loss=0.6713]


Epoch 4: Avg Val PSNR = 27.572 dB | SSIM = 0.8690 | NRPC = 0.9985 | USEI(UQI) = 0.9824 | LPIPS = 0.0054 | RecAcc = 0.6600 | bpp = 9.2187
Saved best models (PSNR 27.57 dB)


Epoch 5/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0142, G_loss=5.2876, Dec_loss=0.6897]


Epoch 5: Avg Val PSNR = 26.312 dB | SSIM = 0.8612 | NRPC = 0.9988 | USEI(UQI) = 0.9812 | LPIPS = 0.0040 | RecAcc = 0.6529 | bpp = 8.7385


Epoch 6/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0077, G_loss=4.2299, Dec_loss=0.6005]


Epoch 6: Avg Val PSNR = 26.469 dB | SSIM = 0.8570 | NRPC = 0.9991 | USEI(UQI) = 0.9847 | LPIPS = 0.0030 | RecAcc = 0.6501 | bpp = 8.3252


Epoch 7/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0175, G_loss=3.5320, Dec_loss=0.5018]


Epoch 7: Avg Val PSNR = 28.186 dB | SSIM = 0.8734 | NRPC = 0.9992 | USEI(UQI) = 0.9876 | LPIPS = 0.0028 | RecAcc = 0.6562 | bpp = 8.1805
Saved best models (PSNR 28.19 dB)


Epoch 8/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0248, G_loss=4.0483, Dec_loss=0.5422]


Epoch 8: Avg Val PSNR = 29.338 dB | SSIM = 0.9097 | NRPC = 0.9994 | USEI(UQI) = 0.9879 | LPIPS = 0.0022 | RecAcc = 0.6638 | bpp = 7.6404
Saved best models (PSNR 29.34 dB)


Epoch 9/50: 100%|██████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.74it/s, D_loss=0.0604, G_loss=2.6578, Dec_loss=0.3565]


Epoch 9: Avg Val PSNR = 30.275 dB | SSIM = 0.8997 | NRPC = 0.9994 | USEI(UQI) = 0.9915 | LPIPS = 0.0022 | RecAcc = 0.6790 | bpp = 7.6609
Saved best models (PSNR 30.27 dB)


Epoch 10/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.74it/s, D_loss=0.0148, G_loss=2.4977, Dec_loss=0.3062]


Epoch 10: Avg Val PSNR = 30.425 dB | SSIM = 0.9048 | NRPC = 0.9995 | USEI(UQI) = 0.9915 | LPIPS = 0.0019 | RecAcc = 0.6843 | bpp = 7.4537
Saved best models (PSNR 30.43 dB)


Epoch 11/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0176, G_loss=3.9327, Dec_loss=0.5483]


Epoch 11: Avg Val PSNR = 29.574 dB | SSIM = 0.9073 | NRPC = 0.9995 | USEI(UQI) = 0.9897 | LPIPS = 0.0022 | RecAcc = 0.6771 | bpp = 7.3075


Epoch 12/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.74it/s, D_loss=0.0226, G_loss=3.4129, Dec_loss=0.4362]


Epoch 12: Avg Val PSNR = 28.513 dB | SSIM = 0.8953 | NRPC = 0.9996 | USEI(UQI) = 0.9892 | LPIPS = 0.0016 | RecAcc = 0.6675 | bpp = 6.7796


Epoch 13/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.75it/s, D_loss=0.0304, G_loss=2.8647, Dec_loss=0.3530]


Epoch 13: Avg Val PSNR = 29.883 dB | SSIM = 0.8979 | NRPC = 0.9997 | USEI(UQI) = 0.9916 | LPIPS = 0.0015 | RecAcc = 0.6810 | bpp = 6.7420


Epoch 14/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0271, G_loss=2.6150, Dec_loss=0.3675]


Epoch 14: Avg Val PSNR = 27.781 dB | SSIM = 0.8748 | NRPC = 0.9997 | USEI(UQI) = 0.9879 | LPIPS = 0.0013 | RecAcc = 0.6432 | bpp = 6.5148


Epoch 15/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.73it/s, D_loss=0.0412, G_loss=3.2388, Dec_loss=0.4117]


Epoch 15: Avg Val PSNR = 30.883 dB | SSIM = 0.9197 | NRPC = 0.9997 | USEI(UQI) = 0.9926 | LPIPS = 0.0012 | RecAcc = 0.6824 | bpp = 6.3689
Saved best models (PSNR 30.88 dB)


Epoch 16/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0251, G_loss=2.4471, Dec_loss=0.3189]


Epoch 16: Avg Val PSNR = 31.254 dB | SSIM = 0.9284 | NRPC = 0.9998 | USEI(UQI) = 0.9933 | LPIPS = 0.0012 | RecAcc = 0.6904 | bpp = 6.3466
Saved best models (PSNR 31.25 dB)


Epoch 17/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:51<00:00,  5.73it/s, D_loss=0.0151, G_loss=2.3545, Dec_loss=0.2823]


Epoch 17: Avg Val PSNR = 28.740 dB | SSIM = 0.8994 | NRPC = 0.9998 | USEI(UQI) = 0.9910 | LPIPS = 0.0012 | RecAcc = 0.6702 | bpp = 6.3728


Epoch 18/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:51<00:00,  5.73it/s, D_loss=0.0332, G_loss=3.2582, Dec_loss=0.4457]


Epoch 18: Avg Val PSNR = 28.170 dB | SSIM = 0.8695 | NRPC = 0.9998 | USEI(UQI) = 0.9896 | LPIPS = 0.0011 | RecAcc = 0.6602 | bpp = 6.2520


Epoch 19/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.73it/s, D_loss=0.0268, G_loss=2.2725, Dec_loss=0.2757]


Epoch 19: Avg Val PSNR = 32.326 dB | SSIM = 0.9358 | NRPC = 0.9998 | USEI(UQI) = 0.9942 | LPIPS = 0.0012 | RecAcc = 0.7017 | bpp = 6.3793
Saved best models (PSNR 32.33 dB)


Epoch 20/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0209, G_loss=2.5116, Dec_loss=0.3313]


Epoch 20: Avg Val PSNR = 32.225 dB | SSIM = 0.9355 | NRPC = 0.9998 | USEI(UQI) = 0.9940 | LPIPS = 0.0011 | RecAcc = 0.7022 | bpp = 6.2105


Epoch 21/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0170, G_loss=3.0645, Dec_loss=0.4324]


Epoch 21: Avg Val PSNR = 28.094 dB | SSIM = 0.8543 | NRPC = 0.9998 | USEI(UQI) = 0.9894 | LPIPS = 0.0009 | RecAcc = 0.6408 | bpp = 5.9063


Epoch 22/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.73it/s, D_loss=0.0133, G_loss=2.3553, Dec_loss=0.2824]


Epoch 22: Avg Val PSNR = 31.516 dB | SSIM = 0.9326 | NRPC = 0.9998 | USEI(UQI) = 0.9938 | LPIPS = 0.0010 | RecAcc = 0.6957 | bpp = 6.1052


Epoch 23/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.77it/s, D_loss=0.0180, G_loss=2.1360, Dec_loss=0.2516]


Epoch 23: Avg Val PSNR = 30.201 dB | SSIM = 0.9000 | NRPC = 0.9998 | USEI(UQI) = 0.9925 | LPIPS = 0.0010 | RecAcc = 0.6806 | bpp = 5.9982


Epoch 24/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.79it/s, D_loss=0.0140, G_loss=2.1337, Dec_loss=0.2450]


Epoch 24: Avg Val PSNR = 30.749 dB | SSIM = 0.9171 | NRPC = 0.9998 | USEI(UQI) = 0.9937 | LPIPS = 0.0009 | RecAcc = 0.6871 | bpp = 5.9890


Epoch 25/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0143, G_loss=1.8427, Dec_loss=0.2158]


Epoch 25: Avg Val PSNR = 30.931 dB | SSIM = 0.9142 | NRPC = 0.9998 | USEI(UQI) = 0.9934 | LPIPS = 0.0009 | RecAcc = 0.6890 | bpp = 5.9112


Epoch 26/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.78it/s, D_loss=0.0161, G_loss=2.2794, Dec_loss=0.2873]


Epoch 26: Avg Val PSNR = 32.791 dB | SSIM = 0.9388 | NRPC = 0.9998 | USEI(UQI) = 0.9951 | LPIPS = 0.0008 | RecAcc = 0.7028 | bpp = 5.8141
Saved best models (PSNR 32.79 dB)


Epoch 27/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.78it/s, D_loss=0.0170, G_loss=1.9874, Dec_loss=0.2137]


Epoch 27: Avg Val PSNR = 33.492 dB | SSIM = 0.9507 | NRPC = 0.9998 | USEI(UQI) = 0.9956 | LPIPS = 0.0008 | RecAcc = 0.7114 | bpp = 5.7824
Saved best models (PSNR 33.49 dB)


Epoch 28/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0059, G_loss=2.0837, Dec_loss=0.2544]


Epoch 28: Avg Val PSNR = 33.776 dB | SSIM = 0.9480 | NRPC = 0.9998 | USEI(UQI) = 0.9958 | LPIPS = 0.0008 | RecAcc = 0.7186 | bpp = 5.7192
Saved best models (PSNR 33.78 dB)


Epoch 29/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0227, G_loss=2.4988, Dec_loss=0.3024]


Epoch 29: Avg Val PSNR = 33.651 dB | SSIM = 0.9554 | NRPC = 0.9998 | USEI(UQI) = 0.9956 | LPIPS = 0.0008 | RecAcc = 0.7145 | bpp = 5.7899


Epoch 30/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.77it/s, D_loss=0.0094, G_loss=1.7996, Dec_loss=0.2085]


Epoch 30: Avg Val PSNR = 34.243 dB | SSIM = 0.9562 | NRPC = 0.9999 | USEI(UQI) = 0.9959 | LPIPS = 0.0008 | RecAcc = 0.7210 | bpp = 5.7499
Saved best models (PSNR 34.24 dB)


Epoch 31/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.77it/s, D_loss=0.0093, G_loss=2.0544, Dec_loss=0.2423]


Epoch 31: Avg Val PSNR = 34.611 dB | SSIM = 0.9589 | NRPC = 0.9998 | USEI(UQI) = 0.9963 | LPIPS = 0.0008 | RecAcc = 0.7260 | bpp = 5.7508
Saved best models (PSNR 34.61 dB)


Epoch 32/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:47<00:00,  5.79it/s, D_loss=0.0120, G_loss=2.3455, Dec_loss=0.2827]


Epoch 32: Avg Val PSNR = 34.358 dB | SSIM = 0.9527 | NRPC = 0.9998 | USEI(UQI) = 0.9963 | LPIPS = 0.0008 | RecAcc = 0.7230 | bpp = 5.7685


Epoch 33/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0103, G_loss=2.1320, Dec_loss=0.2492]


Epoch 33: Avg Val PSNR = 30.770 dB | SSIM = 0.9117 | NRPC = 0.9999 | USEI(UQI) = 0.9937 | LPIPS = 0.0008 | RecAcc = 0.6776 | bpp = 5.7475


Epoch 34/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.77it/s, D_loss=0.0064, G_loss=1.8470, Dec_loss=0.2190]


Epoch 34: Avg Val PSNR = 34.850 dB | SSIM = 0.9613 | NRPC = 0.9999 | USEI(UQI) = 0.9965 | LPIPS = 0.0007 | RecAcc = 0.7286 | bpp = 5.6867
Saved best models (PSNR 34.85 dB)


Epoch 35/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0165, G_loss=1.9447, Dec_loss=0.2396]


Epoch 35: Avg Val PSNR = 31.237 dB | SSIM = 0.9266 | NRPC = 0.9999 | USEI(UQI) = 0.9940 | LPIPS = 0.0007 | RecAcc = 0.6853 | bpp = 5.7081


Epoch 36/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.78it/s, D_loss=0.0084, G_loss=2.0742, Dec_loss=0.2474]


Epoch 36: Avg Val PSNR = 33.802 dB | SSIM = 0.9469 | NRPC = 0.9999 | USEI(UQI) = 0.9963 | LPIPS = 0.0007 | RecAcc = 0.7158 | bpp = 5.6481


Epoch 37/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0089, G_loss=1.9894, Dec_loss=0.2414]


Epoch 37: Avg Val PSNR = 33.933 dB | SSIM = 0.9521 | NRPC = 0.9999 | USEI(UQI) = 0.9957 | LPIPS = 0.0008 | RecAcc = 0.7179 | bpp = 5.7943


Epoch 38/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0153, G_loss=2.2089, Dec_loss=0.2360]


Epoch 38: Avg Val PSNR = 32.743 dB | SSIM = 0.9323 | NRPC = 0.9999 | USEI(UQI) = 0.9954 | LPIPS = 0.0007 | RecAcc = 0.7133 | bpp = 5.7020


Epoch 39/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.77it/s, D_loss=0.0148, G_loss=2.4763, Dec_loss=0.3051]


Epoch 39: Avg Val PSNR = 32.116 dB | SSIM = 0.9462 | NRPC = 0.9999 | USEI(UQI) = 0.9936 | LPIPS = 0.0007 | RecAcc = 0.6887 | bpp = 5.5550


Epoch 40/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:47<00:00,  5.79it/s, D_loss=0.0038, G_loss=1.8854, Dec_loss=0.2118]


Epoch 40: Avg Val PSNR = 34.534 dB | SSIM = 0.9579 | NRPC = 0.9999 | USEI(UQI) = 0.9965 | LPIPS = 0.0007 | RecAcc = 0.7285 | bpp = 5.6358


Epoch 41/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:48<00:00,  5.77it/s, D_loss=0.0047, G_loss=1.8785, Dec_loss=0.2110]


Epoch 41: Avg Val PSNR = 34.985 dB | SSIM = 0.9623 | NRPC = 0.9999 | USEI(UQI) = 0.9966 | LPIPS = 0.0007 | RecAcc = 0.7329 | bpp = 5.6182
Saved best models (PSNR 34.98 dB)


Epoch 42/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0102, G_loss=1.7441, Dec_loss=0.2067]


Epoch 42: Avg Val PSNR = 34.470 dB | SSIM = 0.9523 | NRPC = 0.9999 | USEI(UQI) = 0.9965 | LPIPS = 0.0007 | RecAcc = 0.7287 | bpp = 5.5752


Epoch 43/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0071, G_loss=1.9017, Dec_loss=0.2399]


Epoch 43: Avg Val PSNR = 32.404 dB | SSIM = 0.9378 | NRPC = 0.9999 | USEI(UQI) = 0.9951 | LPIPS = 0.0007 | RecAcc = 0.6901 | bpp = 5.5960


Epoch 44/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:45<00:00,  5.85it/s, D_loss=0.0105, G_loss=1.8067, Dec_loss=0.1970]


Epoch 44: Avg Val PSNR = 34.662 dB | SSIM = 0.9542 | NRPC = 0.9999 | USEI(UQI) = 0.9966 | LPIPS = 0.0007 | RecAcc = 0.7305 | bpp = 5.6245


Epoch 45/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0034, G_loss=1.7883, Dec_loss=0.2024]


Epoch 45: Avg Val PSNR = 35.105 dB | SSIM = 0.9641 | NRPC = 0.9999 | USEI(UQI) = 0.9968 | LPIPS = 0.0007 | RecAcc = 0.7344 | bpp = 5.5568
Saved best models (PSNR 35.10 dB)


Epoch 46/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.76it/s, D_loss=0.0054, G_loss=2.0299, Dec_loss=0.2426]


Epoch 46: Avg Val PSNR = 34.976 dB | SSIM = 0.9600 | NRPC = 0.9999 | USEI(UQI) = 0.9968 | LPIPS = 0.0007 | RecAcc = 0.7287 | bpp = 5.5496


Epoch 47/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.74it/s, D_loss=0.0044, G_loss=1.7168, Dec_loss=0.1827]


Epoch 47: Avg Val PSNR = 35.823 dB | SSIM = 0.9688 | NRPC = 0.9999 | USEI(UQI) = 0.9972 | LPIPS = 0.0007 | RecAcc = 0.7416 | bpp = 5.6065
Saved best models (PSNR 35.82 dB)


Epoch 48/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.74it/s, D_loss=0.0061, G_loss=2.1538, Dec_loss=0.2673]


Epoch 48: Avg Val PSNR = 34.106 dB | SSIM = 0.9515 | NRPC = 0.9999 | USEI(UQI) = 0.9961 | LPIPS = 0.0007 | RecAcc = 0.7245 | bpp = 5.6055


Epoch 49/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:49<00:00,  5.75it/s, D_loss=0.0047, G_loss=1.6088, Dec_loss=0.1662]


Epoch 49: Avg Val PSNR = 34.874 dB | SSIM = 0.9629 | NRPC = 0.9999 | USEI(UQI) = 0.9970 | LPIPS = 0.0007 | RecAcc = 0.7255 | bpp = 5.5379


Epoch 50/50: 100%|█████████████████████████████████████████████████████████| 1667/1667 [04:50<00:00,  5.75it/s, D_loss=0.0062, G_loss=1.7635, Dec_loss=0.1813]


Epoch 50: Avg Val PSNR = 34.877 dB | SSIM = 0.9577 | NRPC = 0.9999 | USEI(UQI) = 0.9969 | LPIPS = 0.0007 | RecAcc = 0.7289 | bpp = 5.6439
Training finished. Best PSNR: 35.82306431723115


  G.load_state_dict(torch.load(os.path.join(log_dir,"G_best.pth"), map_location=device))
  Dec.load_state_dict(torch.load(os.path.join(log_dir,"Dec_best.pth"), map_location=device))


Saved visualization to runs_steg_20251114_114126/stegano_results_best.png
All plots, CSVs and sample images saved to runs_steg_20251114_114126
