In [None]:
import torch
import numpy as np
 
# check if CUDA is available
 
train_on_gpu = torch.cuda.is_available()
 
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

CUDA is available!  Training on GPU ...


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch.optim as optim
#from model import *
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

def load_dataset():
    data_path = 'path_to_train_set'
    train_dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )
    return train_loader

def load_val():
    data_path = 'path_to_val_set'
    val_dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.ToTensor()
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )
    return val_loader

  

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np

class DecomNet(nn.Module):

    def __init__(self, channel=64, kernel_size=3, is_Training=True):
        super(DecomNet, self).__init__()

        self.conv0 = nn.Conv2d(4, int(channel/2), kernel_size, padding=1)
        self.conv = nn.Conv2d(4, channel, kernel_size*3, padding=4)
        self.conv1 = nn.Conv2d(channel, channel, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(channel, channel*2, kernel_size, stride=2, padding=1)
        self.conv3 = nn.Conv2d(channel*2, channel*2, kernel_size, padding=1)
        self.conv4 = nn.ConvTranspose2d(channel*2, channel, kernel_size, stride=2, padding=1)
        self.conv4_1 = nn.Conv2d(channel, channel, kernel_size, stride=1, padding=1)
        self.conv5 = nn.Conv2d(channel*2, channel, kernel_size, padding=1)
        self.conv6 = nn.Conv2d(3*int(channel/2), channel, kernel_size, padding=1)
        self.conv7 = nn.Conv2d(channel, 4, kernel_size, padding=1)
        self.upsample = F.upsample_nearest

    def forward(self, x):

        x_hist = torch.max(x, dim=1, keepdim=True)
        #x_hist = x_hist.float()
        x = torch.cat((x, x_hist[0]), dim=1)

        x1 = F.relu(self.conv0(x))
        x = self.conv(x)
        x = F.relu(self.conv1(x))
        y = x
        shp = y.data.shape
        shp = shp[2:4]
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.upsample(x, size=shp)
        x = F.relu(self.conv4_1(x))
        x = torch.cat((x, y), dim=1)
        x = F.relu(self.conv5(x))
        
        x = torch.cat((x, x1), dim=1)
        x = self.conv6(x)
        out = self.conv7(x)

        return out

def rgb_to_grayscale(tensor):
    tensor = tensor.cpu()
    img = transforms.functional.to_pil_image(tensor, mode=None)
    img_gs = transforms.functional.to_grayscale(img, num_output_channels=1)
    return transforms.functional.to_tensor(img_gs)

def smooth(I, R):
    R1 = torch.squeeze(R, 0)
    R = rgb_to_grayscale(R1)
    R = R.cuda().unsqueeze(0)
    return torch.mean(gradient(I, "x") * torch.exp(-10 * gradient(R, "x")) + gradient(I, "y") * torch.exp(-10 * gradient(R, "y")))
    #return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.gradient(input_R, "y")))

def gradient(input_tensor, direction):
    smooth_kernel_x = torch.reshape(torch.tensor([[0, 0], [-1, 1]], dtype=torch.float32, device='cuda:0'), [2, 1, 2, 1])
    smooth_kernel_y = smooth_kernel_x.permute(2,1,0,3)

    if direction == "x":
        kernel = smooth_kernel_x
    elif direction == "y":
        kernel = smooth_kernel_y
    return torch.abs(F.conv2d(input_tensor, kernel, stride=(1,1), padding=1))

def lowLightLoss(input_im, R, L, im_eq):
    L_3 = torch.cat((L, L, L), dim=1)
    recon_loss_low = torch.mean(torch.abs(R * L_3 - input_im))
    R_low_max = torch.max(R, dim=1, keepdims=True)
    recon_loss_low_eq = torch.mean(torch.abs(R_low_max[0] - im_eq))
    #R1 = R.detach()
    R1 = torch.squeeze(R, 0)
    #R1 = torch.reshape(R1, (400, 600, 3))
    #R1 = R1.numpy()
    #print(R1.shape)
    a=gradient(rgb_to_grayscale(R1).cuda().unsqueeze(0), "x")
    b=gradient(rgb_to_grayscale(R1).cuda().unsqueeze(0), "y")
    #print(shp_a)
    #print(shp_b)
    R_low_loss_smooth = torch.mean(torch.abs(a) + torch.abs(b))
    Ismooth_loss_low = smooth(L, R)
    loss_Decom_zhangyu= recon_loss_low + 0.1 * Ismooth_loss_low + 0.1 * recon_loss_low_eq + 0.01*R_low_loss_smooth
    return loss_Decom_zhangyu




In [None]:
model = DecomNet()
model.cuda()
print(model)

DecomNet(
  (conv0): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv): Conv2d(4, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv4_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)


In [None]:
import torch
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0001, amsgrad=False)
n_epochs = 30 

In [None]:

import numpy as np
from PIL import Image
from pylab import *
import cv2
from PIL import ImageFilter

def histeq(im,nbr_bins = 256):
    imhist,bins = histogram(im.flatten(),nbr_bins,density=True)
    cdf = imhist.cumsum()
    cdf = 1.0*cdf / cdf[-1]
    im2 = interp(im.flatten(),bins[:-1],cdf)
    return im2.reshape(im.shape)

In [None]:
n_epochs = 100  # suggest training between 20-50 epochs

model.train() # prep model for training

for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0
    ###################
    # train the model #
    ###################
    #for data, target in train_loader:
    #i=0
    for batch_idx, (data, target) in enumerate(load_dataset()):
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        #print(data.shape)
        data, target = data.cuda(), target.cuda()
        eq = data.clone()
        im_max_channel = torch.max(eq, dim=1, keepdim=True)
        im_max_channel = im_max_channel[0].squeeze(0)
        img = im_max_channel.detach().cpu().numpy()
        im_eq = histeq(img)
        im_eq = torch.from_numpy(im_eq).float().cuda()
        im_eq = im_eq.unsqueeze(0)
        output = model(data)
        

        R = F.sigmoid(output[:,0:3,:,:])
        L = F.sigmoid(output[:,3:4,:,:]) 
        # calculate the loss
        loss = lowLightLoss(data, R, L, im_eq)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update running training loss
        train_loss += loss.item()*data.size(0)
        #i += 1
        #print("Training image " str(i))
        
    # print training statistics 
    # calculate average loss over an epoch
    train_loss = train_loss/1

    print('Epoch: {} \tTraining Loss: {:.6f}'.format(
        epoch, 
        train_loss
        ))
    print("Saving Epoch " + str(epoch))
    path = "model_" + str(epoch) + ".pt"
    torch.save(model.state_dict(), path)
    print("Saved!")



Epoch: 0 	Training Loss: 7.217404
Saving Epoch 0
Saved!
Epoch: 1 	Training Loss: 2.997780
Saving Epoch 1
Saved!
Epoch: 2 	Training Loss: 2.046865
Saving Epoch 2
Saved!
Epoch: 3 	Training Loss: 1.895502
Saving Epoch 3
Saved!
Epoch: 4 	Training Loss: 1.843407
Saving Epoch 4
Saved!
Epoch: 5 	Training Loss: 1.755164
Saving Epoch 5
Saved!
Epoch: 6 	Training Loss: 1.692268
Saving Epoch 6
Saved!
Epoch: 7 	Training Loss: 1.690748
Saving Epoch 7
Saved!
Epoch: 8 	Training Loss: 1.672582
Saving Epoch 8
Saved!
Epoch: 9 	Training Loss: 1.625369
Saving Epoch 9
Saved!
Epoch: 10 	Training Loss: 1.654988
Saving Epoch 10
Saved!
Epoch: 11 	Training Loss: 1.654690
Saving Epoch 11
Saved!
Epoch: 12 	Training Loss: 1.693094
Saving Epoch 12
Saved!
Epoch: 13 	Training Loss: 1.613992
Saving Epoch 13
Saved!
Epoch: 14 	Training Loss: 1.608643
Saving Epoch 14
Saved!
Epoch: 15 	Training Loss: 1.623326
Saving Epoch 15
Saved!
Epoch: 16 	Training Loss: 1.573311
Saving Epoch 16
Saved!
Epoch: 17 	Training Loss: 1.625413

In [None]:
model.load_state_dict(torch.load('model_85.pt'))

<All keys matched successfully>

In [None]:
from IPython.display import Image 
data_root='path_to_inference_dir'
model.eval()
i=0
for batch_idx, (data, target) in enumerate(load_val()):
  data = data.cuda()
  output = model(data)
  R = F.sigmoid(output[:,0:3,:,:])
  L = F.sigmoid(output[:,3:4,:,:])
  L = torch.cat((L, L, L), dim=1)
  #a = 0.75
  print(L.shape)
  im = R
  img_o = data.cpu().squeeze(0)
  im = im.cpu().squeeze(0)
  #imm = (a*im + (1-a)*img_o)
  img = transforms.functional.to_pil_image(im, mode=None)
  img_o = transforms.functional.to_pil_image(img_o, mode=None)
  #imm = transforms.functional.to_pil_image(imm, mode=None)
  display(img)
  img.save(data_root + str(i) + ".png","PNG")
  display(img_o)
  #display(imm)
  img_o.save(data_root + str(i) + "_.png","PNG" )
  i += 1