# BayesCap SRGAN Experiment Suite
This notebook mirrors the new `scripts/bayescap_pipeline.py` CLI while keeping a linear, reproducible workflow for running every experiment combination (ImageNet vs. DIV2K checkpoints, parameter sweeps, and qualitative visualization) on the four benchmark super-resolution datasets.

## 1. Environment & Imports
Import the shared pipeline helpers plus the network definitions so every subsequent section (dataset setup, baseline evaluation, DIV2K training, parameter sweeps, qualitative viz) can call into the same code paths as the CLI.

In [1]:
import json
import sys
import random
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

project_root = Path.cwd().parent
scripts_dir = project_root / 'scripts'
if str(scripts_dir) not in sys.path:
    sys.path.insert(0, str(scripts_dir))

from bayescap_pipeline import (
    organize_sr_benchmarks,
    evaluate_sr_metrics,
    run_parameter_sweep,
    compare_experiments,
    load_experiment_registry,
    download_div2k,
    extract_div2k,
    build_div2k_loaders,
    pretrain_srgan_on_div2k,
    finetune_bayescap_on_div2k,
    load_model_pair,
)
from ds import ImgDset
from networks_SRGAN import Generator, BayesCap
from utils import img_ssim

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

data_root = project_root / 'data'
benchmark_root = data_root / 'SR' / 'val'
sr_factor = 4
baseline_generator_ckpt = project_root / 'ckpt' / 'srgan-ImageNet-bc347d67.pth'
baseline_bayescap_ckpt = project_root / 'ckpt' / 'BayesCap_SRGAN_best.pth'

  dtype=torch.cuda.FloatTensor(),


## 2. Benchmark Dataset Organization
Organize Set5/Set14/BSD100/Urban100 HR images into `data/SR/val/<dataset>/original` so loaders and evaluation code can enumerate them consistently. Skip gracefully if the structure already exists.

> **Scripted workflow**  The heavy-lifting utilities now live in `../scripts/bayescap_pipeline.py`.  You can call the same routines from this notebook (imported below) or invoke them via `python ../scripts/bayescap_pipeline.py --help` for a pure command-line workflow.

In [2]:
datasets = ['Set5', 'Set14', 'BSD100', 'Urban100']
organize_sr_benchmarks(data_root, datasets, sr_factor=sr_factor)
val_root = benchmark_root

[Set5] copied 0 HR images to /data/oe23/BayesCap/data/SR/val/Set5/original
  (nothing to do; files already organized)
[Set14] copied 0 HR images to /data/oe23/BayesCap/data/SR/val/Set14/original
  (nothing to do; files already organized)
[BSD100] copied 0 HR images to /data/oe23/BayesCap/data/SR/val/BSD100/original
  (nothing to do; files already organized)
[Urban100] copied 0 HR images to /data/oe23/BayesCap/data/SR/val/Urban100/original
  (nothing to do; files already organized)
Done organizing SR benchmark datasets.


In [3]:
val_dataset_name = 'Set5'
val_dset = ImgDset(
    dataroot=str(benchmark_root / val_dataset_name / 'original'),
    image_size=(256, 256),
    upscale_factor=sr_factor,
    mode='val',
)
val_loader = DataLoader(
    val_dset,
    batch_size=1,
    shuffle=False,
    pin_memory=torch.cuda.is_available(),
)
print(f"Loaded {len(val_dset)} validation crops from {val_dataset_name}")

Loaded 5 validation crops from Set5


### Validation loader for qualitative inspection
A lightweight loader on Set5 powers the final visualization cell so we can peek at SR outputs, uncertainty maps, and reconstruction errors.

## 3. Baseline (ImageNet-pretrained) Checkpoints
Load the released SRGAN generator and BayesCap head so we always have a reference point before adapting to DIV2K.

In [4]:
NetG, NetC = load_model_pair(baseline_generator_ckpt, baseline_bayescap_ckpt, device)
netg_params = sum(p.numel() for p in NetG.parameters())
netc_params = sum(p.numel() for p in NetC.parameters())
print(f"Generator params: {netg_params:,}")
print(f"BayesCap params: {netc_params:,}")

Generator params: 1,547,350
BayesCap params: 2,589,658


In [5]:
benchmark_datasets = ['Set5', 'Set14', 'BSD100', 'Urban100']
baseline_metrics_df = evaluate_sr_metrics(
    NetG,
    NetC,
    benchmark_datasets,
    dataset_root=benchmark_root,
    image_size=(256, 256),
    upscale_factor=sr_factor,
    batch_size=1,
    num_bins=30,
    device=str(device),
    dtype=dtype,
)
baseline_metrics_df

Unnamed: 0_level_0,PSNR,SSIM,UCE,C.Coeff,Images
Dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
BSD100,24.286604,0.634499,0.066838,0.337899,100
Set14,23.612141,0.663415,0.072535,0.274057,14
Set5,26.784134,0.794881,0.068288,0.347054,5
Urban100,21.594371,0.633385,15.270981,0.249663,100


## 5. Experiment Registry (ImageNet vs. DIV2K)
Use the JSON registry in `../scripts/experiment_registry.example.json` (copy/edit as needed) to enumerate every generator/BayesCap checkpoint pair you want to evaluate. The helper below loops through each entry, evaluates it on all four datasets, and produces a single comparison table.

In [6]:
experiment_registry_path = project_root / 'scripts' / 'experiment_registry.example.json'
experiments = load_experiment_registry(experiment_registry_path)
comparison_df = compare_experiments(
    experiments,
    benchmark_root,
    ['Set5', 'Set14', 'BSD100', 'Urban100'],
    image_size=(256, 256),
    upscale_factor=sr_factor,
    batch_size=1,
    num_bins=30,
    device=device,
    dtype=dtype,
)
comparison_df

Unnamed: 0_level_0,Unnamed: 1_level_0,PSNR,SSIM,UCE,C.Coeff,Images,Description
Experiment,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
div2kG_imagenetC,BSD100,23.367758,0.480238,0.007322,0.349033,100,DIV2K generator + ImageNet BayesCap
div2kG_imagenetC,Set14,21.127391,0.426991,0.010483,0.356087,14,DIV2K generator + ImageNet BayesCap
div2kG_imagenetC,Set5,22.468839,0.467435,0.016124,0.393581,5,DIV2K generator + ImageNet BayesCap
div2kG_imagenetC,Urban100,19.735719,0.386285,0.014973,0.291797,100,DIV2K generator + ImageNet BayesCap
div2k_pretrained,BSD100,23.367758,0.480238,0.093616,0.065566,100,DIV2K fine-tuned weights
div2k_pretrained,Set14,21.127391,0.426991,0.062091,,14,DIV2K fine-tuned weights
div2k_pretrained,Set5,22.468839,0.467435,0.061113,0.133503,5,DIV2K fine-tuned weights
div2k_pretrained,Urban100,19.735719,0.386285,5.814424,,100,DIV2K fine-tuned weights
imagenetG_div2kC,BSD100,24.286604,0.634499,0.367766,,100,ImageNet generator + DIV2K BayesCap
imagenetG_div2kC,Set14,23.612141,0.663415,0.088027,,14,ImageNet generator + DIV2K BayesCap


## 6. Parameter sweep across datasets and scoring knobs
Explore sensitivity to crop size, batch size, dataset grouping, and UCE binning. Results are concatenated into a single DataFrame and the manifest is logged to `data/SR/val/bayescap_sweep_manifest.json`. Grayscale benchmarks (e.g., Set14) are now auto-expanded to RGB during collation so mixed-channel batches no longer crash the sweep.

In [7]:
sweep_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sweep_dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
dataset_options = [
    ['Set5', 'Set14', 'BSD100'],
    ['Set5', 'Set14', 'BSD100', 'Urban100']
]
image_sizes = [(128, 128), (256, 256), (320, 320)]
batch_sizes = [1, 2]
num_bins = [15, 20, 30, 40]
manifest_path = benchmark_root / 'bayescap_sweep_manifest.json'

sweep_df = run_parameter_sweep(
    NetG,
    NetC,
    benchmark_root,
    dataset_options,
    image_sizes,
    batch_sizes,
    num_bins,
    upscale_factor=sr_factor,
    device=sweep_device,
    dtype=sweep_dtype,
    manifest_path=manifest_path,
)
sweep_df

[cfg_01] datasets=['Set5', 'Set14', 'BSD100'], image_size=(128, 128), batch=1, num_bins=15
[cfg_02] datasets=['Set5', 'Set14', 'BSD100', 'Urban100'], image_size=(128, 128), batch=1, num_bins=15
[cfg_02] datasets=['Set5', 'Set14', 'BSD100', 'Urban100'], image_size=(128, 128), batch=1, num_bins=15
[cfg_03] datasets=['Set5', 'Set14', 'BSD100'], image_size=(128, 128), batch=1, num_bins=20
[cfg_03] datasets=['Set5', 'Set14', 'BSD100'], image_size=(128, 128), batch=1, num_bins=20
[cfg_04] datasets=['Set5', 'Set14', 'BSD100', 'Urban100'], image_size=(128, 128), batch=1, num_bins=20
[cfg_04] datasets=['Set5', 'Set14', 'BSD100', 'Urban100'], image_size=(128, 128), batch=1, num_bins=20
[cfg_05] datasets=['Set5', 'Set14', 'BSD100'], image_size=(128, 128), batch=1, num_bins=30
[cfg_05] datasets=['Set5', 'Set14', 'BSD100'], image_size=(128, 128), batch=1, num_bins=30
[cfg_06] datasets=['Set5', 'Set14', 'BSD100', 'Urban100'], image_size=(128, 128), batch=1, num_bins=30
[cfg_06] datasets=['Set5', 'Se

Unnamed: 0_level_0,Unnamed: 1_level_0,PSNR,SSIM,UCE,C.Coeff,Images,ImageSize,BatchSize,NumBins,Datasets
Config,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
cfg_01,BSD100,23.509387,0.612648,0.076594,0.284839,100,"(128, 128)",1,15,"Set5, Set14, BSD100"
cfg_01,Set14,22.376287,0.626960,0.084797,0.255654,14,"(128, 128)",1,15,"Set5, Set14, BSD100"
cfg_01,Set5,24.194658,0.740373,0.097095,0.200861,5,"(128, 128)",1,15,"Set5, Set14, BSD100"
cfg_02,BSD100,23.509387,0.612648,0.076594,0.284839,100,"(128, 128)",1,15,"Set5, Set14, BSD100, Urban100"
cfg_02,Set14,22.376287,0.626960,0.084797,0.255654,14,"(128, 128)",1,15,"Set5, Set14, BSD100, Urban100"
...,...,...,...,...,...,...,...,...,...,...
cfg_47,Set5,28.024978,0.815751,0.060392,0.362039,5,"(320, 320)",2,40,"Set5, Set14, BSD100"
cfg_48,BSD100,23.982155,0.623020,0.063944,0.333563,100,"(320, 320)",2,40,"Set5, Set14, BSD100, Urban100"
cfg_48,Set14,24.242203,0.681729,0.068491,0.271168,14,"(320, 320)",2,40,"Set5, Set14, BSD100, Urban100"
cfg_48,Set5,28.024978,0.815751,0.060392,0.362039,5,"(320, 320)",2,40,"Set5, Set14, BSD100, Urban100"


In [9]:
sweep_df.to_csv(project_root / 'results' / 'bayescap_sr_parameter_sweep_results.csv', index=False)

## 7. DIV2K pretraining & fine-tuning workflow
Download DIV2K HR images, build loaders, optionally fine-tune the SRGAN generator, and (optionally) adapt BayesCap using the newly trained backbone. Toggle the boolean flags when you're ready to launch the heavier jobs; the training helpers now expose `tqdm` progress bars so you can track batches directly in the notebook.

### 7.1 Download & extract DIV2K

In [10]:
div2k_url = 'http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip'
div2k_zip_path = download_div2k(data_root, div2k_url)
div2k_hr_dir = extract_div2k(div2k_zip_path)
print(f'DIV2K archive: {div2k_zip_path}')
print(f'DIV2K HR directory: {div2k_hr_dir}')

[skip] /data/oe23/BayesCap/data/DIV2K/DIV2K_train_HR.zip already present
[skip] /data/oe23/BayesCap/data/DIV2K/DIV2K_train_HR already extracted
DIV2K archive: /data/oe23/BayesCap/data/DIV2K/DIV2K_train_HR.zip
DIV2K HR directory: /data/oe23/BayesCap/data/DIV2K/DIV2K_train_HR


### 7.2 Build DIV2K loaders

In [11]:
div2k_train_loader, div2k_val_loader = build_div2k_loaders(
    div2k_hr_dir,
    sr_factor=sr_factor,
    train_crop=(128, 128),
    val_crop=(256, 256),
    train_batch=8,
    val_batch=4,
    num_workers=4,
)
print(f"DIV2K train batches: {len(div2k_train_loader)}, val batches: {len(div2k_val_loader)}")

DIV2K train batches: 100, val batches: 200


### 7.3 Optional SRGAN fine-tuning on DIV2K

In [13]:
div2k_generator_ckpt = project_root / 'ckpt' / 'srgan_DIV2K.pth'
run_div2k_training = True  # Set True to launch full DIV2K training

if run_div2k_training:
    NetG_div2k = Generator()
    pretrain_srgan_on_div2k(
        NetG_div2k,
        div2k_train_loader,
        div2k_val_loader,
        epochs=100,
        lr=5e-5,
        ckpt_path=div2k_generator_ckpt,
        device=device,
    )
else:
    print("Skipping DIV2K generator training (run_div2k_training=False)")

[DIV2K][Epoch 1/100] train:   0%|          | 0/100 [00:00<?, ?batch/s]

                                                                                            

[DIV2K][Epoch 1/100] train loss=0.4157, val PSNR=13.03 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 2/100] train loss=0.1755, val PSNR=15.29 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 3/100] train loss=0.1397, val PSNR=17.02 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 4/100] train loss=0.1215, val PSNR=17.98 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 5/100] train loss=0.1093, val PSNR=18.66 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 6/100] train loss=0.1014, val PSNR=19.21 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 7/100] train loss=0.0957, val PSNR=19.70 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 8/100] train loss=0.0897, val PSNR=20.23 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                            

[DIV2K][Epoch 9/100] train loss=0.0858, val PSNR=20.75 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 10/100] train loss=0.0822, val PSNR=20.80 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 11/100] train loss=0.0793, val PSNR=21.23 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 12/100] train loss=0.0773, val PSNR=21.51 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 13/100] train loss=0.0755, val PSNR=21.78 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 14/100] train loss=0.0727, val PSNR=22.00 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 15/100] train loss=0.0711, val PSNR=22.18 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 16/100] train loss=0.0699, val PSNR=21.93 dB


                                                                                              

[DIV2K][Epoch 17/100] train loss=0.0686, val PSNR=22.49 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 18/100] train loss=0.0675, val PSNR=22.65 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 19/100] train loss=0.0664, val PSNR=22.67 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 20/100] train loss=0.0656, val PSNR=22.78 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 21/100] train loss=0.0647, val PSNR=22.76 dB


                                                                                              

[DIV2K][Epoch 22/100] train loss=0.0638, val PSNR=22.96 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 23/100] train loss=0.0634, val PSNR=23.00 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 24/100] train loss=0.0626, val PSNR=23.09 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 25/100] train loss=0.0621, val PSNR=23.15 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 26/100] train loss=0.0615, val PSNR=23.26 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 27/100] train loss=0.0611, val PSNR=23.26 dB


                                                                                              

[DIV2K][Epoch 28/100] train loss=0.0605, val PSNR=23.31 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 29/100] train loss=0.0602, val PSNR=23.39 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 30/100] train loss=0.0595, val PSNR=23.43 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 31/100] train loss=0.0590, val PSNR=23.44 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 32/100] train loss=0.0588, val PSNR=23.57 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 33/100] train loss=0.0588, val PSNR=23.57 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 34/100] train loss=0.0581, val PSNR=23.61 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 35/100] train loss=0.0579, val PSNR=23.64 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 36/100] train loss=0.0576, val PSNR=23.64 dB


                                                                                             

[DIV2K][Epoch 37/100] train loss=0.0573, val PSNR=23.65 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 38/100] train loss=0.0570, val PSNR=23.58 dB


                                                                                             

[DIV2K][Epoch 39/100] train loss=0.0569, val PSNR=23.74 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 40/100] train loss=0.0565, val PSNR=23.75 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 41/100] train loss=0.0564, val PSNR=23.76 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 42/100] train loss=0.0562, val PSNR=23.81 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 43/100] train loss=0.0559, val PSNR=23.82 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 44/100] train loss=0.0559, val PSNR=23.79 dB


                                                                                              

[DIV2K][Epoch 45/100] train loss=0.0557, val PSNR=23.85 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 46/100] train loss=0.0552, val PSNR=23.87 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 47/100] train loss=0.0553, val PSNR=23.80 dB


                                                                                             

[DIV2K][Epoch 48/100] train loss=0.0550, val PSNR=23.85 dB


                                                                                             

[DIV2K][Epoch 49/100] train loss=0.0547, val PSNR=23.94 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 50/100] train loss=0.0548, val PSNR=23.94 dB


                                                                                             

[DIV2K][Epoch 51/100] train loss=0.0546, val PSNR=23.97 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 52/100] train loss=0.0546, val PSNR=23.96 dB


                                                                                             

[DIV2K][Epoch 53/100] train loss=0.0545, val PSNR=23.99 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 54/100] train loss=0.0542, val PSNR=23.93 dB


                                                                                             

[DIV2K][Epoch 55/100] train loss=0.0541, val PSNR=24.02 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 56/100] train loss=0.0540, val PSNR=24.03 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 57/100] train loss=0.0539, val PSNR=24.04 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 58/100] train loss=0.0538, val PSNR=24.05 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 59/100] train loss=0.0536, val PSNR=24.06 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 60/100] train loss=0.0536, val PSNR=24.05 dB


                                                                                             

[DIV2K][Epoch 61/100] train loss=0.0536, val PSNR=24.07 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 62/100] train loss=0.0534, val PSNR=24.08 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 63/100] train loss=0.0534, val PSNR=24.07 dB


                                                                                             

[DIV2K][Epoch 64/100] train loss=0.0533, val PSNR=24.10 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 65/100] train loss=0.0533, val PSNR=24.07 dB


                                                                                             

[DIV2K][Epoch 66/100] train loss=0.0533, val PSNR=24.10 dB


                                                                                             

[DIV2K][Epoch 67/100] train loss=0.0531, val PSNR=24.11 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 68/100] train loss=0.0530, val PSNR=24.10 dB


                                                                                             

[DIV2K][Epoch 69/100] train loss=0.0530, val PSNR=24.13 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 70/100] train loss=0.0529, val PSNR=24.13 dB


                                                                                             

[DIV2K][Epoch 71/100] train loss=0.0529, val PSNR=24.13 dB


                                                                                              

[DIV2K][Epoch 72/100] train loss=0.0529, val PSNR=24.15 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 73/100] train loss=0.0527, val PSNR=24.15 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 74/100] train loss=0.0527, val PSNR=24.16 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 75/100] train loss=0.0527, val PSNR=24.16 dB


                                                                                              

[DIV2K][Epoch 76/100] train loss=0.0526, val PSNR=24.16 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 77/100] train loss=0.0526, val PSNR=24.17 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 78/100] train loss=0.0525, val PSNR=24.17 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 79/100] train loss=0.0525, val PSNR=24.16 dB


                                                                                              

[DIV2K][Epoch 80/100] train loss=0.0525, val PSNR=24.18 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 81/100] train loss=0.0525, val PSNR=24.17 dB


                                                                                             

[DIV2K][Epoch 82/100] train loss=0.0524, val PSNR=24.18 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 83/100] train loss=0.0524, val PSNR=24.17 dB


                                                                                             

[DIV2K][Epoch 84/100] train loss=0.0524, val PSNR=24.19 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 85/100] train loss=0.0523, val PSNR=24.19 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 86/100] train loss=0.0523, val PSNR=24.18 dB


                                                                                             

[DIV2K][Epoch 87/100] train loss=0.0523, val PSNR=24.19 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 88/100] train loss=0.0523, val PSNR=24.19 dB


                                                                                             

[DIV2K][Epoch 89/100] train loss=0.0523, val PSNR=24.19 dB


                                                                                             

[DIV2K][Epoch 90/100] train loss=0.0522, val PSNR=24.19 dB


                                                                                             

[DIV2K][Epoch 91/100] train loss=0.0522, val PSNR=24.20 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 92/100] train loss=0.0523, val PSNR=24.20 dB


                                                                                              

[DIV2K][Epoch 93/100] train loss=0.0523, val PSNR=24.20 dB


                                                                                              

[DIV2K][Epoch 94/100] train loss=0.0523, val PSNR=24.20 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 95/100] train loss=0.0522, val PSNR=24.20 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                             

[DIV2K][Epoch 96/100] train loss=0.0522, val PSNR=24.20 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 97/100] train loss=0.0523, val PSNR=24.19 dB


                                                                                              

[DIV2K][Epoch 98/100] train loss=0.0522, val PSNR=24.20 dB
  -> saved new best checkpoint to /data/oe23/BayesCap/ckpt/srgan_DIV2K.pth


                                                                                              

[DIV2K][Epoch 99/100] train loss=0.0523, val PSNR=24.19 dB


                                                                                              

[DIV2K][Epoch 100/100] train loss=0.0522, val PSNR=24.20 dB




### 7.4 Optional BayesCap adaptation on DIV2K

In [14]:
div2k_bayescap_ckpt = project_root / 'ckpt' / 'BayesCap_SRGAN_DIV2K.pth'
run_div2k_bayescap = True  # Set True to adapt BayesCap with DIV2K checkpoints

if run_div2k_bayescap:
    if not div2k_generator_ckpt.exists():
        raise FileNotFoundError(f"DIV2K generator checkpoint missing at {div2k_generator_ckpt}")
    finetune_bayescap_on_div2k(
        div2k_generator_ckpt,
        div2k_bayescap_ckpt,
        div2k_train_loader,
        div2k_val_loader,
        epochs=50,
        init_lr=5e-5,
        device=device,
    )
else:
    print("Skipping BayesCap DIV2K fine-tuning (run_div2k_bayescap=False)")

Epoch 0: 100%|██████████| 100/100 [00:20<00:00,  4.87batch/s, loss=0.268]



Avg. loss: 0.8978620047867298
current score: 0.14035906575561966 | Last best score: -100000000.0
current score: 0.14035906575561966 | Last best score: -100000000.0


Epoch 1: 100%|██████████| 100/100 [00:19<00:00,  5.06batch/s, loss=0.064]



Avg. loss: 0.14754561822861434
current score: 0.2843219699431211 | Last best score: 0.14035906575561966
current score: 0.2843219699431211 | Last best score: 0.14035906575561966


Epoch 2: 100%|██████████| 100/100 [00:24<00:00,  4.06batch/s, loss=0.0238]



Avg. loss: 0.055114151276648045
current score: 0.3845425215456635 | Last best score: 0.2843219699431211
current score: 0.3845425215456635 | Last best score: 0.2843219699431211


Epoch 3: 100%|██████████| 100/100 [00:19<00:00,  5.08batch/s, loss=0.0195] 



Avg. loss: 0.012815122939646244
current score: 0.4273222695104778 | Last best score: 0.3845425215456635
current score: 0.4273222695104778 | Last best score: 0.3845425215456635


Epoch 4: 100%|██████████| 100/100 [00:21<00:00,  4.72batch/s, loss=0.00432] 

Avg. loss: -0.006812080815434456





current score: 0.4675762705132365 | Last best score: 0.4273222695104778


Epoch 5: 100%|██████████| 100/100 [00:18<00:00,  5.27batch/s, loss=-0.021] 



Avg. loss: -0.01934349499642849
current score: 0.48223540786653757 | Last best score: 0.4675762705132365
current score: 0.48223540786653757 | Last best score: 0.4675762705132365


Epoch 6: 100%|██████████| 100/100 [00:18<00:00,  5.35batch/s, loss=-0.0272]



Avg. loss: -0.027918005362153053
current score: 0.5108658525720239 | Last best score: 0.48223540786653757
current score: 0.5108658525720239 | Last best score: 0.48223540786653757


Epoch 7: 100%|██████████| 100/100 [00:22<00:00,  4.40batch/s, loss=-0.0328]



Avg. loss: -0.034407481253147125
current score: 0.5148385311290622 | Last best score: 0.5108658525720239
current score: 0.5148385311290622 | Last best score: 0.5108658525720239


Epoch 8: 100%|██████████| 100/100 [00:18<00:00,  5.27batch/s, loss=-0.0474]



Avg. loss: -0.03952380947768688
current score: 0.5212487713247538 | Last best score: 0.5148385311290622
current score: 0.5212487713247538 | Last best score: 0.5148385311290622


Epoch 9: 100%|██████████| 100/100 [00:24<00:00,  4.04batch/s, loss=-0.0375]

Avg. loss: -0.04292162876576185





current score: 0.5305551955476403 | Last best score: 0.5212487713247538


Epoch 10: 100%|██████████| 100/100 [00:19<00:00,  5.24batch/s, loss=-0.0513]



Avg. loss: -0.047335860058665274
current score: 0.5440629214793443 | Last best score: 0.5305551955476403
current score: 0.5440629214793443 | Last best score: 0.5305551955476403


Epoch 11: 100%|██████████| 100/100 [00:19<00:00,  5.09batch/s, loss=-0.0515]



Avg. loss: -0.04916819728910923
current score: 0.5562065154314041 | Last best score: 0.5440629214793443
current score: 0.5562065154314041 | Last best score: 0.5440629214793443


Epoch 12: 100%|██████████| 100/100 [00:20<00:00,  4.90batch/s, loss=-0.0442]



Avg. loss: -0.04990063924342394
current score: 0.552295536249876 | Last best score: 0.5562065154314041
current score: 0.552295536249876 | Last best score: 0.5562065154314041


Epoch 13: 100%|██████████| 100/100 [00:19<00:00,  5.04batch/s, loss=-0.0583]
Epoch 13: 100%|██████████| 100/100 [00:19<00:00,  5.04batch/s, loss=-0.0583]


Avg. loss: -0.05312271829694509
current score: 0.5642969712987542 | Last best score: 0.5562065154314041
current score: 0.5642969712987542 | Last best score: 0.5562065154314041


Epoch 14: 100%|██████████| 100/100 [00:19<00:00,  5.11batch/s, loss=-0.0602]



Avg. loss: -0.05350533246994019
current score: 0.565108755864203 | Last best score: 0.5642969712987542
current score: 0.565108755864203 | Last best score: 0.5642969712987542


Epoch 15: 100%|██████████| 100/100 [00:19<00:00,  5.20batch/s, loss=-0.0464]



Avg. loss: -0.056444117836654185
current score: 0.566836151778698 | Last best score: 0.565108755864203
current score: 0.566836151778698 | Last best score: 0.565108755864203


Epoch 16: 100%|██████████| 100/100 [00:23<00:00,  4.24batch/s, loss=-0.0694]



Avg. loss: -0.05747743662446737
current score: 0.5645318302698433 | Last best score: 0.566836151778698
current score: 0.5645318302698433 | Last best score: 0.566836151778698


Epoch 17: 100%|██████████| 100/100 [00:20<00:00,  4.92batch/s, loss=-0.063]



Avg. loss: -0.061513791531324385
current score: 0.5774367508664727 | Last best score: 0.566836151778698
current score: 0.5774367508664727 | Last best score: 0.566836151778698


Epoch 18: 100%|██████████| 100/100 [00:20<00:00,  4.85batch/s, loss=-0.0454]



Avg. loss: -0.05968246031552553
current score: 0.5823274639248848 | Last best score: 0.5774367508664727
current score: 0.5823274639248848 | Last best score: 0.5774367508664727


Epoch 19: 100%|██████████| 100/100 [00:20<00:00,  4.76batch/s, loss=-0.0447]



Avg. loss: -0.05972648050636053
current score: 0.5821414452046156 | Last best score: 0.5823274639248848
current score: 0.5821414452046156 | Last best score: 0.5823274639248848


Epoch 20: 100%|██████████| 100/100 [00:19<00:00,  5.17batch/s, loss=-0.071]



Avg. loss: -0.06427139725536107
current score: 0.5845634558796883 | Last best score: 0.5823274639248848
current score: 0.5845634558796883 | Last best score: 0.5823274639248848


Epoch 21: 100%|██████████| 100/100 [00:19<00:00,  5.25batch/s, loss=-0.0753]



Avg. loss: -0.06361483838409185
current score: 0.5894536402449012 | Last best score: 0.5845634558796883
current score: 0.5894536402449012 | Last best score: 0.5845634558796883


Epoch 22: 100%|██████████| 100/100 [00:19<00:00,  5.08batch/s, loss=-0.0706]



Avg. loss: -0.06559251621365547
current score: 0.5898199523240328 | Last best score: 0.5894536402449012
current score: 0.5898199523240328 | Last best score: 0.5894536402449012


Epoch 23: 100%|██████████| 100/100 [00:21<00:00,  4.60batch/s, loss=-0.0707]



Avg. loss: -0.06589717738330364
current score: 0.582545041386038 | Last best score: 0.5898199523240328
current score: 0.582545041386038 | Last best score: 0.5898199523240328


Epoch 24: 100%|██████████| 100/100 [00:20<00:00,  4.97batch/s, loss=-0.073]
Epoch 24: 100%|██████████| 100/100 [00:20<00:00,  4.97batch/s, loss=-0.073]


Avg. loss: -0.06747364643961191
current score: 0.5940340758860111 | Last best score: 0.5898199523240328
current score: 0.5940340758860111 | Last best score: 0.5898199523240328


Epoch 25: 100%|██████████| 100/100 [00:19<00:00,  5.19batch/s, loss=-0.077]



Avg. loss: -0.06886816140264272
current score: 0.5927378167584538 | Last best score: 0.5940340758860111
current score: 0.5927378167584538 | Last best score: 0.5940340758860111


Epoch 26: 100%|██████████| 100/100 [00:19<00:00,  5.06batch/s, loss=-0.0775]



Avg. loss: -0.0702896561473608
current score: 0.6005431950464845 | Last best score: 0.5940340758860111
current score: 0.6005431950464845 | Last best score: 0.5940340758860111


Epoch 27: 100%|██████████| 100/100 [00:19<00:00,  5.22batch/s, loss=-0.0631]
Epoch 27: 100%|██████████| 100/100 [00:19<00:00,  5.22batch/s, loss=-0.0631]


Avg. loss: -0.0681889334693551
current score: 0.5977174665033818 | Last best score: 0.6005431950464845
current score: 0.5977174665033818 | Last best score: 0.6005431950464845


Epoch 28: 100%|██████████| 100/100 [00:18<00:00,  5.46batch/s, loss=-0.0803]
Epoch 28: 100%|██████████| 100/100 [00:18<00:00,  5.46batch/s, loss=-0.0803]


Avg. loss: -0.07106866039335728
current score: 0.6018683501332999 | Last best score: 0.6005431950464845
current score: 0.6018683501332999 | Last best score: 0.6005431950464845


Epoch 29: 100%|██████████| 100/100 [00:19<00:00,  5.19batch/s, loss=-0.0881]



Avg. loss: -0.07033423509448766
current score: 0.6022527013346552 | Last best score: 0.6018683501332999
current score: 0.6022527013346552 | Last best score: 0.6018683501332999


Epoch 30: 100%|██████████| 100/100 [00:22<00:00,  4.55batch/s, loss=-0.0641]

Avg. loss: -0.07156873177736997





current score: 0.6041249547526241 | Last best score: 0.6022527013346552


Epoch 31: 100%|██████████| 100/100 [00:20<00:00,  4.95batch/s, loss=-0.0515]



Avg. loss: -0.06839698918163777
current score: 0.5979144307225943 | Last best score: 0.6041249547526241
current score: 0.5979144307225943 | Last best score: 0.6041249547526241


Epoch 32: 100%|██████████| 100/100 [00:20<00:00,  4.90batch/s, loss=-0.0809]



Avg. loss: -0.07413590386509895
current score: 0.6073627404123545 | Last best score: 0.6041249547526241
current score: 0.6073627404123545 | Last best score: 0.6041249547526241


Epoch 33: 100%|██████████| 100/100 [00:19<00:00,  5.19batch/s, loss=-0.0664]



Avg. loss: -0.0740063664317131
current score: 0.6027900956198573 | Last best score: 0.6073627404123545
current score: 0.6027900956198573 | Last best score: 0.6073627404123545


Epoch 34: 100%|██████████| 100/100 [00:18<00:00,  5.30batch/s, loss=-0.0762]



Avg. loss: -0.07308924648910761
current score: 0.6070566699281335 | Last best score: 0.6073627404123545
current score: 0.6070566699281335 | Last best score: 0.6073627404123545


Epoch 35: 100%|██████████| 100/100 [00:18<00:00,  5.33batch/s, loss=-0.0682]



Avg. loss: -0.07373499043285847
current score: 0.6026199899986386 | Last best score: 0.6073627404123545
current score: 0.6026199899986386 | Last best score: 0.6073627404123545


Epoch 36: 100%|██████████| 100/100 [00:19<00:00,  5.05batch/s, loss=-0.0794]



Avg. loss: -0.07534410238265991
current score: 0.607339261136949 | Last best score: 0.6073627404123545
current score: 0.607339261136949 | Last best score: 0.6073627404123545


Epoch 37: 100%|██████████| 100/100 [00:22<00:00,  4.44batch/s, loss=-0.0783]
Epoch 37: 100%|██████████| 100/100 [00:22<00:00,  4.44batch/s, loss=-0.0783]


Avg. loss: -0.07586916133761407
current score: 0.6084771163761615 | Last best score: 0.6073627404123545
current score: 0.6084771163761615 | Last best score: 0.6073627404123545


Epoch 38: 100%|██████████| 100/100 [00:20<00:00,  4.92batch/s, loss=-0.0773]



Avg. loss: -0.07544692810624838
current score: 0.6116893562301994 | Last best score: 0.6084771163761615
current score: 0.6116893562301994 | Last best score: 0.6084771163761615


Epoch 39: 100%|██████████| 100/100 [00:18<00:00,  5.34batch/s, loss=-0.0837]



Avg. loss: -0.0769886427745223
current score: 0.6143934966623783 | Last best score: 0.6116893562301994
current score: 0.6143934966623783 | Last best score: 0.6116893562301994


Epoch 40: 100%|██████████| 100/100 [00:19<00:00,  5.09batch/s, loss=-0.0736]



Avg. loss: -0.0772609806805849
current score: 0.6137050611525774 | Last best score: 0.6143934966623783
current score: 0.6137050611525774 | Last best score: 0.6143934966623783


Epoch 41: 100%|██████████| 100/100 [00:19<00:00,  5.26batch/s, loss=-0.0735]



Avg. loss: -0.07873340897262096
current score: 0.6144172477722168 | Last best score: 0.6143934966623783
current score: 0.6144172477722168 | Last best score: 0.6143934966623783


Epoch 42: 100%|██████████| 100/100 [00:18<00:00,  5.28batch/s, loss=-0.068]



Avg. loss: -0.07728926859796047
current score: 0.6113726418092846 | Last best score: 0.6144172477722168
current score: 0.6113726418092846 | Last best score: 0.6144172477722168


Epoch 43: 100%|██████████| 100/100 [00:20<00:00,  4.93batch/s, loss=-0.0748]
Epoch 43: 100%|██████████| 100/100 [00:20<00:00,  4.93batch/s, loss=-0.0748]


Avg. loss: -0.07803525235503912
current score: 0.6080376373231411 | Last best score: 0.6144172477722168
current score: 0.6080376373231411 | Last best score: 0.6144172477722168


Epoch 44: 100%|██████████| 100/100 [00:21<00:00,  4.57batch/s, loss=-0.0795]



Avg. loss: -0.07971204526722431
current score: 0.6166969083249569 | Last best score: 0.6144172477722168
current score: 0.6166969083249569 | Last best score: 0.6144172477722168


Epoch 45: 100%|██████████| 100/100 [00:26<00:00,  3.80batch/s, loss=-0.0813]



Avg. loss: -0.07866195943206548
current score: 0.6204213196784258 | Last best score: 0.6166969083249569
current score: 0.6204213196784258 | Last best score: 0.6166969083249569


Epoch 46: 100%|██████████| 100/100 [00:19<00:00,  5.02batch/s, loss=-0.0844]



Avg. loss: -0.08060105603188276
current score: 0.6217589554935694 | Last best score: 0.6204213196784258
current score: 0.6217589554935694 | Last best score: 0.6204213196784258


Epoch 47: 100%|██████████| 100/100 [00:19<00:00,  5.24batch/s, loss=-0.078]



Avg. loss: -0.07920422993600368
current score: 0.6103375287726521 | Last best score: 0.6217589554935694
current score: 0.6103375287726521 | Last best score: 0.6217589554935694


Epoch 48: 100%|██████████| 100/100 [00:21<00:00,  4.67batch/s, loss=-0.0815]
Epoch 48: 100%|██████████| 100/100 [00:21<00:00,  4.67batch/s, loss=-0.0815]


Avg. loss: -0.07886756937950849
current score: 0.6147329454869032 | Last best score: 0.6217589554935694
current score: 0.6147329454869032 | Last best score: 0.6217589554935694


Epoch 49: 100%|██████████| 100/100 [00:19<00:00,  5.03batch/s, loss=-0.0664]
Epoch 49: 100%|██████████| 100/100 [00:19<00:00,  5.03batch/s, loss=-0.0664]


Avg. loss: -0.08106000002473593
current score: 0.6249246019124984 | Last best score: 0.6217589554935694
current score: 0.6249246019124984 | Last best score: 0.6217589554935694


## 8. Visualize SR outputs & uncertainty
Inspect SR reconstructions, alpha/beta maps, and per-pixel error/uncertainty for a few samples from the validation loader. This provides qualitative intuition that complements the quantitative tables above.

In [None]:
viz_device = device
viz_dtype = dtype
num_imgs = 0
mean_ssim = 0.0
for idx, batch in enumerate(val_loader):
    print(f'Image {idx} ...')
    xLR, xHR = batch[0].to(viz_device), batch[1].to(viz_device)
    xLR, xHR = xLR.type(viz_dtype), xHR.type(viz_dtype)
    with torch.no_grad():
        xSR = NetG(xLR)
        xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
    n_batch = xSRC_mu.shape[0]
    for j in range(n_batch):
        num_imgs += 1
        mean_ssim += img_ssim(xSRC_mu[j], xHR[j]).item()

    plt.figure(figsize=(30, 10))
    plt.subplot(1, 4, 1)
    plt.imshow(xLR[0].to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.imshow(xSR[0].to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    plt.axis('off')

    plt.subplot(1, 4, 3)
    a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
    plt.imshow(a_map.transpose(0,2).transpose(0,1), cmap='inferno')
    plt.clim(0, 0.1)
    plt.axis('off')

    plt.subplot(1, 4, 4)
    error_map = torch.mean(torch.pow(torch.abs(xSR[0]-xHR[0]),2), dim=0).to('cpu').data
    plt.imshow(error_map, cmap='jet')
    plt.clim(0,0.01)
    plt.axis('off')

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

    plt.figure(figsize=(30,10))
    plt.subplot(1,4,1)
    plt.imshow(xHR[0].to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    plt.axis('off')

    plt.subplot(1,4,2)
    plt.imshow((0.6*xSRC_mu[0]+0.4*xSR[0]).to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
    plt.axis('off')

    plt.subplot(1,4,3)
    b_map = xSRC_beta[0].to('cpu').data
    plt.imshow(b_map.transpose(0,2).transpose(0,1), cmap='cividis')
    plt.clim(0.45, 0.75)
    plt.axis('off')

    plt.subplot(1,4,4)
    u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
    plt.imshow(u_map.transpose(0,2).transpose(0,1), cmap='hot')
    plt.clim(0,0.15)
    plt.axis('off')

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

print(f'Mean SSIM over {num_imgs} crops: {mean_ssim / max(num_imgs, 1):.4f}')