## Load the model and the config

In [1]:
import os
import torch

def save_ckp(checkpoint, checkpoint_dir, suffix=""):
    old_path = f"{checkpoint_dir}/checkpoint_{suffix}_{checkpoint['epoch']-1}.pt"
    if os.path.exists(old_path):
        os.remove(old_path)
    f_path = f"{checkpoint_dir}/checkpoint_{suffix}_{checkpoint['epoch']}.pt"
    torch.save(checkpoint, f_path)

def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

In [2]:
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from generator import RRDBNet
from discriminator import DiscriminatorForVGG, DiscriminatorForVGG2
import json

# cache to read model weights from
target_dir = "cache-northern-vortex-43"
config_path = target_dir + "/metadata.json"
with open(config_path) as conf_file:
    config = json.load(conf_file)

# create models
G = RRDBNet(in_nc=config["in_nc"], out_nc=config["out_nc"], nf=config["nf"], nb=config["nb"], gc=config["gc"])
D = DiscriminatorForVGG(in_channels=config["in_channels"], out_channels=config["out_channels"], channels=config["channels"])

# load weights
G, _, _ = load_ckp(target_dir + f"/checkpoint_generator_{config['epochs']}.pt", G, None)
D, _, _ = load_ckp(target_dir + f"/checkpoint_discriminator_{config['epochs']}.pt", D, None)

# move models to device (cpu, cuda or mps)
G.to(device=device)
D.to(device=device);

## Prepare the data

In [4]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((128, 128), antialias=None), 
])

In [5]:
from custom_datasets import PairWiseImages, PairWiseImagesRGBE

### RGBE format
if config["RGBE"]:
    pair = PairWiseImagesRGBE("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                        "LDR-HDR-pair_Dataset-master/HDR/", 
                        transform=train_transform, device=device)
    
### Original dataset RGB format
else:
    pair = PairWiseImages("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                        "LDR-HDR-pair_Dataset-master/HDR/", 
                        transform=train_transform, device=device)


In [6]:
import torch
from torch.utils.data import Subset

if config["nb_images"] > len(pair):
    raise ValueError("Number of images to train is greater than the dataset size")

elif config["nb_images"] == -1 or config["nb_images"] == len(pair):
    print("Training on the whole dataset")
    pair_subset = pair
else:
    print("Training on a subset of the dataset")
    indices = torch.arange(config["nb_images"])
    pair_subset = Subset(pair, indices)

Training on the whole dataset


In [7]:
import torch

length = len(pair_subset)
test_length = int(0.2 * length)

train_data, valid_data = torch.utils.data.random_split(pair_subset, [length - test_length, test_length])

In [8]:
BATCH_SIZE = 2

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_dataloader), len(valid_dataloader)

(71, 18)

## Evaluate the model

### Load a batch from the validation set

In [9]:
# load a batch from the validation dataset
ldr, hdr = next(iter(valid_dataloader))

# create a real, generated and random HDR image
real_hdr = hdr
fake_hdr = G(ldr)
random_hdr = torch.rand_like(hdr)

# run discriminator on the images
print("Real HDR images (values should be positive)")
print(D(hdr).mean(0, keepdim=True).detach())
print()

print("Random HDR images (values should be negative)")
print(D(random_hdr).mean(0, keepdim=True).detach())
print()

print("Generated HDR images (values should be positive for a good generator, and negative for a good discriminator)")
print(D(fake_hdr).mean(0, keepdim=True).detach())
print()

Real HDR images (values should be positive)
tensor([[7.7343, 7.9246, 7.8803]], device='mps:0')

Random HDR images (values should be negative)
tensor([[-25.0027, -24.6086, -25.0165]], device='mps:0')

Generated HDR images (values should be positive for a good generator, and negative for a good discriminator)
tensor([[4.9257, 5.1651, 5.1004]], device='mps:0')



In [10]:
import numpy as np
from utils import preprocess_tensor_to_array
import cv2 as cv

def save_img(img_tensor: torch.Tensor, f_name: str, hdr=False):
    img_arr = preprocess_tensor_to_array(img_tensor)
    if hdr: # HDR: image should be in the range [0, 4]
        np.clip(img_arr, 0, 4)
        try:
            cv.imwrite(f"{target_dir}/image_samples/{f_name}.hdr", img_arr)
            return True
        except:
            print("Error while saving image")
            return False
    else: # LDR: image should be in the range [0, 255]
        np.clip(img_arr, 0, 1)
        img_arr *= 255
        try:
            cv.imwrite(f"{target_dir}/image_samples/{f_name}.jpg", img_arr)
            return True
        except:
            print("Error while saving image")
            return False

In [38]:
import torch
from skimage.metrics import structural_similarity #type:ignore

def psnr(orig_hdr: torch.Tensor, pred_hdr: torch.Tensor):
    mse = torch.mean((orig_hdr - pred_hdr) ** 2)
    return 20 * torch.log10(4 / torch.sqrt(mse))

def ssim(orig_hdr: torch.Tensor, pred_hdr: torch.Tensor):
    orig_hdr = preprocess_tensor_to_array(orig_hdr)
    pred_hdr = preprocess_tensor_to_array(pred_hdr)
    return structural_similarity(orig_hdr.flatten(), pred_hdr.flatten(), data_range=4)

In [39]:
psnr(real_hdr[0], fake_hdr[0]).item()

22.314022064208984

In [40]:
ssim(real_hdr[0].detach(), fake_hdr[0].detach())

0.7643911497236188

In [22]:
small_psnr_avg = 0
big_psnr_avg = 0

for _ in range(100):
    small1 = torch.rand(3, 128, 128, dtype=torch.float32) * 4
    small2 = torch.rand(3, 128, 128, dtype=torch.float32) * 4
    big1 = torch.rand(3, 1024, 1024, dtype=torch.float32) * 4
    big2 = torch.rand(3, 1024, 1024, dtype=torch.float32) * 4
    small_psnr_avg += psnr(small1, small2)
    big_psnr_avg += psnr(big1, big2)

small_psnr_avg /= 100
big_psnr_avg /= 100
print(f"Small PSNR: {small_psnr_avg}")
print(f"Big PSNR: {big_psnr_avg}")

Small PSNR: 7.7792510986328125
Big PSNR: 7.781223297119141


In [12]:
from custom_datasets import PairWiseImages, PairWiseImagesRGBE
    
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((128, 128), antialias=None), 
])

### Original dataset RGB format
pair = PairWiseImages("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                    "LDR-HDR-pair_Dataset-master/HDR/", 
                    transform=train_transform, device="cpu")


import torch

length = len(pair)
test_length = int(0.2 * length)

torch.manual_seed(config["seed"])
train_data, valid_data = torch.utils.data.random_split(pair, [length - test_length, test_length])

KeyError: 'seed'

In [24]:
train_data[0]

(tensor([[[0.7598, 0.8706, 0.8324,  ..., 0.9990, 0.9569, 0.9461],
          [0.8686, 0.8657, 0.9049,  ..., 0.9922, 1.0000, 0.9775],
          [0.8657, 0.8902, 0.8873,  ..., 0.9961, 0.9598, 0.9069],
          ...,
          [0.5363, 0.7020, 0.5980,  ..., 0.6627, 0.6451, 0.7127],
          [0.6265, 0.6333, 0.5794,  ..., 0.6363, 0.6882, 0.6422],
          [0.4863, 0.4392, 0.4069,  ..., 0.6265, 0.6794, 0.7147]],
 
         [[0.6912, 0.8873, 0.8520,  ..., 0.9833, 0.9127, 0.9284],
          [0.8412, 0.8853, 0.9127,  ..., 0.9725, 0.9824, 0.9657],
          [0.8578, 0.9098, 0.8951,  ..., 0.9843, 0.9373, 0.8755],
          ...,
          [0.5147, 0.6765, 0.5706,  ..., 0.6510, 0.6333, 0.6814],
          [0.6059, 0.6020, 0.5559,  ..., 0.6206, 0.6745, 0.6304],
          [0.4608, 0.4137, 0.3873,  ..., 0.6098, 0.6598, 0.7029]],
 
         [[0.6520, 0.8843, 0.8637,  ..., 0.9716, 0.9049, 0.9157],
          [0.7863, 0.8814, 0.9127,  ..., 0.9608, 0.9667, 0.9520],
          [0.8225, 0.9059, 0.8951,  ...,