In [15]:
from PIL import Image
import glob
import cv2 
import numpy as np
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
from torch.optim import *
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.utils as vutils
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import random
import os
from os import listdir
from os.path import isfile, join

In [145]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
PATH = '/Dataset'
PATH2 = '/models/'
batch_size = 4
lrG = 0.0002
lrD = 0.0002
beta1 = 0.5
beta2 = 0.999
L1lambda = 100
num_epochs = 15

In [17]:
transform = albumentations.Compose([
    albumentations.Resize(256, 256), 
    #albumentations.RandomCrop(224, 224),
    #albumentations.HorizontalFlip(), # Same with transforms.RandomHorizontalFlip()
    albumentations.pytorch.transforms.ToTensor()
])

In [18]:
#import glob
#import cv2 as cv
#path = glob.glob("/path/to/folder/*.jpg")
#cv_img = []
#for img in path:
#    n = cv.imread(img)
#    cv_img.append(n)

In [52]:
onlyfiles = [file for file in glob.glob(os.getcwd()+PATH+"/*") if file.endswith('.png') or file.endswith('.PNG')]
images = np.empty(len(onlyfiles), dtype=object)
for n in range(0, len(onlyfiles)):
  images[n] = cv2.imread( join(PATH, onlyfiles[n]) )

In [139]:
class CustomDataset(Dataset):

    def __init__(self, file, transform=None):
        self.file = file
        self.transform = transform
        
    def __len__(self):
        return len(self.file)

    def __getitem__(self, idx):
        img = self.file[idx]
        sketch_image = cv2.cvtColor(img[:,:256,:], cv2.COLOR_BGR2RGBA)
        real_image = cv2.cvtColor(img[:,256:,:], cv2.COLOR_BGR2RGBA)
        if self.transform:
            augmented1 = self.transform(image=sketch_image) 
            image1 = augmented1['image']
            augmented2 = self.transform(image= real_image) 
            image2 = augmented2['image']
        return image1, image2

In [140]:
dataset = CustomDataset(images, transform)

In [141]:
dataloader =DataLoader(dataset, batch_size=1, shuffle=True)

In [56]:
# U-net 구조를 만들어보자

class Generator(nn.Module):
  def __init__(self, ngf=64): ### Question -> why self.conv_bn no?
    super(Generator, self).__init__()
    self.conv1 = nn.Conv2d(4, ngf, kernel_size=4, stride=2, padding=1) #(3, 256, 256)->(64, 128, 128)
    self.conv2 = nn.Conv2d(ngf, ngf*2, 4, 2, 1) #(64, 128, 128)->(128, 64, 64)
    self.conv2_bn = nn.BatchNorm2d(ngf*2) 
    self.conv3 = nn.Conv2d(ngf*2, ngf*4, 4, 2, 1) #(128, 64, 64)->(256, 32, 32)
    self.conv3_bn = nn.BatchNorm2d(ngf*4) 
    self.conv4 = nn.Conv2d(ngf*4, ngf*8, 4, 2, 1) #(256, 32, 32)->(512, 16, 16)
    self.conv4_bn = nn.BatchNorm2d(ngf*8) 
    self.conv5 = nn.Conv2d(ngf*8, ngf*8, 4, 2, 1) #(512, 16, 16)->(512, 8, 8)
    self.conv5_bn = nn.BatchNorm2d(ngf*8) 
    self.conv6 = nn.Conv2d(ngf*8, ngf*8, 4, 2, 1) #(512, 8, 8)->(512, 4, 4)
    self.conv6_bn = nn.BatchNorm2d(ngf*8) 
    self.conv7 = nn.Conv2d(ngf*8, ngf*8, 4, 2, 1) #(512, 4, 4)->(512, 2, 2)
    self.conv7_bn = nn.BatchNorm2d(ngf*8) 
    self.conv8 = nn.Conv2d(ngf*8, ngf*8, 4, 2, 1) #(512, 2, 2)->(512, 1, 1)
    self.conv8_bn = nn.BatchNorm2d(ngf*8) 

    self.deconv0 = nn.ConvTranspose2d(ngf*8, ngf*8, 4, 2, 1) #(512, 1, 1)->(512, 2, 2)  # concat 했으니까 수정해야함
    self.deconv0_bn = nn.BatchNorm2d(ngf*8) 
    self.deconv1 = nn.ConvTranspose2d(ngf*8*2, ngf*8, 4, 2, 1) #(512, 2, 2)->(512, 4, 4)
    self.deconv1_bn = nn.BatchNorm2d(ngf*8) 
    self.deconv2 = nn.ConvTranspose2d(ngf*8*2, ngf*8, 4, 2, 1) #(512, 4, 4)->(512, 8, 8)
    self.deconv2_bn = nn.BatchNorm2d(ngf*8) 
    self.deconv3 = nn.ConvTranspose2d(ngf*8*2, ngf*8, 4, 2, 1) #(512, 8, 8)->(512, 16, 16)
    self.deconv3_bn = nn.BatchNorm2d(ngf*8) 
    self.deconv4 = nn.ConvTranspose2d(ngf*8*2, ngf*4, 4, 2, 1) #(512, 16, 16)->(256, 32, 32)
    self.deconv4_bn = nn.BatchNorm2d(ngf*4) 
    self.deconv5 = nn.ConvTranspose2d(ngf*4*2, ngf*2, 4, 2, 1) #(256, 32, 32)->(128, 64, 64)
    self.deconv5_bn = nn.BatchNorm2d(ngf*2) 
    self.deconv6 = nn.ConvTranspose2d(ngf*2*2, ngf, 4, 2, 1) #(128, 64, 64)->(64, 128, 128)
    self.deconv6_bn = nn.BatchNorm2d(ngf) 
    self.deconv7 = nn.ConvTranspose2d(ngf*2, 4, 4, 2, 1) #(64, 128, 128)->(3, 256, 256)

    self.leaky = nn.LeakyReLU(0.2, True)
    self.relu = nn.ReLU(True)
    self.drop = nn.Dropout(0.5)

  def forward(self, input): # (3, 256, 256) <- input
    x1 = self.conv1(input) #(64, 128, 128)<- x1

    x2 = self.leaky(x1)
    x2 = self.conv2(x1)
    x2 = self.conv2_bn(x2)

    x3 = self.leaky(x2)   
    x3 = self.conv3(x3)
    x3 = self.conv3_bn(x3)

    x4 = self.leaky(x3)  
    x4 = self.conv4(x4)
    x4 = self.conv4_bn(x4)  

    x5 = self.leaky(x4)
    x5 = self.conv5(x5)    
    x5 = self.conv5_bn(x5)  

    x6 = self.leaky(x5)
    x6 = self.conv6(x6)    
    x6 = self.conv6_bn(x6)

    x7 = self.leaky(x6)
    x7 = self.conv7(x7)
    x7 = self.conv7_bn(x7)

    x8 = self.leaky(x7)
    x8 = self.conv8(x7)

    y1 = self.relu(x8)
    y1 = self.deconv0(y1)
    y1 = self.deconv0_bn(y1)
    y1 = self.drop(y1)
    y1 = torch.cat([y1,x7], dim=1)

    y2 = self.relu(y1)
    y2 = self.deconv1(y2)
    y2 = self.deconv1_bn(y2)
    y2 = self.drop(y2)
    y2 = torch.cat([y2,x6], dim=1)

    y3 = self.relu(y2)
    y3 = self.deconv2(y3)
    y3 = self.deconv2_bn(y3)
    y3 = self.drop(y3)
    y3 = torch.cat([y3,x5], dim=1)

    y4 = self.relu(y3)
    y4 = self.deconv3(y4)
    y4 = self.deconv3_bn(y4)
    y4 = torch.cat([y4,x4], dim=1)

    y5 = self.relu(y4)
    y5 = self.deconv4(y5)
    y5 = self.deconv4_bn(y5)
    y5 = torch.cat([y5,x3], dim=1)

    y6 = self.relu(y5)
    y6 = self.deconv5(y6)
    y6 = self.deconv5_bn(y6)
    y6 = torch.cat([y6,x2], dim=1)

    y7 = self.relu(y6)
    y7 = self.deconv6(y7)
    y7 = self.deconv6_bn(y7)
    y7 = torch.cat([y7,x1], dim=1)

    y8 = self.relu(y7)
    y8 = self.deconv7(y8)

    output = nn.Tanh()(y8)

    return output

In [57]:
class Discriminator(nn.Module): # 70*70 Patch
  def __init__(self, ndf=64):
    super(Discriminator, self).__init__()
    self.conv1 = nn.Conv2d(8, ndf, 4, 2, 1) 
    self.conv2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1)
    self.conv2_bn = nn.BatchNorm2d(ndf*2)
    self.conv3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1)
    self.conv3_bn = nn.BatchNorm2d(ndf*4)
    self.conv4 = nn.Conv2d(ndf*4, ndf*8, 4, 1, 1)
    self.conv4_bn = nn.BatchNorm2d(ndf*8)
    self.conv5 = nn.Conv2d(ndf*8, 1, 4, 1, 1)

    self.leaky = nn.LeakyReLU(0.2, True)

  def forward(self, input, label):
    x = torch.cat([input,label], dim=1)
    x = self.conv1(x)
    x = self.leaky(x)
    x = self.conv2(x)
    x = self.conv2_bn(x)
    x = self.leaky(x)
    x = self.conv3(x)
    x = self.conv3_bn(x)
    x = self.leaky(x)
    x = self.conv4(x)
    x = self.conv4_bn(x)
    x = self.leaky(x)
    x = self.conv5(x)
    output = nn.Sigmoid()(x)

    return output

In [58]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [123]:
def testpic():
    onlyfilesT = [file for file in glob.glob(os.getcwd()+"/TestDataset/*") if file.endswith('.png') or file.endswith('.PNG')]
    imagesT = np.empty(len(onlyfilesT), dtype=object)
    print(onlyfilesT)
    for n in range(0, len(onlyfilesT)):
      imagesT[n] = cv2.imread(onlyfilesT[n])

    dataloaderT = DataLoader(CustomDataset(imagesT, transform), batch_size=1, shuffle=True)
    
    fig = plt.figure(figsize=(8,8))
    plt.axis("off")

    num = 0
    for i, data in enumerate(dataloaderT):
        if(num>=1) :
            break
        num+=1
        x = data[0]
        y = data[1]

        pic = netG(x)
        pic = np.transpose(pic.detach().numpy()[0], (1, 2, 0))
        #print(pic.shape)
        plt.imshow(pic)
        plt.show()

In [59]:
netG = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

Generator(
  (conv1): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv4_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv5_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv6_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv7): Conv2d(512, 512, kernel_size=(4, 

In [60]:
netD = Discriminator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

Discriminator(
  (conv1): Conv2d(8, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv2_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (conv3_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (conv4_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  (leaky): LeakyReLU(negative_slope=0.2, inplace=True)
)


In [61]:
real_label = 1.
fake_label = 0.
img_list = []
G_loss = []
D_loss = []


In [62]:
def fit(num_epochs=15):
  print("Starting Training Loop...")
  optimizerG = Adam(netG.parameters(), lr=lrG, betas=(beta1, beta2))
  optimizerD = Adam(netD.parameters(), lr=lrD, betas=(beta1, beta2))
  iters = 0
  for epoch in range(num_epochs):
    print(f"EPOCH{epoch+1}:")
    train_one_epoch(dataloader, netG, netD, optimizerG, optimizerD, epoch)


def train_one_epoch(dataloader, netG, netD, optimizerG, optimizerD, epoch, iters=0):
  for i, data in enumerate(dataloader):

        netD.zero_grad()
        sketch, real = data
        sketch, real = sketch.to(device), real.to(device)
        
        D_real = netD(sketch, real).view(-1)
        # Calculate loss on all-real batch
        label = torch.full((D_real.size(0),), real_label, device=device)
        errD_real = nn.BCELoss()(D_real, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = D_real.mean().item()
      
        # Generate fake image batch with G
        fake = netG(sketch)
        label.fill_(fake_label)
        # Classify all fake batch with D
        D_fake = netD(sketch, fake.detach()).view(-1)
        
        # Calculate D's loss on the all-fake batch
        errD_fake = nn.BCELoss()(D_fake, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = D_fake.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        D_output = netD(sketch, fake).view(-1)
        G_output = netG(sketch)
        # Calculate G's loss based on this output
        errG = nn.BCELoss()(D_output, label)+ nn.L1Loss()(G_output, real)*L1lambda
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = D_output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_loss.append(errG.item())
        D_loss.append(errD.item())
        
        if(iters%100==0) :
            testpic()

        # Check how the generator is doing
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(sketch).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1


In [63]:
fit(1000)

Starting Training Loop...
EPOCH1:
[0/15][0/3]	Loss_D: 1.7118	Loss_G: 79.4221	D(x): 0.3825	D(G(z)): 0.3645 / 0.4645
EPOCH2:
[1/15][0/3]	Loss_D: 1.2754	Loss_G: 70.4430	D(x): 0.6170	D(G(z)): 0.4877 / 0.3964
EPOCH3:
[2/15][0/3]	Loss_D: 0.9760	Loss_G: 62.2135	D(x): 0.6254	D(G(z)): 0.3533 / 0.3101
EPOCH4:
[3/15][0/3]	Loss_D: 0.9120	Loss_G: 53.4814	D(x): 0.6661	D(G(z)): 0.3063 / 0.2780
EPOCH5:
[4/15][0/3]	Loss_D: 0.5451	Loss_G: 47.0552	D(x): 0.7639	D(G(z)): 0.2224 / 0.1659


In [148]:
torch.save(netG.state_dict(), 'G1.pth')
torch.save(netD.state_dict(), 'D1.pth')