<a href="https://colab.research.google.com/github/tejasbana/DCGANS/blob/main/clean_DCGANS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch 
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.nn as nn
from skimage.color import rgb2lab, lab2rgb, rgb2gray, gray2rgb
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
import random
import cv2
%matplotlib inline

In [None]:
!gdown --id 19N6B89I5ATZaP9s8zkS_IbP4LIm_V5Hu
!unzip ./dataset.zip

!mkdir images
!mv /content/dataset/Test /content/images

In [None]:
image_size = 128
batch_size = 16


data_dir = "/content/dataset"
print(len(os.listdir(data_dir +"/Train")))
print(data_dir +"/Train")



In [None]:
train_ds = ImageFolder(data_dir , transform=T.Compose([ T.Resize(image_size),
                                                        T.ToTensor()          ]))

train_loader = DataLoader(train_ds , batch_size , num_workers=3 , pin_memory=True)

**RGB_TENSOR_BATCH   to    LAB_TENSOR_BACTH**

In [None]:
def convert_RGB_bacth_to_lab_batch(real_images):
  xb = to_cpu(real_images , 'cpu')
  numpy_rgb_batch = xb.permute(0,2,3,1).numpy()                   #convert rgb image tensor of range[0,1) to numpy array
  numpy_lab_batch = rgb2lab(numpy_rgb_batch)                      # numpy LAB images of range[-127,127] shape : (batch_size,128,128,3)
  numpy_lab_batch[:,:, :, 0] *= 255 / 100                         #Transform the numpy lab images to images of range [0, 1] 
  numpy_lab_batch[:,:, :, 1] += 128
  numpy_lab_batch[:,:, :, 2] += 128
  numpy_lab_batch /= 255
  torch_lab_batch  = torch.from_numpy(np.transpose(numpy_lab_batch, (0,3,1,2))).type(torch.float32)       # torch LAB image shape:(batch_size,3,128,128)
  return torch_lab_batch.detach()

# MOVE TO GPU OR CPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def to_cpu(data, device='cpu'):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
          img , l = b
          img = convert_RGB_bacth_to_lab_batch(img)
          b = img , l
          yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

# Model

In [None]:
class UNet(torch.nn.Module):

  def unet_conv(self , ch_in , ch_out , is_leaky):
    if is_leaky:
      return nn.Sequential(
          nn.Conv2d(ch_in , ch_out , 3 , padding=1),
          nn.BatchNorm2d(ch_out),
          nn.LeakyReLU(0.2),
          nn.Conv2d(ch_out , ch_out , 3 , padding=1),
          nn.BatchNorm2d(ch_out),
          nn.LeakyReLU(0.2)
      )
    else:
      return nn.Sequential(
          nn.Conv2d(ch_in , ch_out , 3 , padding=1),
          nn.BatchNorm2d(ch_out),
          nn.ReLU(),
          nn.Conv2d(ch_out , ch_out , 3 , padding=1),
          nn.BatchNorm2d(ch_out),
          nn.ReLU()
      )
    
  def up(self,ch_in,ch_out):
    return nn.Sequential(
        nn.ConvTranspose2d(ch_in , ch_out , 3, 2 , 1 ,1),
        nn.ReLU()
        )
  
  def __init__(self, is_leaky):
    super(UNet,self).__init__()

    # First encoding layer
    self.conv1 = self.unet_conv(1,64, is_leaky)                     # IN : 128 x 128  , OUT : 128 x 128
    # Second encoding layer
    self.conv2 = self.unet_conv(64,128 , is_leaky)                  # IN : 128 x 128  , OUT : 64 x 64
    # Third encoding layer
    self.conv3 = self.unet_conv(128,256 , is_leaky)                 # IN : 64 x 64  , OUT : 32 x 32
    # Forth encoding layer
    self.conv4 = self.unet_conv(256,512, is_leaky)                  # IN : 32 x 32  , OUT : 16 x 16
    # Fifth encoding layer
    self.conv5 = self.unet_conv(512,1024, is_leaky)                 # IN : 16 x 16  , OUT : 8 x 8
    # sixth enconding layer
    self.conv6 = self.unet_conv(1024,2048 , is_leaky)               # IN : 8 x 8  , OUT : 4 x 4
    # Seventh encoding layer
    self.conv7 = self.unet_conv(2048,1024 , is_leaky)               # IN : 4 x 4  , OUT : 2 x 2


    #Pooling layer
    self.pool = nn.MaxPool2d(2)


    # First Upsampling layer
    self.up1 = self.up(1024,2048)                                   # IN : 2 x 2  , OUT : 4 x 4
    # Second Upsampling layer
    self.up2 = self.up(2048,1024)                                   # IN : 4 x 4  , OUT : 8 x 8
    # Third Upsampling layer
    self.up3 = self.up(1024,512)                                    # IN : 8 x 8  , OUT : 16 x 16
    # Fourth Upsampling layer
    self.up4 = self.up(512,256)                                     # IN : 16 x 16  , OUT : 32 x 32
    # Fifth Upsampling layer
    self.up5 = self.up(256,128)                                     # IN : 32 x 32  , OUT : 64 x 64
    # Sixth Upsampling layer
    self.up6 = self.up(128,64)                                      # IN : 64 x 64  , OUT : 128 x 128


    # First Decoding layer
    self.conv8  = self.unet_conv(4096 ,2048, False)
    # Second Decoding layer
    self.conv9  = self.unet_conv(2048, 1024, False)
    # Third Decoding layer
    self.conv10 = self.unet_conv(1024, 512 , False)
    # Fourth Decoding layer
    self.conv11 = self.unet_conv(512,  256 , False)
    # Fifth Decoding layer 
    self.conv12 = self.unet_conv(256,  128 , False)
    # Sixth Decoding layer
    self.conv13 = self.unet_conv(128,  64  , False)


    #Last layer
    self.conv14 = nn.Conv2d(64,2,1)                   #IN_channel : 64 , OUT: 2 , Kernel_size = 1

  
  def forward(self, x):

    #Encoding Path
    x1 = self.conv1(x)
    x2 = self.conv2(self.pool(x1))
    x3 = self.conv3(self.pool(x2))
    x4 = self.conv4(self.pool(x3))
    x5 = self.conv5(self.pool(x4))
    x6 = self.conv6(self.pool(x5))
    x7 = self.conv7(self.pool(x6))

    #Decoding Path
    x  = self.conv8( torch.cat( ( x6 ,  self.up1(x7) ),1 ) )
    x  = self.conv9( torch.cat( ( x5 ,  self.up2(x) ), 1 ) )
    x  = self.conv10( torch.cat(( x4 , self.up3(x) ), 1 ) )
    x  = self.conv11( torch.cat(( x3 , self.up4(x) ), 1 ) )
    x  = self.conv12( torch.cat(( x2 , self.up5(x) ), 1 ) )
    x  = self.conv13( torch.cat(( x1 , self.up6(x) ), 1 ) )

    x = self.conv14(x)
    m = nn.Tanh()
    x = m(x)

    return x



# discriminator

In [None]:
class DNet(torch.nn.Module):
  def unet_conv(self, ch_in , ch_out):
    return nn.Sequential(
        nn.Conv2d(ch_in , ch_out , 3, padding=1),
        nn.BatchNorm2d(ch_out),
        nn.LeakyReLU(0.2),
        nn.Conv2d(ch_out , ch_out , 3, padding=1),
        nn.BatchNorm2d(ch_out),
        nn.LeakyReLU(0.2)
    )

  def __init__(self):
    super(DNet,self).__init__()

    # First layer
    self.conv1 = self.unet_conv(3,64)
    # Second layer
    self.conv2 = self.unet_conv(64,64)
    # Third layer 
    self.conv3 = self.unet_conv(64,128)
    # Fourth layer
    self.conv4 = self.unet_conv(128,128)
    # Fifth layer
    self.conv5 = self.unet_conv(128,256)
    # Sixth layer
    self.conv6 = self.unet_conv(256,512)
    # Seventh layer
    self.conv7 = self.unet_conv(512,1024)

    #Pooling layer
    self.pool = nn.MaxPool2d(2)

    #Last layer
    self.linear = nn.Linear(2*2*1024 , 1)

  def forward(self,x):
    x1 = self.conv1(x)
    x2 = self.conv2(self.pool(x1))
    x3 = self.conv3(self.pool(x2))
    x4 = self.conv4(self.pool(x3))
    x5 = self.conv5(self.pool(x4))
    x6 = self.conv6(self.pool(x5))
    x7 = self.conv7(self.pool(x6))

    x8 = x5.reshape(-1,2*2*1024)
    m = nn.Sigmoid()
    x = m(self.linear(x8))

    return x

# Move to GPU

In [None]:
discriminator = DNet()
generator = UNet(True)
generator.cuda()
discriminator.cuda()

train_loader = DeviceDataLoader(train_loader , device)

# Test

In [None]:
for batch in train_loader:
  img , l = batch
  #lab_batch = convert_RGB_bacth_to_lab_batch(img) 
  #gray_batch = lab_batch[:,0,:,:]
  #gray_batch = to_device(gray_batch , device)
  gray_batch = img[:,0,:,:]
  gray_batch = gray_batch.unsqueeze(1)
  print(gray_batch.shape)
  out = unet(gray_batch)
  print(out.shape)
  out = discriminator(torch.cat([gray_batch , out],1))
  print(out.shape)
  break;

# Fit

In [None]:
d_optimizer = torch.optim.Adam(discriminator.parameters() , betas=(0.5,0.999) , lr = 0.0002)
g_optimizer = torch.optim.Adam(generator.parameters() , betas=(0.5,0.999) , lr = 0.0002)

d_criterion = nn.BCELoss()
g_criterion_1 = nn.BCELoss()
g_criterion_2 = nn.L1Loss()

def train(epochs):
  losses_g = []
  losses_d = []
  g_lambda = 100
  smooth = 0.1
  

  for epoch in range(epochs):
    d_running_loss = 0.0
    g_running_loss = 0.0

    for lab_batch , _ in train_loader:
      # split the lab color space images into luminescence and chrominance channels.
      l_images = lab_batch[:,0,:,:]
      c_images = lab_batch[:,1:,:,:]
      # shift the source and target images into the range [-1, 1].
      mean = torch.Tensor([0.5])
      l_images = l_images - mean.expand_as(l_images).cuda()
      l_images = 2*l_images
      l_images = l_images.unsqueeze(1)

      c_images = c_images - mean.expand_as(c_images).cuda()
      c_images = 2*c_images

      batch_size = l_images.shape[0]

      # fake images are generated by passing them through the generator.
      fake_images = generator(l_images)

      # Train the discriminator. The loss would be the sum of the losses over
	    # the source and fake images, with greyscale images as the condition.
      d_optimizer.zero_grad()
      d_loss = 0
      preds = discriminator(cat([l_images, c_images] , 1))
      d_real_loss = d_criterion(preds.squeeze(1) , ((1 - smooth) * torch.ones(batch_size)).cuda() )

      preds = discriminator(cat([l_images, fake_images] , 1))
      d_fake_loss = d_criterion(preds.squeeze(1) , (torch.zeros(batch_size)).cuda())

      d_loss = d_real_loss + d_fake_loss
      d_loss.backward(retain_graph=True)
      d_optimizer.step()

      # Train the generator. The loss would be the sum of the adversarial loss
	    # due to the GAN and L1 distance loss between the fake and target images.

      g_optimizer.zero_grad()
      g_loss = 0
      fake_preds = discriminator(cat([l_images , fake_images] , 1))
      g_fake_loss = g_criterion_1(fake_preds.squeeze(1) , (torch.ones(batch_size)).cuda())         # generator loss 1

      g_image_distance_loss = g_lambda * g_criterion_2(fake_images , c_images)         # generator loss 2

      g_loss = g_fake_loss + g_image_distance_loss
      g_loss.backward(retain_graph=True)
      g_optimizer.step()

      # print statistics on pre-defined intervals.
      d_running_loss += d_loss.detach()
      g_running_loss += g_fake_loss.detach()

    print('Epoch : {} , g_epoch_loss : {:.4f} , d_epoch_loss : {:.4f}'.format(epoch,g_running_loss,d_running_loss))
    losses_g.append(d_running_loss)
    losses_d.append(g_running_loss)
  return losses_g , losses_d




In [None]:
discriminator.load_state_dict(torch.load('/content/drive/MyDrive/weights/discriminator.ckpt'))
generator.load_state_dict(torch.load('/content/drive/MyDrive/weights/generator.ckpt'))

In [None]:
history = train(5)

In [None]:
torch.save(generator.state_dict() , 'generator.ckpt')
torch.save(discriminator.state_dict() , 'discriminator.ckpt')

In [None]:
def denorm(l_image ,fake_image):
  l_image = l_image * 100
  mean = torch.Tensor([0.5])  
  fake_image = fake_image + mean.expand_as(fake_image)
  fake_image *= 255
  fake_image -= 128
  torch_image = cat([l_image ,fake_image] , 0).detach()
  numpy_image = torch_image.permute(1,2,0).numpy()
  rgb_image = lab2rgb(numpy_image)
  plt.imshow(rgb_image)


In [None]:
rgb_image , l = train_ds[2222]                                      #take single image from dataset
numpy_rgb_image = rgb_image.permute(1,2,0).numpy()                #convert rgb image tensor of range[0,1) to numpy array 
numpy_lab_image = rgb2lab(numpy_rgb_image)                        # numpy LAB images of range[-127,127] shape : (128,128,3)
#numpy_lab_image = cv2.cvtColor(numpy_rgb_image, cv2.COLOR_RGB2LAB)

numpy_lab_image[:, :, 0] *= 255 / 100
numpy_lab_image[:, :, 1] += 128
numpy_lab_image[:, :, 2] += 128
numpy_lab_image /= 255                                            #Transform the numpy lab images to images of range [0, 1] 

torch_lab_image = torch.from_numpy(np.transpose(numpy_lab_image, (2, 0, 1)))       # torch LAB image
gray_image = torch_lab_image[0,:,:].type(torch.float32)                             #shape [2 dim]
mean = torch.Tensor([0.5])  
gray_image = gray_image - mean.expand_as(gray_image)              

l_image = gray_image.type(torch.float32)
gray_image = gray_image.unsqueeze(0)
gray_image = gray_image.unsqueeze(0).type(torch.float32)

fake_image = generator(gray_image.cuda())
fake_image = to_cpu(fake_image)
fake_image = fake_image.squeeze(0)                      # remove batch channel
print(fake_image.shape)
l_image = l_image.unsqueeze(0)
denorm(l_image ,fake_image)

print(l_image.shape)
pred = cat([l_image ,fake_image] , 0).detach()
pred = pred.permute(1,2,0).numpy()
#print(pred)
#plt.imshow(pred)
#plt.imshow(numpy_lab_image)