# DeLDRify - ESRGAN applied to single-track LDR to HDR image conversion

## Import and initalize the models

In [1]:
import torch

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
device

'mps'

In [21]:
from generator import RRDBNet
from discriminator import DiscriminatorForVGG

G = RRDBNet(in_nc=3, out_nc=4, nf=64, nb=8, gc=32)
D = DiscriminatorForVGG(in_channels=4, out_channels=4, channels=64)

G.to(device=device)
D.to(device=device);

In [22]:
print("Number of param (G):", sum(p.numel() for p in G.parameters()) / 1_000_000, "M")
print("Number of param (D):", sum(p.numel() for p in D.parameters()) / 1_000_000, "M")

Number of param (G): 5.79642 M
Number of param (D): 14.50028 M


## Prepare the data

In [23]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((128, 128), antialias=None), 
])

In [24]:
from custom_datasets import PairWiseImages, PairWiseImagesRGBE

### Original dataset RGB format
# pair = PairWiseImages("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
#                       "LDR-HDR-pair_Dataset-master/HDR/", 
#                       transform=train_transform)

### RGBE format
pair = PairWiseImagesRGBE("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                      "LDR-HDR-pair_Dataset-master/HDR/", 
                      transform=train_transform, device=device)

In [25]:
import torch
from torch.utils.data import Subset

indices = torch.arange(40)
pair_40 = Subset(pair, indices)

In [26]:
import torch

length = len(pair_40)
test_length = int(0.2 * length)

train, valid = torch.utils.data.random_split(pair_40, [length - test_length, test_length])

In [27]:
BATCH_SIZE = 4

train_dataloader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
valid_data_loader = torch.utils.data.DataLoader(valid, batch_size=BATCH_SIZE, shuffle=True)

## Train the models

In [28]:
import wandb
wandb.login()

run = wandb.init(project="DeLDRify")



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
epoch,▁▃▆█
loss_D,█▁▁▁
loss_D_fake,█▁▁▁
loss_D_real,█▁▁▁
loss_G,█▄▂▁

0,1
epoch,3.0
loss_D,9e-05
loss_D_fake,6e-05
loss_D_real,0.00012
loss_G,0.12597


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01116782361111114, max=1.0)…

In [29]:
criterion_pixel = torch.nn.L1Loss().to(device)
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)

In [30]:
from datetime import datetime

results_dir = './cache-' + run.name
results_dir

'./cache-elated-rain-5'

In [31]:
from tqdm.notebook import tqdm
import os
import numpy as np

epochs = 100

optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.9, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.9, 0.999))

loss_scaling_factor = 1e-3

if not os.path.exists(results_dir):
        os.mkdir(results_dir)

for epoch in tqdm(range(epochs)):
    total_loss_G = 0
    total_loss_D = 0
    total_loss_D_real = 0
    total_loss_D_fake = 0

    for ldr, hdr in tqdm(train_dataloader, leave=False):

        D_output_shape = D.out_channels

        valid = torch.tensor(np.ones((ldr.size(0), D_output_shape)), requires_grad=False, dtype=torch.float32).to(device=device)
        fake = torch.tensor(np.zeros((ldr.size(0), D_output_shape)), requires_grad=False, dtype=torch.float32).to(device=device)

        # Train Generator
        optimizer_G.zero_grad()

        fake_hdr = G(ldr)

        loss_pixel = criterion_pixel(fake_hdr, hdr)

        pred_real = D(hdr).detach()
        pred_fake = D(fake_hdr)

        loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        loss_G = loss_pixel + loss_scaling_factor * loss_GAN
        total_loss_G += loss_G.item()

        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        pred_real = D(hdr)
        pred_fake = D(fake_hdr.detach())

        loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        loss_D = (loss_real + loss_fake) / 2
        total_loss_D_real += loss_real.item()
        total_loss_D_fake += loss_fake.item()
        total_loss_D += loss_D.item()

        loss_D.backward()
        optimizer_D.step()

    wandb.log({"loss_G": total_loss_G / len(train_dataloader), "loss_D": total_loss_D / len(train_dataloader), "loss_D_real": total_loss_D_real / len(train_dataloader), "loss_D_fake": total_loss_D_fake / len(train_dataloader), "epoch": epoch})

    torch.save(G.state_dict(), f"{results_dir}/generator_last.pth")
    torch.save(D.state_dict(), f"{results_dir}/discriminator_last.pth")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
wandb.finish()

## Load the model and do inference

In [None]:
import PIL

ldr, hdr = train[0]
ldr = ldr[[2, 1, 0], :, :] # BGR to RGB
hdr = hdr[[2, 1, 0], :, :] # BGR to RGB

ldr_img = PIL.Image.fromarray((ldr.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
hdr_img = PIL.Image.fromarray((hdr.permute(1, 2, 0).numpy() * 255).astype(np.uint8))

def tensor_to_hdr_img1(hdr_tensor):
    new_hdr = np.clip(hdr_tensor, 0, 1)
    new_hdr = new_hdr**(1/2.2)
    new_hdr_img = PIL.Image.fromarray((new_hdr.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
    return new_hdr_img

def tensor_to_hdr_img2(hdr_tensor):
    new_hdr2 = hdr_tensor.numpy()
    tonemap = cv.createTonemapDrago(2.2)
    scale = 1 / tonemap.getSaturation()
    new_hdr2 = np.transpose(new_hdr2, (1, 2, 0))
    new_hdr2 = scale * tonemap.process(new_hdr2)
    new_hdr2 = np.clip(new_hdr2, 0, 1)
    new_hdr2_img = PIL.Image.fromarray((new_hdr2 * 255).astype(np.uint8))
    return new_hdr2_img

new_hdr_img = tensor_to_hdr_img1(hdr)
new_hdr2_img = tensor_to_hdr_img2(hdr)

images = [ldr_img, new_hdr_img]

display(*images)

In [None]:
from generator import RRDBNet
from discriminator import DiscriminatorForVGG
import torch

target_dir = "cache-2023-11-29-20-52-36-models"

G = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=6, gc=32)
D = DiscriminatorForVGG(in_channels=3, out_channels=3, channels=64)

G.load_state_dict(torch.load(f"{target_dir}/generator_last.pth"))
D.load_state_dict(torch.load(f"{target_dir}/discriminator_last.pth"))

In [None]:
ldr, hdr = train[0]

In [None]:
hdr.type()

In [None]:
D(hdr.unsqueeze(0))

In [None]:
from custom_datasets import RGB_to_RGBE
import cv2 as cv

hdr = cv.imread("LDR-HDR-pair_Dataset-master/HDR/HDR_001.hdr", flags=cv.IMREAD_ANYDEPTH)
hdr_RGBE = RGB_to_RGBE(hdr)

In [None]:
hdr_RGBE.dtype