In [35]:
import os
import glob
import random
import math
import shutil

from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

# Reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


RuntimeError: CUDA error: unknown error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [10]:
class PairedImageDataset(Dataset):
    def __init__(self, lr_root, hr_root, training=True):
        super().__init__()
        self.lr_root = lr_root
        self.hr_root = hr_root
        self.training = training

        self.lr_paths = sorted(
            [p for p in glob.glob(os.path.join(lr_root, "*")) if os.path.isfile(p)]
        )
        self.hr_paths = sorted(
            [p for p in glob.glob(os.path.join(hr_root, "*")) if os.path.isfile(p)]
        )

        assert len(self.lr_paths) == len(self.hr_paths) and len(self.lr_paths) > 0, \
            "Mismatch or empty LR/HR folders"

        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.lr_paths)

    def _random_augment(self, lr_img, hr_img):
        # simple, symmetric augments
        if random.random() < 0.5:
            lr_img = lr_img.transpose(Image.FLIP_LEFT_RIGHT)
            hr_img = hr_img.transpose(Image.FLIP_LEFT_RIGHT)
        if random.random() < 0.5:
            lr_img = lr_img.transpose(Image.FLIP_TOP_BOTTOM)
            hr_img = hr_img.transpose(Image.FLIP_TOP_BOTTOM)
        k = random.randint(0, 3)
        if k:
            lr_img = lr_img.rotate(90 * k, expand=False)
            hr_img = hr_img.rotate(90 * k, expand=False)
        return lr_img, hr_img

    def __getitem__(self, idx):
        lr_path = self.lr_paths[idx]
        hr_path = self.hr_paths[idx]

        lr = Image.open(lr_path).convert("RGB")
        hr = Image.open(hr_path).convert("RGB")

        target_size = (256, 256)
        if lr.size != target_size:
            lr = lr.resize(target_size, Image.BICUBIC)
        if hr.size != target_size:
            hr = hr.resize(target_size, Image.BICUBIC)

        if self.training:
            lr, hr = self._random_augment(lr, hr)

        lr = self.to_tensor(lr)  # [0,1]
        hr = self.to_tensor(hr)

        return lr, hr


In [11]:
train_lr_root = "LR_train/train"
train_hr_root = "HR_train/train"
val_lr_root   = "LR_val/val"
val_hr_root   = "HR_val/val"

train_dataset = PairedImageDataset(train_lr_root, train_hr_root, training=True)
val_dataset   = PairedImageDataset(val_lr_root,   val_hr_root,   training=False)

batch_size = 8

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=1,
                          shuffle=False, num_workers=2, pin_memory=True)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))


Train samples: 258
Val samples: 100


In [12]:
# ----- Swish activation (NPU-friendly: uses Sigmoid + Mul only) -----
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


# ----- Depthwise-separable Conv (DW + PW) -----
class DWConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        # depthwise
        self.dw = nn.Conv2d(
            in_channels, in_channels,
            kernel_size=kernel_size,
            padding=padding,
            groups=in_channels
        )
        # pointwise
        self.pw = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)
        return x


# ----- AGDN-style nano block -----
class AGDBNano(nn.Module):
    def __init__(self, channels, distill_ratio=0.5):
        """
        channels: feature channels (very small, e.g. 24–40)
        distill_ratio: fraction of channels considered "distilled"
        """
        super().__init__()
        self.channels = channels
        self.distilled_channels = int(channels * distill_ratio)
        self.remaining_channels = channels - self.distilled_channels

        self.conv_in = DWConv(channels, channels, kernel_size=3, padding=1)
        self.act = Swish()

        self.conv_distill = DWConv(
            self.distilled_channels, self.distilled_channels,
            kernel_size=3, padding=1
        )
        self.conv_remain = DWConv(
            self.remaining_channels, self.remaining_channels,
            kernel_size=3, padding=1
        )

        # Channel attention (SE-like, tiny bottleneck)
        ca_mid = max(channels // 4, 4)
        self.ca_conv1 = nn.Conv2d(channels, ca_mid, kernel_size=1)
        self.ca_conv2 = nn.Conv2d(ca_mid, channels, kernel_size=1)

        # Multi-scale spatial attention (3x3, 5x5, 7x7)
        self.sa_conv3 = nn.Conv2d(1, 1, kernel_size=3, padding=1)
        self.sa_conv5 = nn.Conv2d(1, 1, kernel_size=5, padding=2)
        self.sa_conv7 = nn.Conv2d(1, 1, kernel_size=7, padding=3)
        self.sa_fuse  = nn.Conv2d(3, 1, kernel_size=1)

        self.conv_out = DWConv(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x

        x = self.act(self.conv_in(x))

        # Channel distillation
        x_d, x_r = torch.split(
            x,
            [self.distilled_channels, self.remaining_channels],
            dim=1
        )
        x_d = self.act(self.conv_distill(x_d))
        x_r = self.act(self.conv_remain(x_r))
        x = torch.cat([x_d, x_r], dim=1)

        # Channel attention
        ca = x.mean(dim=[2, 3], keepdim=True)   # ReduceMean spatial
        ca = self.act(self.ca_conv1(ca))
        ca = torch.sigmoid(self.ca_conv2(ca))
        x = x * ca                               # Mul

        # Multi-scale spatial attention
        s = x.mean(dim=1, keepdim=True)         # ReduceMean channels
        s3 = self.sa_conv3(s)
        s5 = self.sa_conv5(s)
        s7 = self.sa_conv7(s)
        s_cat = torch.cat([s3, s5, s7], dim=1)  # Concat
        sa = torch.sigmoid(self.sa_fuse(s_cat))
        x = x * sa                               # Mul

        x = self.conv_out(x)
        return x + residual                      # Add residual


class AGDNNanoRestorer(nn.Module):
    """
    Ultra-light AGDN-style restorer for 256x256 -> 256x256.
    Default config (base_channels=32, num_blocks=6):
        ~0.028M params (inside 0.01–0.05M).
    """
    def __init__(self,
                 in_channels=3,
                 out_channels=3,
                 base_channels=32,
                 num_blocks=6,
                 distill_ratio=0.5):
        super().__init__()

        self.shallow = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([
            AGDBNano(base_channels, distill_ratio=distill_ratio)
            for _ in range(num_blocks)
        ])

        self.conv_mid = DWConv(base_channels, base_channels, kernel_size=3, padding=1)
        self.conv_out = nn.Conv2d(base_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        inp = x
        feat = self.shallow(x)

        x_mid = feat
        for block in self.blocks:
            x_mid = block(x_mid)

        x_mid = self.conv_mid(x_mid)
        feat = feat + x_mid

        out = self.conv_out(feat)

        # predict residual and add LR back
        return inp + out


In [13]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# You can adjust base_channels / num_blocks if needed
base_channels = 32
num_blocks = 6

model = AGDNNanoRestorer(
    in_channels=3,
    out_channels=3,
    base_channels=base_channels,
    num_blocks=num_blocks,
    distill_ratio=0.5,
).to(device)

params = count_parameters(model)
print(f"Model parameters: {params}  ({params/1e6:.6f} M)")

# quick shape sanity check
x = torch.randn(1, 3, 256, 256).to(device)
with torch.no_grad():
    y = model(x)
print("Output shape:", y.shape)


Model parameters: 28687  (0.028687 M)
Output shape: torch.Size([1, 3, 256, 256])


In [14]:
def conv_flops(in_c, out_c, k, h, w, groups=1):
    # MACs = (in_c/groups) * out_c * k * k * h * w
    # FLOPs ≈ 2 * MACs (multiply + add)
    macs = (in_c // groups) * out_c * k * k * h * w
    return 2 * macs

def estimate_flops_256(model, base_channels, num_blocks):
    H = W = 256
    flops = 0

    # shallow: 3 -> base_channels, 3x3
    flops += conv_flops(3, base_channels, 3, H, W, groups=1)

    # each AGDBNano block:
    for _ in range(num_blocks):
        c = base_channels
        cd = c // 2
        cr = c - cd

        # conv_in (DW + PW)
        flops += conv_flops(c, c, 3, H, W, groups=c)   # depthwise
        flops += conv_flops(c, c, 1, H, W, groups=1)   # pointwise

        # conv_distill + conv_remain
        for ch in (cd, cr):
            flops += conv_flops(ch, ch, 3, H, W, groups=ch)
            flops += conv_flops(ch, ch, 1, H, W, groups=1)

        # channel attention (1x1 convs on pooled features: negligible for FLOPs)

        # spatial attention: 3x3,5x5,7x7 on 1 channel
        for k in (3, 5, 7):
            flops += conv_flops(1, 1, k, H, W, groups=1)
        # fuse conv 3->1
        flops += conv_flops(3, 1, 1, H, W, groups=1)

        # conv_out
        flops += conv_flops(c, c, 3, H, W, groups=c)
        flops += conv_flops(c, c, 1, H, W, groups=1)

    # conv_mid (DW + PW)
    c = base_channels
    flops += conv_flops(c, c, 3, H, W, groups=c)
    flops += conv_flops(c, c, 1, H, W, groups=1)

    # conv_out (3x3 standard)
    flops += conv_flops(c, 3, 3, H, W, groups=1)

    return flops

flops = estimate_flops_256(model, base_channels=base_channels, num_blocks=num_blocks)
print(f"Estimated FLOPs for 1x3x256x256: {flops/1e9:.4f} GFLOPs")


Estimated FLOPs for 1x3x256x256: 3.1588 GFLOPs


In [15]:
# PSNR (RGB) – range [0,1]
def calc_psnr(pred, target):
    pred = torch.clamp(pred, 0.0, 1.0)
    target = torch.clamp(target, 0.0, 1.0)
    mse = torch.mean((pred - target) ** 2).item()
    if mse == 0:
        return 99.0
    return 10 * math.log10(1.0 / mse)

criterion = nn.MSELoss()   # aligns with PSNR

# optimizer & LR schedule – good for small models and PSNR
initial_lr = 5e-4
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)

# Cosine annealing over N epochs (you can tweak)
num_epochs = 1000
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=num_epochs, eta_min=1e-6
)

print("Ready for training.")


Ready for training.


In [16]:
os.makedirs("checkpoints_small_v2", exist_ok=True)

best_val_psnr = 0.0

for epoch in range(1, num_epochs + 1):
    # ---- TRAIN ----
    model.train()
    running_loss = 0.0

    for lr_img, hr_img in train_loader:
        lr_img = lr_img.to(device, non_blocking=True)
        hr_img = hr_img.to(device, non_blocking=True)

        optimizer.zero_grad()
        sr = model(lr_img)
        loss = criterion(sr, hr_img)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * lr_img.size(0)

    train_loss = running_loss / len(train_loader.dataset)

    # ---- VAL ----
    model.eval()
    val_psnr_sum = 0.0

    with torch.no_grad():
        for lr_img, hr_img in val_loader:
            lr_img = lr_img.to(device, non_blocking=True)
            hr_img = hr_img.to(device, non_blocking=True)

            sr = model(lr_img)
            val_psnr_sum += calc_psnr(sr, hr_img)

    val_psnr = val_psnr_sum / len(val_loader)

    # step LR
    scheduler.step()

    print(f"Epoch [{epoch}/{num_epochs}] "
          f"Train MSE: {train_loss:.6f}  Val PSNR: {val_psnr:.4f} dB")

    # save last
    torch.save(
        {
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "val_psnr": val_psnr,
        },
        "checkpoints_small_v2/last_model.pt",
    )

    # save best
    if val_psnr > best_val_psnr:
        best_val_psnr = val_psnr
        torch.save(
            {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "val_psnr": val_psnr,
            },
            "checkpoints_small_v2/best_model.pt",
        )
        print(f"  -> New BEST model (PSNR={best_val_psnr:.4f} dB)")


Epoch [1/1000] Train MSE: 0.020609  Val PSNR: 18.7444 dB
  -> New BEST model (PSNR=18.7444 dB)
Epoch [2/1000] Train MSE: 0.014447  Val PSNR: 19.0916 dB
  -> New BEST model (PSNR=19.0916 dB)
Epoch [3/1000] Train MSE: 0.013050  Val PSNR: 19.3848 dB
  -> New BEST model (PSNR=19.3848 dB)
Epoch [4/1000] Train MSE: 0.012006  Val PSNR: 19.6161 dB
  -> New BEST model (PSNR=19.6161 dB)
Epoch [5/1000] Train MSE: 0.011207  Val PSNR: 19.8587 dB
  -> New BEST model (PSNR=19.8587 dB)
Epoch [6/1000] Train MSE: 0.010549  Val PSNR: 20.0883 dB
  -> New BEST model (PSNR=20.0883 dB)
Epoch [7/1000] Train MSE: 0.010032  Val PSNR: 20.2851 dB
  -> New BEST model (PSNR=20.2851 dB)
Epoch [8/1000] Train MSE: 0.009619  Val PSNR: 20.4669 dB
  -> New BEST model (PSNR=20.4669 dB)
Epoch [9/1000] Train MSE: 0.009277  Val PSNR: 20.6123 dB
  -> New BEST model (PSNR=20.6123 dB)
Epoch [10/1000] Train MSE: 0.009078  Val PSNR: 20.6795 dB
  -> New BEST model (PSNR=20.6795 dB)
Epoch [11/1000] Train MSE: 0.008845  Val PSNR: 20

In [23]:
import torch
from torch import nn
import torch.optim as optim
import math
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === IMPORTANT: set these to exactly what you used when training ===
base_channels = 32      # or 40, etc. (whatever you trained the best model with)
num_blocks    = 6       # same number of blocks you used before

# Recreate model with SAME architecture as used for best_model.pt
model = AGDNNanoRestorer(
    in_channels=3,
    out_channels=3,
    base_channels=base_channels,
    num_blocks=num_blocks,
    distill_ratio=0.5,  # same as before
).to(device)

# Load best checkpoint
best_ckpt_path = "checkpoints_small_v2/best_model.pt"
checkpoint = torch.load(best_ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model"])

start_epoch   = checkpoint.get("epoch", 0) + 1
best_val_psnr = checkpoint.get("val_psnr", 0.0)

print(f"Loaded best_model.pt from epoch {start_epoch-1} "
      f"with Val PSNR = {best_val_psnr:.4f} dB")


Loaded best_model.pt from epoch 901 with Val PSNR = 22.8717 dB


In [24]:
def calc_psnr_rgb(pred, target):
    pred = torch.clamp(pred, 0.0, 1.0)
    target = torch.clamp(target, 0.0, 1.0)
    mse = torch.mean((pred - target) ** 2).item()
    if mse == 0:
        return 99.0
    return 10 * math.log10(1.0 / mse)


def rgb_to_y(img):
    # img: Bx3xHxW, range [0,1]
    r, g, b = img[:, 0:1, :, :], img[:, 1:2, :, :], img[:, 2:3, :, :]
    y = 0.299 * r + 0.587 * g + 0.114 * b
    return y


def calc_psnr_y(pred, target):
    pred = torch.clamp(pred, 0.0, 1.0)
    target = torch.clamp(target, 0.0, 1.0)
    pred_y   = rgb_to_y(pred)
    target_y = rgb_to_y(target)
    mse = torch.mean((pred_y - target_y) ** 2).item()
    if mse == 0:
        return 99.0
    return 10 * math.log10(1.0 / mse)


In [27]:
criterion = nn.MSELoss()   # same as before, but with fresh optimizer

# OneCycleLR needs steps_per_epoch
steps_per_epoch = len(train_loader)

# Fine-tune for N extra epochs
finetune_epochs = 5000   # adjust if you want more / less

# Max LR a bit higher than your old "final tiny LR", but not crazy
max_lr = 1e-4          # you can try 2e-4 if training is stable

optimizer = optim.Adam(model.parameters(), lr=max_lr)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=max_lr,
    steps_per_epoch=steps_per_epoch,
    epochs=finetune_epochs,
    pct_start=0.3,     # 30% warmup, then anneal
    anneal_strategy='cos',
    div_factor=10.0,   # initial lr = max_lr / div_factor
    final_div_factor=100.0
)

print("Finetune setup ready.")


Finetune setup ready.


In [28]:
os.makedirs("checkpoints_small_v2", exist_ok=True)

current_best_psnr_rgb = best_val_psnr   # from the checkpoint
current_best_epoch    = start_epoch - 1

for e in range(finetune_epochs):
    epoch = start_epoch + e

    # -------- TRAIN --------
    model.train()
    running_loss = 0.0

    for lr_img, hr_img in train_loader:
        lr_img = lr_img.to(device, non_blocking=True)
        hr_img = hr_img.to(device, non_blocking=True)

        optimizer.zero_grad()
        sr = model(lr_img)
        loss = criterion(sr, hr_img)
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * lr_img.size(0)

    train_loss = running_loss / len(train_loader.dataset)

    # -------- VALIDATION --------
    model.eval()
    val_psnr_rgb_sum = 0.0
    val_psnr_y_sum   = 0.0

    with torch.no_grad():
        for lr_img, hr_img in val_loader:
            lr_img = lr_img.to(device, non_blocking=True)
            hr_img = hr_img.to(device, non_blocking=True)

            sr = model(lr_img)

            val_psnr_rgb_sum += calc_psnr_rgb(sr, hr_img)
            val_psnr_y_sum   += calc_psnr_y(sr, hr_img)

    val_psnr_rgb = val_psnr_rgb_sum / len(val_loader)
    val_psnr_y   = val_psnr_y_sum   / len(val_loader)

    print(f"[Finetune] Epoch {epoch} | "
          f"Train MSE: {train_loss:.6f} | "
          f"Val PSNR(RGB): {val_psnr_rgb:.4f} dB | "
          f"Val PSNR(Y): {val_psnr_y:.4f} dB")

    # Save last
    torch.save(
        {
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "val_psnr": val_psnr_rgb,   # store RGB PSNR as before
        },
        "checkpoints_small_v2/last_model.pt",
    )

    # Update best (based on RGB PSNR, to be consistent with previous)
    if val_psnr_rgb > current_best_psnr_rgb:
        current_best_psnr_rgb = val_psnr_rgb
        current_best_epoch    = epoch
        torch.save(
            {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "val_psnr": val_psnr_rgb,
            },
            "checkpoints_small_v2/best_model.pt",
        )
        print(f"  -> New BEST (RGB) model: {current_best_psnr_rgb:.4f} dB at epoch {epoch}")


[Finetune] Epoch 902 | Train MSE: 0.005790 | Val PSNR(RGB): 22.8660 dB | Val PSNR(Y): 23.5567 dB
[Finetune] Epoch 903 | Train MSE: 0.005783 | Val PSNR(RGB): 22.8698 dB | Val PSNR(Y): 23.5609 dB
[Finetune] Epoch 904 | Train MSE: 0.005787 | Val PSNR(RGB): 22.8744 dB | Val PSNR(Y): 23.5656 dB
  -> New BEST (RGB) model: 22.8744 dB at epoch 904
[Finetune] Epoch 905 | Train MSE: 0.005787 | Val PSNR(RGB): 22.8691 dB | Val PSNR(Y): 23.5587 dB
[Finetune] Epoch 906 | Train MSE: 0.005793 | Val PSNR(RGB): 22.8719 dB | Val PSNR(Y): 23.5625 dB
[Finetune] Epoch 907 | Train MSE: 0.005790 | Val PSNR(RGB): 22.8707 dB | Val PSNR(Y): 23.5614 dB
[Finetune] Epoch 908 | Train MSE: 0.005779 | Val PSNR(RGB): 22.8722 dB | Val PSNR(Y): 23.5627 dB
[Finetune] Epoch 909 | Train MSE: 0.005787 | Val PSNR(RGB): 22.8724 dB | Val PSNR(Y): 23.5629 dB
[Finetune] Epoch 910 | Train MSE: 0.005783 | Val PSNR(RGB): 22.8720 dB | Val PSNR(Y): 23.5630 dB
[Finetune] Epoch 911 | Train MSE: 0.005786 | Val PSNR(RGB): 22.8691 dB | Val

In [33]:
calib_dir = "checkpoints_small_v2/calib_imgs"
os.makedirs(calib_dir, exist_ok=True)

lr_train_files = sorted(
    [p for p in glob.glob(os.path.join(train_lr_root, "*")) if os.path.isfile(p)]
)

random.shuffle(lr_train_files)
num_calib = min(50, len(lr_train_files))

for src_path in lr_train_files[:num_calib]:
    fname = os.path.basename(src_path)
    dst_path = os.path.join(calib_dir, fname)
    shutil.copy(src_path, dst_path)

print(f"Copied {num_calib} calibration images into '{calib_dir}'")


Copied 50 calibration images into 'checkpoints_small_v2/calib_imgs'


In [30]:
# Reload best model
best_ckpt_path = "checkpoints_small_v2/best_model.pt"
checkpoint = torch.load(best_ckpt_path, map_location="cpu")

export_model = AGDNNanoRestorer(
    in_channels=3,
    out_channels=3,
    base_channels=base_channels,   # use the same you trained with
    num_blocks=num_blocks,
    distill_ratio=0.5,
)
export_model.load_state_dict(checkpoint["model"])
export_model.eval()

dummy_input = torch.randn(1, 3, 256, 256)

onnx_path = "agdn_nano.onnx"

torch.onnx.export(
    export_model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["output"],
    opset_version=17,      # or whatever opset your NPU prefers
    do_constant_folding=True,
    dynamic_axes=None      # fixed 1x3x256x256 for clean .mxq conversion
)

print("Exported ONNX to:", onnx_path)


Exported ONNX to: agdn_nano.onnx


In [31]:
import time

model.eval()
dummy = torch.randn(1, 3, 256, 256).to(device)

# warm-up
for _ in range(10):
    _ = model(dummy)

iters = 50

if device.type == "cuda":
    torch.cuda.synchronize()

start = time.time()
for _ in range(iters):
    _ = model(dummy)
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = (time.time() - start) / iters * 1000.0

print(f"Average inference time small: {elapsed:.3f} ms per 256x256 image")

Average inference time small: 4.618 ms per 256x256 image
