In [1]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/LLIE_Project'
    print(path_to_file)
    # change current path to the folder containing "file_name"
    os.chdir(path_to_file)
    !pwd

Mounted at /content/gdrive
/content/gdrive/My Drive/LLIE_Project
/content/gdrive/My Drive/LLIE_Project


In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from DatasetAndAugmentation.LowHighDataAugment import PairedTransforms
from DatasetAndAugmentation.LowHightDataset import LOLPairedDataset
import matplotlib.pyplot as plt
from model.model import RELLIE

try:
  from piqa import SSIM
except:
  !pip install piqa
  from piqa import SSIM

import torch
import torch.nn as nn
from torch.optim import AdamW, Adam

Collecting piqa
  Downloading piqa-1.3.2-py3-none-any.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.12.0->piqa)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.12.0->piqa)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.12.0->piqa)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.12.0->piqa)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.12.0->piqa)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.12.0->piqa)
  Downloading nvidia_cufft_

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Define the dir of the dataset

In [4]:
# direction of the dataset
dataset_dir = "/path/to/dataset"
# directory of low-light images
train_low_dir = "./LOLdataset/train/low"
# directory of high-light images
train_bright_dir = "./LOLdataset/train/high"

# test
test_low_dir = "./LOLdataset/test/low"
test_bright_dir = "./LOLdataset/test/high"

# create the train/test pic transformer, dataset and dataLoader

In [5]:
# create transform class to transform the image into tensor
train_batch_size = 1
train_transform = PairedTransforms(image_size=(400, 600), train=True)
train_dataset = LOLPairedDataset(train_low_dir, train_bright_dir, transform=train_transform, train=True)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)


test_transform = PairedTransforms(image_size=(400, 600), train=False)
test_dataset = LOLPairedDataset(test_low_dir, test_bright_dir, transform=train_transform, train=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Initialize the model

In [6]:
model = RELLIE().to(device)
lr = 0.0001
optimizer = Adam(
    model.parameters(), lr=lr
)

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

mse_loss = nn.MSELoss()
ssim_loss = SSIMLoss().to(device)

# Pipline of training Process

In [7]:
def train_one_epoch(model, train_loader, optimizer, epoch = 0):
    running_total = 0
    running_ref = 0
    running_dec = 0
    running_ill = 0
    num_batches = 0
    for batch in train_loader:
        input_low_light = batch["low"].to(device)
        target_high_light = batch["bright"].to(device)
        # forward
        reflectance_low_light,\
            reflectance_high_light,\
                illumination_low_light,\
                    illumination_high_light,\
                        enhanced_illumination = model(input_low_light, target_high_light, 'train')

        # calculate loss
        decomposition_loss = ssim_loss(reflectance_high_light * illumination_high_light, target_high_light) + \
            ssim_loss(reflectance_low_light * illumination_low_light, input_low_light)
        reflectance_loss = mse_loss(reflectance_low_light, reflectance_high_light)
        illumination_enhance_loss = ssim_loss(enhanced_illumination * reflectance_low_light, target_high_light)

        total_loss = 1.5 * decomposition_loss + 0.75 * reflectance_loss + 2.0 * illumination_enhance_loss

        running_dec += decomposition_loss.detach().item()
        running_ref += reflectance_loss.detach().item()
        running_ill += illumination_enhance_loss.detach().item()
        running_total += total_loss.detach().item()

        # backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        num_batches += 1

    print(f'Epoch {epoch + 1}:\t Total Loss = {running_total / num_batches}\t d_loss = {running_dec / num_batches}\t r_loss = {running_ref / num_batches}\t i_loss = {running_ill/num_batches}')


def evaluate(model, test_loader):
    running_total = 0
    num_batches = 0
    with torch.no_grad():
      for batch in test_loader:
          input_low_light = batch["low"].to(device)
          target_high_light = batch["bright"].to(device)

          # forward
          predict_high_light = model(input_low_light, None, 'eval')

          loss = ssim_loss(predict_high_light, target_high_light)
          running_total += loss.item()
          num_batches += 1

    print(f"Evaluation on test set: Loss = {running_total / num_batches}")


In [None]:
model_dict_folder = "./model_files"
if not os.path.exists(model_dict_folder):
    os.makedirs(model_dict_folder)

starting_epoch = 0
load_model = False

if load_model:
  model.load_state_dict(torch.load(f'epoch_{starting_epoch}_state.pt'), weights_only=True)

for epoch in range(starting_epoch, 200):
    if epoch % 20 == 1:
      lr = lr / 2
      optimizer = Adam(model.parameters(), lr=lr)

    train_one_epoch(model, train_loader, optimizer, epoch)
    torch.save(model.state_dict(), os.path.join(model_dict_folder, f'epoch_{epoch}_state.pt'))
    evaluate(model, test_loader)

# Visualize the output of model

In [None]:
num_images_show = 5
fig, axes = plt.subplots(3, num_images_show, figsize=(3 * train_batch_size, 9))
for i in range(num_images_show):
    batch = next(iter(test_loader))
    input_low_light = batch["low"].to(device)
    #print(input_low_light)
    target_high_light = batch["bright"].squeeze(0).to(device)

    #print(target_high_light)
    with torch.no_grad():
      predict_high_light = model(input_low_light, None, 'eval').squeeze(0)


    predict_high_light_PIL, target_high_light_PIL = train_transform.tensor2PIL(predict_high_light, target_high_light)
    input_low_light_PIL, _  = train_transform.tensor2PIL(input_low_light.squeeze(0), None)

    # 第1行: Picture input into the model
    axes[0, i].imshow(input_low_light_PIL)
    axes[0, i].set_title(f"Input {i+1}")
    axes[0, i].axis('off')

    # 第2行: Picture output from the model, augmented by the model
    axes[1, i].imshow(predict_high_light_PIL)
    axes[1, i].set_title(f"Predicted {i+1}")
    axes[1, i].axis('off')

    # 第3行：The target picture
    axes[2, i].imshow(target_high_light_PIL)
    axes[2, i].set_title(f"Target {i+1}")
    axes[2, i].axis('off')
plt.show()