In [None]:
VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [None]:
#!pip install -U git+https://github.com/albumentations-team/albumentations
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
from torch.autograd import Variable
import numpy as np
import os 
import pathlib
from torch.utils.data import Dataset,DataLoader
import time
import matplotlib.pyplot as plt
from torchvision import transforms
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import datetime
import sys
#import torch_xla
#import torch_xla.core.xla_model as xm

In [None]:
image_paths_flickr = [x for x in pathlib.Path('../input/flickr-image-dataset').rglob('*.jpg')]
image_paths_landmark = [x for x in pathlib.Path('../input/landmark-recognition-2020/train').rglob('*.jpg')]

In [None]:
f"{len(image_paths_flickr)},{len(image_paths_landmark)}"

In [None]:
all_image_paths = image_paths_flickr + image_paths_landmark
all_image_paths = [str(i) for i in all_image_paths]
f"{len(all_image_paths)} {all_image_paths[0]}"

In [None]:
def show_image(arr):
    arr = np.moveaxis(arr,0,-1)
    plt.imshow(arr)


In [None]:
img = cv2.imread(all_image_paths[-1])
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = np.moveaxis(img,-1,0) # Channel first
print(img.shape)
show_image(img)

In [None]:
gimg = cv2.cvtColor(np.moveaxis(img,0,-1),cv2.COLOR_RGB2GRAY)
gimg = np.stack((gimg,gimg,gimg),axis=0)

print(gimg.shape)
show_image(gimg)

In [None]:
transforms = A.Compose([
    A.Resize(256,256),
    A.Normalize(),
    A.pytorch.transforms.ToTensorV2()
            ])

In [None]:
class ColorDataset(Dataset):
    def __init__(self,img_list,transform):
        self.img_list = img_list
        self.transform = transform
    
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx): idx = idx.tolist()
        img = cv2.imread(all_image_paths[idx])
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        gimg = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
        gimg = np.stack((gimg,gimg,gimg),axis=-1)

        img = self.transform(image = img)["image"]
        gimg = self.transform(image=gimg)["image"]

        return img,gimg
    
            

In [None]:
Cds = ColorDataset(all_image_paths,transforms)

In [None]:
Cds[0][0].shape

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           U-NET
##############################


class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(generator, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


##############################
#        Discriminator
##############################


class discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

In [None]:
netG = generator()
netD = discriminator()

In [None]:
lr = 0.001
batch_size = 32
n_epochs = 1
img_height = 256
img_width = 256
optim_G = torch.optim.Adam(netG.parameters(),lr=lr)
optim_D = torch.optim.Adam(netD.parameters(),lr=lr)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)#loss of image disc patch gan
lambda_pixel = 100


In [None]:
netG = netG.to(device)
netD = netD.to(device)
criterion_GAN.to(device)
criterion_pixelwise.to(device)


In [None]:
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []

In [None]:
imdl = DataLoader(Cds,batch_size = batch_size,shuffle=True)

In [None]:
def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(imdl))
    real_A = Variable(imgs[1].type(Tensor))
    real_B = Variable(imgs[0].type(Tensor))
    fake_B = netG(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2).cpu().numpy().astype(np.float32)
    img_sample -=img_sample.min()
    img_sample/=img_sample.max()
    img_sample = img_sample.transpose(0,2,3,1)
    plt.figure(figsize=[10,20])
    for row in range(3):
        plt.subplot(1,3,row+1)
        plt.imshow(img_sample[row])
    plt.show()

In [None]:
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor


In [None]:
prev_time = time.time()
sample_interval = 500
checkpoint_interval = 338
for epoch in range(n_epochs):
    for i, batch in enumerate(imdl):

        # Model inputs
        real_A= Variable(batch[1].type(Tensor))#black
        real_B= Variable(batch[0].type(Tensor))#color

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)#doubt here(loss of umage disc patch gan)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optim_G.zero_grad()

        # GAN loss
        fake_B = netG(real_A)
        pred_fake = netD(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optim_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optim_D.zero_grad()

        # Real loss
        pred_real = netD(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = netD(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optim_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(imdl) + i
        batches_left = n_epochs * len(imdl) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(imdl),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
                loss_GAN.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)
        if batches_done % 5000 == 0:
            torch.save(netG.state_dict(), f'colourizeGen{n_epochs}.pt')
            torch.save(netD.state_dict(), f'colourizeDis{n_epochs}.pt')
            print("model saved ")
            
            

  #  if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
   #     # Save model checkpoints
     #   torch.save(netG.state_dict(), f'colourizeGen{epoch+20}.pth')
     #   torch.save(netD.state_dict(), f'colourizeDis{epoch+20}.pth')

<a href="./colourizeGen1.pt"> Download File </a>

## 

In [None]:
torch.save(netG.state_dict(), f'colourizeGen{n_epoch}.pt')
torch.save(netD.state_dict(), f'colourizeDis{n_epoch}.pt')