## Load the model and the config

In [None]:
import os
import torch

def save_ckp(checkpoint, checkpoint_dir, suffix=""):
    old_path = f"{checkpoint_dir}/checkpoint_{suffix}_{checkpoint['epoch']-1}.pt"
    if os.path.exists(old_path):
        os.remove(old_path)
    f_path = f"{checkpoint_dir}/checkpoint_{suffix}_{checkpoint['epoch']}.pt"
    torch.save(checkpoint, f_path)

def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from generator import RRDBNet
from discriminator import DiscriminatorForVGG
import json

# cache to read model weights from
target_dir = "cache-northern-vortex-43"
config_path = target_dir + "/metadata.json"
with open(config_path) as conf_file:
    config = json.load(conf_file)

# create models
G = RRDBNet(in_nc=config["in_nc"], out_nc=config["out_nc"], nf=config["nf"], nb=config["nb"], gc=config["gc"])
D = DiscriminatorForVGG(in_channels=config["in_channels"], out_channels=config["out_channels"], channels=config["channels"])

# load weights
G, _, _ = load_ckp(target_dir + f"/checkpoint_generator_{config['epochs']}.pt", G, None)
D, _, _ = load_ckp(target_dir + f"/checkpoint_discriminator_{config['epochs']}.pt", D, None)

# move models to device (cpu, cuda or mps)
G.to(device=device)
D.to(device=device);

## Prepare the data

In [None]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((128, 128), antialias=None), 
])

In [None]:
from custom_datasets import PairWiseImages, PairWiseImagesRGBE

### RGBE format
if config["RGBE"]:
    pair = PairWiseImagesRGBE("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                        "LDR-HDR-pair_Dataset-master/HDR/", 
                        transform=train_transform, device=device)
    
### Original dataset RGB format
else:
    pair = PairWiseImages("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                        "LDR-HDR-pair_Dataset-master/HDR/", 
                        transform=train_transform, device=device)


In [None]:
import torch
from torch.utils.data import Subset

if config["nb_images"] > len(pair):
    raise ValueError("Number of images to train is greater than the dataset size")

elif config["nb_images"] == -1 or config["nb_images"] == len(pair):
    print("Training on the whole dataset")
    pair_subset = pair
else:
    print("Training on a subset of the dataset")
    indices = torch.arange(config["nb_images"])
    pair_subset = Subset(pair, indices)

In [None]:
import torch

length = len(pair_subset)
test_length = int(0.2 * length)

train_data, valid_data = torch.utils.data.random_split(pair_subset, [length - test_length, test_length])

In [None]:
BATCH_SIZE = config["batch_size"]

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_dataloader), len(valid_dataloader)

## Evaluate the model

### Load a batch from the validation set

In [None]:
# load a batch from the validation dataset
ldr, hdr = next(iter(valid_dataloader))

# create a real, generated and random HDR image
real_hdr = hdr
fake_hdr = G(ldr)
random_hdr = torch.rand_like(hdr)

# run discriminator on the images
print("Real HDR images (values should be positive)")
print(D(hdr).mean(0, keepdim=True).detach())
print()

print("Random HDR images (values should be negative)")
print(D(random_hdr).mean(0, keepdim=True).detach())
print()

print("Generated HDR images (values should be positive for a good generator, and negative for a good discriminator)")
print(D(fake_hdr).mean(0, keepdim=True).detach())
print()

In [None]:
import numpy as np
from utils import preprocess_tensor_to_array
import cv2 as cv

def save_img(img_tensor: torch.Tensor, f_name: str, hdr=False):
    img_arr = preprocess_tensor_to_array(img_tensor)
    if hdr: # HDR: image should be in the range [0, 4]
        np.clip(img_arr, 0, 4)
        try:
            cv.imwrite(f"{target_dir}/image_samples/{f_name}.hdr", img_arr)
            return True
        except:
            print("Error while saving image")
            return False
    else: # LDR: image should be in the range [0, 255]
        np.clip(img_arr, 0, 1)
        img_arr *= 255
        try:
            cv.imwrite(f"{target_dir}/image_samples/{f_name}.jpg", img_arr)
            return True
        except:
            print("Error while saving image")
            return False

In [1]:
def psnr(orig_hdr: torch.Tensor, pred_hdr: torch.Tensor):
    mse = torch.mean((orig_hdr - pred_hdr) ** 2)
    return 20 * torch.log10(4 / torch.sqrt(mse))

def ssim(orig_hdr: torch.Tensor, pred_hdr: torch.Tensor):
    orig_hdr = preprocess_tensor_to_array(orig_hdr)
    pred_hdr = preprocess_tensor_to_array(pred_hdr)
    return cv.SSIM(orig_hdr, pred_hdr, multichannel=True)

NameError: name 'torch' is not defined