<a href="https://colab.research.google.com/github/v0rt3x0/DL_submit/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import ConvTranspose2d
from torch.nn import MSELoss
from torch import concat
import numpy as np
from torch import optim
from PIL import Image
import os
print(torch.__version__)
print(Image.__version__)
!python --version


In [None]:
print(np.__version__)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#The model has total 3 components. A decomposition network. A reflectance restoration network and an illumination adjustment network. Each is a deep network on its own. I will create 3 classes for each component which will be instantiated later to make the mode

In [None]:
#We opted to use 8 filters instead of the 32 specified by paper because it was possible our local machine or colab might not be able to train the model as effectively.
class decompose(nn.Module):
    def __init__(self, filters=32, activation='relu',kernel_size=3):
        super().__init__()
        #declaring layers as fields of DeComposition network class
        self.conv_input = Conv2d(3, filters,kernel_size=3,padding=1)
        self.maxpool_r1 = MaxPool2d(kernel_size=2,stride=2)
        self.conv_r1 = Conv2d(filters, filters*2,kernel_size=3,padding=1)
        self.maxpool_r2 = MaxPool2d(kernel_size=2,stride=2)
        self.conv_r2 = Conv2d(filters*2, filters*4,kernel_size=3,padding=1)
        self.deconv_r1 = ConvTranspose2d(filters*4, filters*2,kernel_size=2,padding=0,stride=2)
        #self.concat_r1 = concat()
        self.conv_r3 = Conv2d(filters*2, filters*2,kernel_size=3,padding=1)
        self.deconv_r2 = ConvTranspose2d(filters*2, filters,kernel_size=2,padding=0,stride=2)
        #self.concat_r2 = concat()
        self.conv_r4 = Conv2d(filters, filters,kernel_size=3,padding=1)
        self.conv_r5 = nn.Conv2d(filters, 3, kernel_size=3, padding=1)
        self.R_out = nn.Sigmoid()

        self.conv_i1 = Conv2d(filters, filters,kernel_size=3,padding=1)
        #self.concat_i1 = concat()
        self.conv_i2 = nn.Conv2d(filters, 1, kernel_size=3, padding=1)
        self.I_out = nn.Sigmoid()
        
        
    def forward(self, x):
        conv_input = self.conv_input(x)
        #Reflectance component R
        maxpool_r1 = self.maxpool_r1(conv_input)
        conv_r1 = self.conv_r1(maxpool_r1)
        maxpool_r2 = self.maxpool_r2(conv_r1)
        conv_r2 = self.conv_r2(maxpool_r2)
        deconv_r1 = self.deconv_r1(conv_r2)
        #concat_r1 = self.concat_r1(conv_r1, deconv_r1)
        conv_r3 = self.conv_r3(deconv_r1)
        deconv_r2 = self.deconv_r2(conv_r3)
        #concat_r2 = self.concat_r2(conv_input, deconv_r2)
        conv_r4 = self.conv_r4(deconv_r2)
        conv_r5 = self.conv_r5(conv_r4)
        R_out = self.R_out(conv_r5)
        # Illumination component I
        conv_i1 = self.conv_i1(conv_input)
        #concat_i1 = self.concat_i1(conv_r4, conv_i1)
        conv_i2 = self.conv_i2(conv_i1)
        I_out = self.I_out(conv_i2)
        return I_out,R_out
        #final output image will just be = I*R where * is pixel by pixel multiplication.

In [None]:
class Illum_adjust(nn.Module):
    def __init__(self, filters=8, activation='lrelu',kernel_size = 3):
        super().__init__()
        # bottom path build Illumination map
        self.conv_i1 = Conv2d(2, filters,kernel_size = 3)
        self.conv_i2 = Conv2d(filters, filters,kernel_size = 3)
        self.conv_i3 = Conv2d(filters, filters,kernel_size = 3)
        self.conv_i4 = nn.Conv2d(filters, 1, kernel_size=3, padding=1)
        self.I_out = nn.Sigmoid()

    def forward(self, I, ratio):
        with torch.no_grad():
            ratio_map = torch.ones_like(I) * ratio
        concat_input = concat((I, ratio_map))        
        # build Illumination map
        conv_i1 = self.conv_i1(concat_input)
        conv_i2 = self.conv_i2(conv_i1)
        conv_i3 = self.conv_i3(conv_i2)
        conv_i4 = self.conv_i4(conv_i3)
        I_out = self.I_out(conv_i4)

        return I_out

In [None]:
class final_network(nn.Module):
    def __init__(self, filters=8, activation='lrelu',kernel_size = 3):
        super().__init__()
        self.decomposition = decompose()
        self.illumAdjust = Illum_adjust()
    
    def forward(self, L, ratio):
        R, I = self.decompose(L)
        I1 = self.illum_adjust(I,ratio)
        I1_3 = torch.cat([I1,I1,I1],dim=1)
        output = I1_3*R
        return R, I1_3, output

In [None]:
import torch
from torch import reshape
from torchvision import transforms
convert_tensor = transforms.ToTensor()
def training_decompose(model):
  px = "drive/MyDrive/DL dataset"
  py = "drive/MyDrive/DL dataset/Label"
  optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  criterion = MSELoss()
  running_loss = 0
  total = 0
  correct = 0
  for epoch in range(50):
    print('\nEpoch : %d'%epoch)
    for i in range(1,321,1):
      optimizer.zero_grad()
      x1 = "x"+ str(i)+".JPG"
      y1 = str(i)+".JPG"
      tenx = convert_tensor(Image.open(os.path.join(px,str(i),x1)))
      teny = convert_tensor(Image.open(os.path.join(py,y1)))
      IL,RL = model(tenx)
      IH,RH = model(teny)
      loss = criterion(IH,IL)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()


In [None]:
import torchvision
print(torchvision.__version__)

In [None]:
import torch
from torch import reshape
from torchvision import transforms
convert_tensor = transforms.ToTensor()
def training_illum(model):
  px = "drive/MyDrive/DL dataset"
  py = "drive/MyDrive/DL dataset/Label"
  optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  criterion = MSELoss()
  for epoch in range(50):
    print('\nEpoch : %d'%epoch)
    for i in range(1,321,1):
      optimizer.zero_grad()
      x1 = "x"+ str(i)+".JPG"
      y1 = str(i)+".JPG"
      tenx = convert_tensor(Image.open(os.path.join(px,str(i),x1)))
      teny = convert_tensor(Image.open(os.path.join(py,y1)))
      IL = model(tenx,0.5)
      IH = model(teny,0.5)
      loss = criterion(IH,IL)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()

In [None]:
decom = decompose()
illum = Illum_adjust()

In [None]:
training_decompose(decom)

In [None]:
training_illum(illum)

In [None]:
final = final_network()
px = "drive/MyDrive/DL dataset"
py = "drive/MyDrive/DL dataset/Label"
x = "x"+ str(5)+".JPG"
y = str(5)+".JPG"
x1 = os.path.join(px,str(5),x)
y1 = os.path.join(py,y)
final(x1,0.5)