In [2]:
import os
import random
import math
import numpy as np
import itertools
import sys
import datetime
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transform
from torchvision.utils import save_image

from torch.utils.data import dataloader
from torchvision import datasets
from torch.autograd import Variable

In [3]:
epoch = 0
n_epochs = 200 
dataset_name = 'facade'
batch_size = 1
lr = 0.0002
b1 = 0.5
b2 = 0.999
decay_epoch = 100
sample_interval = 50
n_cpu = 8
img_height = 256
img_width = 256
channels = 3
checkpoint_interval = 1


In [4]:
os.makedirs("images",exist_ok=True)
os.makedirs("saved_model",exist_ok=True)

In [36]:
def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data,0.0,0.2)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data,1.0,0.2)
        torch.nn.init.constant_(m.bias.data,0.0)

In [55]:
class UnetDown(nn.Module):
    def __init__(self,in_size,out_size,normalize=True,dropout=0.0):
        super(UnetDown,self).__init__()
        layers = [nn.Conv2d(in_size,out_size,4,2,1,bias = False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
            
        self.model = nn.Sequential(*layers)
        
        
    def forward(self,x):
        return model(x)
    
    

In [56]:
class UnetUp(nn.Module):
    def __init__(self,in_size,out_size,dropout=0.0):
        super(UnetUp,self).__init__()
        
        layers = [nn.ConvTranspose2d(in_size,out_size,4,2,1,bias=False),
                    nn.InstanceNorm2d(out_size),
                     nn.ReLU(inplace=True)]
        if dropout:
            layers.append(nn.Dropout(dropout))
            
        self.model = nn.Sequential(*layers)
        
        def forward(self,x,skip_input):
            x =  self.model(x)
            x = torch.cat((x,skip_input),1)
            return x

In [57]:
class GeneratorUnet(nn.Module):
    def __init__(self,in_channel = 3,out_channel = 3):
        super(GeneratorUnet,self).__init__()
            
        self.down1 = UnetDown(in_channel,64,normalize=False)
        self.down2 = UnetDown(64,128)
        self.down3 = UnetDown(128,256)
        self.down4 = UnetDown(256,512,dropout=0.5)
        self.down5 = UnetDown(512,512,dropout = 0.5)
        self.down6 = UnetDown(512,512,dropout=0.5)
        self.down7 = UnetDown(512,512,dropout=0.5)
        self.down8 = UnetDown(512,512,dropout=0.5,normalize=False)
            
        self.up1 = UnetUp(512,512,dropout=0.5)
        self.up2 = UnetUp(1024,512,dropout=0.5)
        self.up3 = UnetUp(1024,512,dropout=0.5)
        self.up4 = UnetUp(1024,512,dropout=0.5)
        self.up5 = UnetUp(1024,256)
        self.up6 = UnetUp(512,128)
        self.up7 = UnetUp(256,64)
            
        self.final = nn.Sequential(
                            nn.Upsample(scale_factor = 2),
                            nn.ZeroPad2d((1,0,1,0)),
                            nn.Conv2d(128,out_channel,4,padding = 1),
                            nn.Tanh())
            
            
        def forward(self,x):
            d1 = self.down1(x)
            d2 = self.down2(d1)
            d3 = self.down3(d2)
            d4 = self.down4(d3)
            d5 = self.down5(d4)
            d6 = self.down6(d5)
            d7 = self.down7(d6)
            d8 = self.down8(d7)
                
            u1 = self.up1(d8,d7)
            u2 = self.up2(u1,d6)
            u3 = self.up3(u2,d5)
            u4 = self.up4(u3,d4)
            u5 = self.up5(u4,d3)
            u6 = self.up6(u5,d2)
            u7 = self.up7(u6,d1)
                
            return self.final(u7)
                
                            

In [58]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator,self).__init__()
        
        def discriminator_block(in_fil,out_fil,normalize = True):
            layers = [nn.Conv2d(in_fil,out_fil,4,2,1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_fil))
            layers.append(nn.LeakyReLU(inplace=True))
            return layers
        
        self.model = nn.Sequential(*discriminator_block(in_channels*2,64,normalize=False),
                                  *discriminator_block(64,128),
                                  *discriminator_block(128,256),
                                    *discriminator_block(256,512),
                                  nn.ZeroPad2d((1,0,1,0)),
                                    nn.Conv2d(512,1,4,1,bias=False))
        
        
    def forward(self,img_A,img_B):
        img_input = torch.cat((img_A,img_B),1)
        return self.model(img_input)

In [59]:
import glob

from PIL import Image

def ImageDataset(Dataset):
    def __init__(self,root,transform=None,mode = "train"):
        self.transorm = transforms.Compose(transforms_)
        
        self.files = sorted(glob.glob(os.path.join(root,mode) + "/*.*"))
        if mode == "train":
            self.files.extend(sorted(glob.glob(os.path.join(root,"test")+"/*.*")))
            
            
    def __getitem__(self,index):
        img = Image.open(self.files[index % len(self.files)])
        w,h = img.size
        img_A = img.crop((0,0,w / 2 ,h))
        img_B = img.crop((w/2,0,w,h))
        
        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")
            
            
        img_A = self.transform(img_A)
        img_B = self.transform(img_B)
        
        return {"A":img_A,"B":img_B}
    
    def __len__(self):
        return len(self.files)

In [61]:
cuda = True if torch.cuda.is_available() else False

crieterion_GAN = torch.nn.MSELoss()
crieterion_pixelwise = torch.nn.L1Loss()

lambda_pixel = 100
patch = (1,img_height// 2 ** 4,img_width // 2 ** 4)

generator = GeneratorUnet()
discriminator = Discriminator()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    crieterion_GAN.cuda()
    crieterion_pixelwise.cuda()
    

In [62]:
if epoch != 0:
    generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (dataset_name, epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (dataset_name,epoch)))
else:
    generator.apply(weight_init_normal)
    discriminator.apply(weight_init_normal)
    

In [63]:
optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr,betas = (b1,b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr,betas = (b1,b2))