In [None]:
!gdown http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
!gdown http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
!unzip -q DIV2K_valid_HR.zip
!unzip -q DIV2K_train_HR.zip

Downloading...
From: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip
To: /content/DIV2K_valid_HR.zip
100% 449M/449M [00:21<00:00, 20.8MB/s]
Downloading...
From: http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
To: /content/DIV2K_train_HR.zip
100% 3.53G/3.53G [02:34<00:00, 22.8MB/s]


In [None]:
!pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp312-cp312-linux_x86_64.whl (798.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m798.9/798.9 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m135.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp312-cp312-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m116.3 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.1)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu

In [None]:
import os
import argparse
import random
from glob import glob
from pathlib import Path
import time
import csv
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
from tqdm import tqdm

In [None]:
!pip install lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4


In [None]:
import lpips

In [None]:
!pip install --upgrade scikit-image




In [None]:
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import cv2
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage.restoration import denoise_bilateral
from skimage.util import random_noise

In [None]:
class DenoiseDataset(Dataset):
    def __init__(
        self,
        hr_dir,
        hr_size=256,
        sigmas=(0.01,0.03),
        extensions=('png',)
      ):

        self.hr_paths = []
        for ext in extensions:
            self.hr_paths += glob(os.path.join(hr_dir, f'**/*.{ext}'), recursive=True)
        self.hr_paths = sorted(self.hr_paths)
        self.hr_size = hr_size
        self.sigmas = sigmas
        self.to_tensor = transforms.ToTensor()
        self.resize_hr = transforms.Resize((hr_size, hr_size), interpolation=transforms.InterpolationMode.BICUBIC)

    def __len__(self):
        return len(self.hr_paths)

    def __getitem__(self, idx):
        path = self.hr_paths[idx]
        img = Image.open(path).convert('RGB')
        img = self.resize_hr(img)

        clean = np.asarray(img).astype(np.float32) / 255.0
        sigma = random.choice(self.sigmas)

        noisy_np = random_noise(clean, mode='gaussian', var=sigma ** 2)

        clean = torch.from_numpy(clean.transpose(2, 0, 1)).float()
        noisy = torch.from_numpy(noisy_np.transpose(2, 0, 1)).float()

        return {'noisy': noisy, 'clean': clean, 'sigma': sigma, 'path': path}

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self,x):
        return self.conv(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)
        x1 = F.pad(x1, [diffX//2, diffX - diffX//2, diffY//2, diffY - diffY//2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class UNetDenoise(nn.Module):
    def __init__(self, in_channels=3, base_filters=32):
        super().__init__()
        self.inc = DoubleConv(in_channels, base_filters)
        self.down1 = Down(base_filters, base_filters*2)
        self.down2 = Down(base_filters*2, base_filters*4)
        self.bot = DoubleConv(base_filters*4, base_filters*8)
        self.up2 = Up(base_filters*8 + base_filters*4, base_filters*4)
        self.up1 = Up(base_filters*4 + base_filters*2, base_filters*2)
        self.up0 = Up(base_filters*2 + base_filters, base_filters)
        self.final = nn.Conv2d(base_filters, in_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        xb = self.bot(x3)
        x = self.up2(xb, x3)
        x = self.up1(x, x2)
        x = self.up0(x, x1)
        out = self.final(x)
        out = torch.clamp(out, 0.0, 1.0)
        return out

In [None]:
class UNetDenoiseV2(nn.Module):
    def __init__(self, in_channels=3, base_filters=32):
        super().__init__()
        self.inc = DoubleConv(in_channels, base_filters)
        self.down1 = Down(base_filters, base_filters*2)
        self.down2 = Down(base_filters*2, base_filters*4)
        self.down3 = Down(base_filters*4, base_filters*8)

        self.bot = DoubleConv(base_filters*8, base_filters*16)

        self.up3 = Up(base_filters*16 + base_filters*8, base_filters*8)
        self.up2 = Up(base_filters*8 + base_filters*4, base_filters*4)
        self.up1 = Up(base_filters*4 + base_filters*2, base_filters*2)
        self.up0 = Up(base_filters*2 + base_filters, base_filters)

        self.final = nn.Conv2d(base_filters, in_channels, kernel_size=1)

    def forward(self, x):
      x1 = self.inc(x)
      x2 = self.down1(x1)
      x3 = self.down2(x2)
      x4 = self.down3(x3)

      xb = self.bot(x4)

      x_dec = self.up3(xb, x4)
      x_dec = self.up2(x_dec, x3)
      x_dec = self.up1(x_dec, x2)
      x_dec = self.up0(x_dec, x1)

      noise = self.final(x_dec)
      out = x - noise
      return torch.clamp(out, 0.0, 1.0)

In [None]:
def tensor_to_uint8_image(t):
    arr = t.cpu().numpy()
    arr = np.transpose(arr, (1,2,0))
    arr = np.clip(arr*255.0, 0, 255).astype(np.uint8)
    return arr

def compute_psnr(hr_uint8, pred_uint8, data_range=255.0):
    return compare_psnr(hr_uint8, pred_uint8, data_range=data_range)

def compute_ssim_safe(hr_uint8, pred_uint8):
    try:
        return compare_ssim(hr_uint8, pred_uint8, data_range=255.0, channel_axis=2, win_size=7)
    except TypeError:
        return compare_ssim(hr_uint8, pred_uint8, data_range=255.0, multichannel=True, win_size=7)

def compute_snr_db(hr_tensor, pred_tensor):
    hr = hr_tensor.cpu().numpy()
    pr = pred_tensor.cpu().numpy()
    signal_power = np.sum(hr**2)
    noise_power = np.sum((hr - pr)**2)
    if noise_power <= 1e-12:
        return float('inf')
    return 10.0 * np.log10(signal_power / noise_power)

def compute_lpips(lpips_fn, hr_tensor, pred_tensor):
    hr_n = hr_tensor.unsqueeze(0) * 2.0 - 1.0
    pr_n = pred_tensor.unsqueeze(0) * 2.0 - 1.0
    with torch.no_grad():
        d = lpips_fn(hr_n, pr_n, normalize=True)
    return float(d.mean().cpu().numpy())

In [None]:
def baseline_bilateral(
    noisy_uint8,
    sigma_color=0.05,
    sigma_spatial=15
  ):
    img = noisy_uint8.astype(np.float32) / 255.0
    den = denoise_bilateral(img, sigma_color=sigma_color, sigma_spatial=sigma_spatial, channel_axis=-1)
    den_uint8 = np.clip(den * 255.0, 0, 255).astype(np.uint8)
    return den_uint8

In [None]:
def train_one_epoch(
    model,
    loader,
    opt,
    device,
    criterion,
    epoch,
    lpips_fn=None
  ):
    model.train()
    running = 0.0
    pbar = tqdm(loader, desc=f"Train {epoch}")
    for batch in pbar:
        noisy = batch['noisy'].to(device)
        clean = batch['clean'].to(device)
        pred = model(noisy)
        loss = criterion(pred, clean)
        if lpips_fn is not None:
            lp = lpips_fn((pred*2-1), (clean*2-1)).mean()
            loss = loss + 0.1 * lp
        opt.zero_grad()
        loss.backward()
        opt.step()
        running += loss.item()
        pbar.set_postfix(loss = running / (pbar.n + 1))
    return running / len(loader)

def evaluate(
    model,
    loader,
    device,
    lpips_fn=None,
  ):
    model.eval()
    results = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Eval"):
            noisy = batch['noisy'].to(device)
            clean = batch['clean'].to(device)
            paths = batch['path']

            pred = model(noisy)

            b = noisy.size(0)
            for i in range(b):
                clean_uint8 = tensor_to_uint8_image(clean[i])
                noisy_uint8 = tensor_to_uint8_image(noisy[i])
                pred_uint8 = tensor_to_uint8_image(pred[i])

                psnr_pred = compute_psnr(clean_uint8, pred_uint8)
                ssim_pred = compute_ssim_safe(clean_uint8, pred_uint8)
                snr_pred = compute_snr_db(clean[i], pred[i])

                lp_pred = None
                if lpips_fn is not None:
                    lp_pred = compute_lpips(lpips_fn, clean[i], pred[i])

                results.append({
                    'path': paths[i],
                    'sigma': float(batch['sigma'][i]) if 'sigma' in batch else None,
                    'psnr_pred': psnr_pred,
                    'ssim_pred': ssim_pred,
                    'snr_pred_db': snr_pred,
                    'lpips_pred': lp_pred
                })
    return results

In [None]:
def collate_fn(batch):
    noisy = torch.stack([item['noisy'] for item in batch], dim=0)
    clean = torch.stack([item['clean'] for item in batch], dim=0)
    sigs = torch.tensor([item['sigma'] for item in batch], dtype=torch.float32)
    paths = [item['path'] for item in batch]
    return {'noisy': noisy, 'clean': clean, 'sigma': sigs, 'path': paths}

def save_results_csv(results, out_csv='results_denoise.csv'):
    df = pd.DataFrame(results)
    mean_row = {
        'path': 'MEAN',
        'sigma': df['sigma'].mean() if 'sigma' in df.columns else None,
        'psnr_pred': df['psnr_pred'].mean(),
        'ssim_pred': df['ssim_pred'].mean(),
        'snr_pred_db': df['snr_pred_db'].replace([np.inf, -np.inf], np.nan).mean(),
        'lpips_pred': df['lpips_pred'].mean() if 'lpips_pred' in df.columns else None,
        'psnr_baseline': df['psnr_baseline'].mean() if 'psnr_baseline' in df.columns else None,
        'ssim_baseline': df['ssim_baseline'].mean() if 'ssim_baseline' in df.columns else None,
        'snr_baseline_db': df['snr_baseline_db'].replace([np.inf, -np.inf], np.nan).mean() if 'snr_baseline_db' in df.columns else None,
        'lpips_baseline': df['lpips_baseline'].mean() if 'lpips_baseline' in df.columns else None,
    }
    df = pd.concat([df, pd.DataFrame([mean_row])], ignore_index=True)
    df.to_csv(out_csv, index=False)
    print(f"Saved results to {out_csv}")

In [None]:
def parse_args(args):
    p = argparse.ArgumentParser()
    p.add_argument('--train_dir', default='./train')
    p.add_argument('--val_dir', default='./val')
    p.add_argument('--hr_size', type=int, default=256)
    p.add_argument('--epochs', type=int, default=30)
    p.add_argument('--batch_size', type=int, default=8)
    p.add_argument('--lr', type=float, default=1e-4)
    p.add_argument('--base_filters', type=int, default=32)
    p.add_argument('--num_workers', type=int, default=4)
    p.add_argument('--resume', default='')
    p.add_argument('--out_csv', default='results_denoise_v3.csv')
    if args is None:
        args, unknown = p.parse_known_args(sys.argv[1:])
    else:
        args, unknown = p.parse_known_args(args)

    if unknown:
        print("gnored unknown arguments:", unknown)
    return args

In [None]:
def main(arg1):
    args = parse_args(arg1)
    device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
    print("Device:", device)

    train_ds = DenoiseDataset(args.train_dir, hr_size=args.hr_size, sigmas=(0.01, 0.03))
    val_ds = DenoiseDataset(args.val_dir, hr_size=args.hr_size, sigmas=(0.01, 0.03))

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)

    #model = UNetDenoise(in_channels=3, base_filters=args.base_filters).to(device)
    model = UNetDenoiseV2(in_channels=3, base_filters=args.base_filters).to(device)


    lpips_fn = lpips.LPIPS(net='alex').to(device)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.L1Loss()

    start_epoch = 1
    if args.resume and os.path.exists(args.resume):
        chk = torch.load(args.resume, map_location=device)
        model.load_state_dict(chk['model'])
        opt.load_state_dict(chk['opt'])
        start_epoch = chk.get('epoch', 1) + 1
        print(f"Resumed from {args.resume}")

    best = None
    for epoch in range(start_epoch, args.epochs + 1):
        train_loss = train_one_epoch(model, train_loader, opt, device, criterion, epoch, lpips_fn=lpips_fn)
        print(f"Epoch {epoch} train loss: {train_loss:.6f}")

        torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch}, f'checkpoint_{epoch}.pth')

        results = evaluate(model, val_loader, device, lpips_fn=lpips_fn)
        out_csv = f"{Path(args.out_csv).stem}_epoch{epoch}.csv"
        save_results_csv(results, out_csv=out_csv)

        df = pd.DataFrame(results)
        mean_psnr = df['psnr_pred'].mean()
        if best is None or mean_psnr > best:
            best = mean_psnr
            torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch}, 'best_denoise.pth')
            print(f"New best model saved (epoch {epoch}) mean PSNR {mean_psnr:.4f}")

    print("Training finished. Best mean PSNR:", best)

In [None]:
main([
    '--train_dir', '/content/DIV2K_train_HR',
    '--val_dir', '/content/DIV2K_valid_HR',
    '--epochs', '25',
    '--batch_size', '4'
])

Device: cuda
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


Train 1: 100%|██████████| 200/200 [01:28<00:00,  2.27it/s, loss=0.0508]


Epoch 1 train loss: 0.050810


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.71it/s]


Saved results to results_denoise_v2_epoch1.csv
New best model saved (epoch 1) mean PSNR 32.6396


Train 2: 100%|██████████| 200/200 [01:27<00:00,  2.29it/s, loss=0.0246]


Epoch 2 train loss: 0.024551


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.72it/s]


Saved results to results_denoise_v2_epoch2.csv


Train 3: 100%|██████████| 200/200 [01:26<00:00,  2.33it/s, loss=0.0236]


Epoch 3 train loss: 0.023637


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.75it/s]


Saved results to results_denoise_v2_epoch3.csv


Train 4: 100%|██████████| 200/200 [01:27<00:00,  2.29it/s, loss=0.0221]


Epoch 4 train loss: 0.022120


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.75it/s]


Saved results to results_denoise_v2_epoch4.csv
New best model saved (epoch 4) mean PSNR 34.2740


Train 5: 100%|██████████| 200/200 [01:25<00:00,  2.34it/s, loss=0.0208]


Epoch 5 train loss: 0.020808


Eval: 100%|██████████| 100/100 [00:13<00:00,  7.35it/s]


Saved results to results_denoise_v2_epoch5.csv


Train 6: 100%|██████████| 200/200 [01:26<00:00,  2.32it/s, loss=0.0215]


Epoch 6 train loss: 0.021503


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.78it/s]


Saved results to results_denoise_v2_epoch6.csv


Train 7: 100%|██████████| 200/200 [01:26<00:00,  2.31it/s, loss=0.0213]


Epoch 7 train loss: 0.021277


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.91it/s]


Saved results to results_denoise_v2_epoch7.csv


Train 8: 100%|██████████| 200/200 [01:24<00:00,  2.36it/s, loss=0.0195]


Epoch 8 train loss: 0.019453


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.70it/s]


Saved results to results_denoise_v2_epoch8.csv
New best model saved (epoch 8) mean PSNR 35.0697


Train 9: 100%|██████████| 200/200 [01:26<00:00,  2.31it/s, loss=0.0188]


Epoch 9 train loss: 0.018818


Eval: 100%|██████████| 100/100 [00:12<00:00,  7.76it/s]


Saved results to results_denoise_v2_epoch9.csv


Train 10: 100%|██████████| 200/200 [01:26<00:00,  2.30it/s, loss=0.0185]


Epoch 10 train loss: 0.018531


Eval: 100%|██████████| 100/100 [00:12<00:00,  8.04it/s]


Saved results to results_denoise_v2_epoch10.csv


Train 11:   8%|▊         | 16/200 [00:08<01:32,  1.99it/s, loss=0.0185]


KeyboardInterrupt: 

In [None]:
def run_baseline_denoising(noisy_uint8, sigma=None, method='bilateral'):
    if method == 'bilateral':
        sigma_color = float(sigma) if sigma is not None else 0.05
        den_uint8 = baseline_bilateral(noisy_uint8, sigma_color=sigma_color, sigma_spatial=15)
    elif method == 'none':
        den_uint8 = noisy_uint8.copy()
    else:
        raise ValueError(f"Unknown baseline method: {method}")

    return {'method': method, 'denoised_uint8': den_uint8}

In [None]:
def evaluate_solution(
        model,
        dataset,
        device,
        filename,
        num_samples=8,
        reference_implementation=None,
    ):
    model.eval()

    n_total = len(dataset)
    n = min(num_samples, n_total)

    selected_indices = random.sample(range(n_total), n)
    samples = [dataset[i] for i in selected_indices]

    noisy = torch.stack([s['noisy'] for s in samples])
    clean = torch.stack([s['clean'] for s in samples])
    paths = [s['path'] for s in samples]
    sigmas = [s['sigma'] for s in samples]

    print("Images used for evaluation:")
    for p in paths:
        print(" -", p)

    lpips_model = lpips.LPIPS(net='alex').to(device)

    with torch.no_grad():
        noisy_device = noisy.to(device)
        output = model(noisy_device)
        if isinstance(output, (tuple, list)):
            output = output[0]
        output_cpu = output.cpu()

        if reference_implementation is not None:
            ref_outputs = []
            for i in range(n):
                sigma = float(sigmas[i])
                path = paths[i]

                noisy_uint8 = tensor_to_uint8_image(noisy[i])
                clean_uint8 = tensor_to_uint8_image(clean[i])

                baseline_result = run_baseline_denoising(noisy_uint8, sigma)
                bas_uint8 = baseline_result['denoised_uint8']

                ref_tensor = torch.from_numpy(bas_uint8.astype(np.float32) / 255.0).permute(2, 0, 1).to(device)
                ref_outputs.append(ref_tensor)

            ref_output_cpu = torch.stack(ref_outputs)
        else:
            ref_output_cpu = None

    def compute_metrics(denoised, clean):
        psnr_list, ssim_list, lpips_list = [], [], []
        for i in range(n):
            denoised_np = denoised[i].permute(1, 2, 0).numpy()
            clean_np = clean[i].permute(1, 2, 0).numpy()
            psnr_list.append(compare_psnr(clean_np, denoised_np, data_range=1.0))
            ssim_val = compare_ssim(clean_np, denoised_np, data_range=1.0, channel_axis=2)
            ssim_list.append(ssim_val)
            lpips_val = lpips_model(
                denoised[i].unsqueeze(0) * 2 - 1,
                clean[i].unsqueeze(0) * 2 - 1
            ).item()
            lpips_list.append(lpips_val)
        return {
            "PSNR": np.mean(psnr_list),
            "SSIM": np.mean(ssim_list),
            "LPIPS": np.mean(lpips_list),
        }

    model_metrics = compute_metrics(output_cpu, clean)
    if ref_output_cpu is not None:
        ref_metrics = compute_metrics(ref_output_cpu, clean)
    else:
        ref_metrics = None

    print("\nMetrics (average over {} samples):".format(n))
    print("Main Model:")
    print(" - PSNR: {:.2f}".format(model_metrics["PSNR"]))
    print(" - SSIM: {:.4f}".format(model_metrics["SSIM"]))
    print(" - LPIPS: {:.4f}".format(model_metrics["LPIPS"]))

    if ref_metrics is not None:
        print("\nReference Implementation:")
        print(" - PSNR: {:.2f}".format(ref_metrics["PSNR"]))
        print(" - SSIM: {:.4f}".format(ref_metrics["SSIM"]))
        print(" - LPIPS: {:.4f}".format(ref_metrics["LPIPS"]))

    if ref_output_cpu is not None:
        num_rows = 4
        row_labels = ['Noisy', 'Denoised (Model)', 'Reference', 'Clean']
        row_tensors = lambda col: [noisy, output_cpu, ref_output_cpu, clean]
    else:
        num_rows = 3
        row_labels = ['Noisy', 'Denoised (Model)', 'Clean']
        row_tensors = lambda col: [noisy, output_cpu, clean]

    fig, axes = plt.subplots(num_rows, n, figsize=(max(4, n * 3), num_rows * 3))

    if n == 1 and num_rows == 1:
        axes = np.array([[axes]])
    elif n == 1:
        axes = np.expand_dims(axes, axis=1)
    elif num_rows == 1:
        axes = np.expand_dims(axes, axis=0)

    for col in range(n):
        imgs_for_col = row_tensors(col)
        for row, img_tensor in enumerate(imgs_for_col):
            img = img_tensor[col].permute(1, 2, 0).numpy()
            img = np.clip(img, 0, 1)
            axes[row, col].imshow(img)
            axes[row, col].axis('off')
        axes[0, col].set_title(f"Sample {col + 1}", fontsize=10)

    left_margin = 0.08
    plt.tight_layout(rect=[left_margin, 0, 1, 1])

    for row, label in enumerate(row_labels):
        y = 1.0 - (row + 0.5) / num_rows
        fig.text(
            left_margin / 2.0,
            y,
            label,
            va='center',
            ha='center',
            fontsize=14,
            fontweight='bold',
            rotation='vertical',
            rotation_mode='anchor'
        )

    plt.savefig(filename, bbox_inches='tight', dpi=200)
    #plt.show()
    plt.close()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load("best_denoise.pth", map_location=device)

denoiser = UNetDenoiseV2()
denoiser.to(device)
denoiser.load_state_dict(checkpoint['model'])

val_dataset = DenoiseDataset('./DIV2K_valid_HR', 256, (0.01, 0.03))

evaluate_solution(
    denoiser,
    val_dataset,
    device,
    'denoised.png',
    num_samples=8,
    reference_implementation=baseline_bilateral
)


  checkpoint = torch.load("best_denoise.pth", map_location=device)


Images used for evaluation:
 - ./DIV2K_valid_HR/0831.png
 - ./DIV2K_valid_HR/0898.png
 - ./DIV2K_valid_HR/0812.png
 - ./DIV2K_valid_HR/0861.png
 - ./DIV2K_valid_HR/0873.png
 - ./DIV2K_valid_HR/0895.png
 - ./DIV2K_valid_HR/0896.png
 - ./DIV2K_valid_HR/0857.png
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!