# DeLDRify - ESRGAN applied to single-track LDR to HDR image conversion

## Import and initalize the models

In [1]:
from generator import RRDBNet
from discriminator import DiscriminatorForVGG

G = RRDBNet(in_nc=3, out_nc=3, nf=32, nb=4, gc=16)
D = DiscriminatorForVGG(in_channels=3, out_channels=3, channels=16)

In [2]:
print("Number of param (G):", sum(p.numel() for p in G.parameters()))
print("Number of param (D):", sum(p.numel() for p in D.parameters()))

Number of param (G): 731011
Number of param (D): 1061619


## Prepare the data

In [15]:
from torch.utils.data import Dataset
import os
import cv2 as cv
import numpy as np

class PairWiseImages(Dataset):

    def __init__(self, ldr_path, hdr_path, transform=None) -> None:
        self.ldr_path = ldr_path
        self.hdr_path = hdr_path
        self.transform = transform
        self.ldr_list = sorted(os.listdir(ldr_path))
        self.hdr_list = sorted(os.listdir(hdr_path))

    def __len__(self):
        return len(self.ldr_list)
    
    def __getitem__(self, idx):
        ldr_img_path = os.path.join(self.ldr_path, self.ldr_list[idx])
        hdr_img_path = os.path.join(self.hdr_path, self.hdr_list[idx])
        ldr_img = cv.imread(ldr_img_path)
        ldr_img = ldr_img.astype(np.float32)
        ldr_img /= 255.0
        hdr_img = cv.imread(hdr_img_path, flags=cv.IMREAD_ANYDEPTH)
        hdr_img /= 4.0
        if self.transform:
            ldr_img = self.transform(ldr_img)
            hdr_img = self.transform(hdr_img)
        return ldr_img, hdr_img

In [16]:
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((128, 128), antialias=None), 
])

In [17]:
pair = PairWiseImages("LDR-HDR-pair_Dataset-master/LDR_exposure_0/", 
                      "LDR-HDR-pair_Dataset-master/HDR/", 
                      transform=train_transform)

In [18]:
import torch
from torch.utils.data import Subset

indices = torch.arange(40)
pair_40 = Subset(pair, indices)

In [19]:
import torch

length = len(pair_40)
test_length = int(0.2 * length)

train, valid = torch.utils.data.random_split(pair_40, [length - test_length, test_length])

In [20]:
BATCH_SIZE = 4

train_dataloader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
valid_data_loader = torch.utils.data.DataLoader(valid, batch_size=BATCH_SIZE, shuffle=True)

## Train the models

In [21]:
import wandb
wandb.login()

run = wandb.init(project="DeLDRify")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33moskarjor[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [22]:
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
device

'mps'

In [23]:
criterion_pixel = torch.nn.L1Loss().to(device)
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)

In [24]:
from datetime import datetime

results_dir = './cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-models")
results_dir

'./cache-2023-11-27-20-55-53-models'

In [28]:
from tqdm.notebook import tqdm

epochs = 100

optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.9, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.9, 0.999))

loss_scaling_factor = 1e-3

if not os.path.exists(results_dir):
        os.mkdir(results_dir)

for epoch in tqdm(range(epochs)):
    total_loss_G = 0
    total_loss_D = 0

    for ldr, hdr in tqdm(train_dataloader, leave=False):

        D_output_shape = D.out_channels

        valid = torch.tensor(np.ones((ldr.size(0), D_output_shape)), requires_grad=False)
        fake = torch.tensor(np.zeros((ldr.size(0), D_output_shape)), requires_grad=False)

        # Train Generator
        optimizer_G.zero_grad()

        fake_hdr = G(ldr)

        loss_pixel = criterion_pixel(fake_hdr, hdr)

        pred_real = D(hdr).detach()
        pred_fake = D(fake_hdr)

        loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        loss_G = loss_pixel + loss_scaling_factor * loss_GAN
        total_loss_G += loss_G.item()

        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        pred_real = D(hdr)
        pred_fake = D(fake_hdr.detach())

        loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        loss_D = (loss_real + loss_fake) / 2
        total_loss_D += loss_D.item()

        loss_D.backward()
        optimizer_D.step()

    wandb.log({"loss_G": total_loss_G / len(train_dataloader), "loss_D": total_loss_D / len(train_dataloader), "epoch": epoch})

    torch.save(G.state_dict(), f"{results_dir}/generator_last.pth")
    torch.save(D.state_dict(), f"{results_dir}/discriminator_last.pth")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
wandb.finish()

In [None]:
import PIL

ldr_img = PIL.Image.fromarray((ldr.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
hdr_img = PIL.Image.fromarray((hdr.permute(1, 2, 0).numpy() / 4 * 255).astype(np.uint8))

In [None]:
ldr_img

In [None]:
hdr_img

In [None]:
fake_hdr = G(ldr.unsqueeze(0)).detach()

In [None]:
fake_hdr_img = PIL.Image.fromarray((fake_hdr[0].permute(1, 2, 0).detach().numpy() / 4 * 255).astype(np.uint8))

In [None]:
fake_hdr_img

In [None]:
torch.abs(fake_hdr.flatten() - hdr.flatten()).argmax()

In [None]:
hdr.flatten()[torch.abs(fake_hdr.flatten() - hdr.flatten()).argmax()]

In [None]:
D(fake_hdr)