In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import load_img,img_to_array
import torch
import os
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
class CelebDataset(Dataset):
    def __init__(self,root_dir,inp_size,dataset_size,transforms=None):
        self.root_dir=root_dir
        self.inp_size=inp_size
        self.transforms=transforms
        self.dataset_size=dataset_size
        
    def __len__(self):
        return len(os.listdir(self.root_dir)[:self.dataset_size])
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx=idx.tolist()

        img_name=os.path.join(self.root_dir,os.listdir(self.root_dir)[:self.dataset_size][idx])
        color=img_to_array(load_img(img_name,target_size=(216,216,3)))
        bw=img_to_array(load_img(img_name,target_size=(self.inp_size[0],self.inp_size[1]),color_mode='grayscale'))
        sample={'color': color, 'bw': bw}
        
        if self.transforms:
            sample=self.transforms(sample)
            
        return sample
    
class ToTensor(object):
    def __call__(self, sample):
        color, bw = sample['color'], sample['bw']
        bw=np.transpose(bw,(2,0,1))
        return {'bw': torch.tensor(bw).cuda(),'color': torch.tensor(color,dtype=torch.long).cuda()}
    
transforms=transforms.Compose([ToTensor()])

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        
        self.conv1=nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3)
        self.batchn1=nn.BatchNorm2d(num_features=64)
        self.relu1=nn.ReLU(inplace=True)
        
        self.conv2=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3)
        self.batchn2=nn.BatchNorm2d(num_features=64)
        self.relu2=nn.ReLU(inplace=True)
        
        self.down1=nn.Conv2d(in_channels=64,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv3=nn.Conv2d(in_channels=1,out_channels=128,kernel_size=3)
        self.batchn3=nn.BatchNorm2d(num_features=128)
        self.relu3=nn.ReLU(inplace=True)
        
        self.conv4=nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3)
        self.batchn4=nn.BatchNorm2d(num_features=128)
        self.relu4=nn.ReLU(inplace=True)
        
        self.down2=nn.Conv2d(in_channels=128,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv5=nn.Conv2d(in_channels=1,out_channels=256,kernel_size=3)
        self.batchn5=nn.BatchNorm2d(num_features=256)
        self.relu5=nn.ReLU(inplace=True)
        
        self.conv6=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3)
        self.batchn6=nn.BatchNorm2d(num_features=256)
        self.relu6=nn.ReLU(inplace=True)
        
        self.down3=nn.Conv2d(in_channels=256,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv7=nn.Conv2d(in_channels=1,out_channels=512,kernel_size=3)
        self.batchn7=nn.BatchNorm2d(num_features=512)
        self.relu7=nn.ReLU(inplace=True)
        
        self.conv8=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3)
        self.batchn8=nn.BatchNorm2d(num_features=512)
        self.relu8=nn.ReLU(inplace=True)
        
        self.down4=nn.Conv2d(in_channels=512,out_channels=1,kernel_size=2,stride=2)
        '''
        self.conv9=nn.Conv2d(in_channels=1,out_channels=1024,kernel_size=3)
        self.batchn9=nn.BatchNorm2d(num_features=1024)
        self.relu9=nn.ReLU(inplace=True)
        
        self.conv10=nn.Conv2d(in_channels=1024,out_channels=1024,kernel_size=3)
        self.batchn10=nn.BatchNorm2d(num_features=1024)
        self.relu10=nn.ReLU(inplace=True)
        '''
        self.up1=nn.ConvTranspose2d(in_channels=1024,out_channels=512,kernel_size=2,stride=2)
        
        self.conv11=nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=3)
        self.batchn11=nn.BatchNorm2d(num_features=512)
        self.relu11=nn.ReLU(inplace=True)
        
        self.conv12=nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3)
        self.batchn12=nn.BatchNorm2d(num_features=512)
        self.relu12=nn.ReLU(inplace=True)
        '''
        self.up2=nn.ConvTranspose2d(in_channels=1024,out_channels=256,kernel_size=2,stride=2)#1024,512
        
        self.conv13=nn.Conv2d(in_channels=512,out_channels=256,kernel_size=3)
        self.batchn13=nn.BatchNorm2d(num_features=256)
        self.relu13=nn.ReLU(inplace=True)
        
        self.conv14=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3)
        self.batchn14=nn.BatchNorm2d(num_features=256)
        self.relu14=nn.ReLU(inplace=True)
        '''
        self.up3=nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=2,stride=2)
        
        self.conv15=nn.Conv2d(in_channels=256,out_channels=128,kernel_size=3)
        self.batchn15=nn.BatchNorm2d(num_features=128)
        self.relu15=nn.ReLU(inplace=True)
        
        self.conv16=nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3)
        self.batchn16=nn.BatchNorm2d(num_features=128)
        self.relu16=nn.ReLU(inplace=True)
        '''
        self.up4=nn.ConvTranspose2d(in_channels=256,out_channels=64,kernel_size=2,stride=2)#256,128
        
        self.conv17=nn.Conv2d(in_channels=128,out_channels=64,kernel_size=3)
        self.batchn17=nn.BatchNorm2d(num_features=64)
        self.relu17=nn.ReLU(inplace=True)
        
        self.conv18=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3)
        self.batchn18=nn.BatchNorm2d(num_features=64)
        self.relu18=nn.ReLU(inplace=True)
        
        self.outR=nn.Conv2d(in_channels=64,out_channels=256,kernel_size=1)
        self.outG=nn.Conv2d(in_channels=64,out_channels=256,kernel_size=1)
        self.outB=nn.Conv2d(in_channels=64,out_channels=256,kernel_size=1)
        
    def forward(self,t):
        
        t=self.conv1(t)
        t=self.batchn1(t)
        t=self.relu1(t)
        
        t=self.conv2(t)
        t=self.batchn2(t)
        c1=self.relu2(t)
        
        t=self.down1(c1)
        '''
        t=self.conv3(t)
        t=self.batchn3(t)
        t=self.relu3(t)
        
        t=self.conv4(t)
        t=self.batchn4(t)
        c2=self.relu4(t)
        
        t=self.down2(c2)
        '''
        t=self.conv5(t)
        t=self.batchn5(t)
        t=self.relu5(t)
        
        t=self.conv6(t)
        t=self.batchn6(t)
        c3=self.relu6(t)
        
        t=self.down3(c3)
        '''
        t=self.conv7(t)
        t=self.batchn7(t)
        t=self.relu7(t)
        
        t=self.conv8(t)
        t=self.batchn8(t)
        c4=self.relu8(t)
        
        t=self.down4(c4)
        '''
        t=self.conv9(t)
        t=self.batchn9(t)
        t=self.relu9(t)
        
        t=self.conv10(t)
        t=self.batchn10(t)
        t=self.relu10(t)
        '''
        t=self.up1(t)
        c4=c4[:,:,4:20,4:20]
        t=torch.cat((t,c4),dim=1)
        
        t=self.conv11(t)
        t=self.batchn11(t)
        t=self.relu11(t)
        
        t=self.conv12(t)
        t=self.batchn12(t)
        t=self.relu12(t)
        '''
        t=self.up2(t)
        
        c3=c3[:,:,4:118,4:118]  #40,118  16,4
        t=torch.cat((t,c3),dim=1)
        
        t=self.conv13(t)
        t=self.batchn13(t)
        t=self.relu13(t)
        
        t=self.conv14(t)
        t=self.batchn14(t)
        t=self.relu14(t)
        '''
        t=self.up3(t)
        c2=c2[:,:,41:81,41:81]
        t=torch.cat((t,c2),dim=1)
        
        t=self.conv15(t)
        t=self.batchn15(t)
        t=self.relu15(t)
        
        t=self.conv16(t)
        t=self.batchn16(t)
        t=self.relu16(t)
        '''
        t=self.up4(t)
        
        c1=c1[:,:,16:236,16:236]  #16,90  162,236
        t=torch.cat((t,c1),dim=1)
        
        t=self.conv17(t)
        t=self.batchn17(t)
        t=self.relu17(t)
        
        t=self.conv18(t)
        t=self.batchn18(t)
        t=self.relu18(t)
        
        R=self.outR(t)
        G=self.outG(t)
        B=self.outB(t)
        
        return R,G,B
        

In [None]:
from torch import optim

device = torch.device("cuda")
color_model=Network()
color_model.load_state_dict(torch.load('/kaggle/input/umodels/UNET_Pytorch.pkl',map_location="cuda:0"),strict=False)
color_model.to(device)
color_model.eval()

EPOCHS=200
BATCH_SIZE=18
LEARNING_RATE=0.00001
optimizer=optim.Adam(color_model.parameters(), lr=LEARNING_RATE)
celeb_dataset=CelebDataset(root_dir='/kaggle/input/celebrities-100k/100k/100k',inp_size=(256,256,3),dataset_size=100000,transforms=transforms)
celeb_dataloader=DataLoader(celeb_dataset,batch_size=BATCH_SIZE,shuffle=True)

In [None]:
for epoch in range(EPOCHS):
    epoch_loss=0
    for i,sample in enumerate(celeb_dataloader):
        X,y=sample['bw'],sample['color']
        R,G,B=color_model(X)
        optimizer.zero_grad()
        
        Rloss=F.cross_entropy(R,y[:,:,:,0])
        Gloss=F.cross_entropy(G,y[:,:,:,1])
        Bloss=F.cross_entropy(B,y[:,:,:,2])
        
        loss=Rloss+Gloss+Bloss
        epoch_loss+=loss.item()/BATCH_SIZE
        
        loss.backward()
        optimizer.step()
        if(i%100==0):
            print(i,' : ',loss.item(),sep='')
    
    print('        Epoch: {0}, Loss: {1}'.format(epoch,epoch_loss))

In [None]:
torch.save(color_model.state_dict(),'/kaggle/working/UNET_Pytorch.pkl')

In [None]:
import torchvision
import cv2
from PIL import Image
from skimage import data, color, io
def get_color_img(color_model,n_samples):
    dataset=CelebDataset(root_dir='/kaggle/input/celebrities-100k/100k/100k',inp_size=(216,216,3),dataset_size=1000,transforms=transforms)
    dataloader=DataLoader(celeb_dataset,batch_size=n_samples)
    plotdata=DataLoader(dataset,batch_size=n_samples)
    
    samples=iter(dataloader)
    sample=next(samples)
    
    plot_samples=iter(plotdata)
    plot_sample=next(plot_samples)
    
    X,y=sample['bw'],sample['color']
    R,G,B=color_model(X)
    
    R=np.expand_dims(torch.argmax(nn.Softmax(dim=1)(R),dim=1).cpu().numpy(),axis=1)
    G=np.expand_dims(torch.argmax(nn.Softmax(dim=1)(G),dim=1).cpu().numpy(),axis=1)
    B=np.expand_dims(torch.argmax(nn.Softmax(dim=1)(B),dim=1).cpu().numpy(),axis=1)
    
    y_pred=np.array(np.transpose(np.concatenate((R,G,B),axis=1),(0,2,3,1)),dtype=np.int32)
                
    X=X.cpu()
    y=y.cpu()
    
    plt.figure(figsize=(15,15))
    for i in range(n_samples):
        plt.subplot(1,n_samples,i+1)
        plt.imshow(X[i][0],cmap='gray')
    plt.show()
    
    plt.figure(figsize=(15,15))
    for i in range(n_samples):
        plt.subplot(1,n_samples,i+1)
        plt.imshow(y[i])
    plt.show()
    v=np.expand_dims(plot_sample['bw'].cpu().numpy(),axis=-1)
    return y_pred,np.array(np.concatenate((v,v,v),axis=-1),dtype=np.int32)
    
    
samples,X=get_color_img(color_model,4)
fig=plt.figure(figsize=(15,15))
for i in range(samples.shape[0]):
    plt.subplot(1,samples.shape[0],i+1)
    b=samples[i]/255
    plt.imshow(b+X[i][0]/1000)
plt.savefig('/kaggle/working/demo.jpg')
plt.show()