In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd
import random
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from torch.autograd import Variable
import tqdm
import itertools
import glob

In [None]:
# This works
class ResidualBlock(nn.Module):
    def __init__(self, num_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(num_features, num_features, 3),
                      nn.InstanceNorm2d(num_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(num_features, num_features, 3),
                      nn.InstanceNorm2d(num_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)
        

In [None]:
# Generator is correct
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # Initial convolution block       
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(3, 64, 7),
                 nn.InstanceNorm2d(64),
                 nn.ReLU(inplace=True)]

        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(9):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features//2
        for _ in range(2):
            model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                      nn.InstanceNorm2d(out_features),
                      nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features//2

        # Output layer
        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(64, 3, 7),
                  nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
# Discriminator is good 

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        model = [nn.Conv2d(3, 64, 4, stride=2, padding=1),
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(64, 128, 4, stride=2, padding=1),
                 nn.InstanceNorm2d(128), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(128, 256, 4, stride=2, padding=1),
                 nn.InstanceNorm2d(256), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(256, 512, 4, padding=1),
                 nn.InstanceNorm2d(512), 
                 nn.LeakyReLU(0.2, inplace=True),
                 nn.Conv2d(512, 1, 4, padding=1),
                 nn.Sigmoid()
                ]
        
        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
#         return x
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [None]:
class FakePool():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def query(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [None]:
class Args(object):
    def __init__(self, batch_size=1, test_batch_size=1,
            epochs=10, lr=0.001, momentum=0.5,seed=1):
        self.batch_size = batch_size
        self.test_batch_size = test_batch_size # Input batch size for testing
        self.epochs = epochs # Number of epochs to train
        self.lr = lr # Learning rate
        self.momentum = momentum
        self.seed = seed # Random seed

In [None]:
# This time, there is no label. Just input 2 different images
class PainterDataset():
    def __init__(self, img_dir, transform=None, unaligned=False, mode='train'):
        self.img_dir = img_dir
        self.transform = transform
        self.unaligned = unaligned
        
        pathA = os.path.join(self.img_dir, mode+'A')
        pathB = os.path.join(self.img_dir, mode+'B')
        
        self.filesA = sorted(glob.glob(pathA + '/*.*'))
        self.filesB = sorted(glob.glob(pathB + '/*.*'))

    def __len__(self):
        return len(self.filesA)
    
    def __getitem__(self, idx):
        files_A = self.filesA[idx % len(self.filesA)]
        img1 = Image.open(files_A)
        
        indexB = -1
        if self.unaligned:
            indexB = random.randint(0, len(self.filesB) -1)
        else:
            indexB = idx % len(self.filesB)
        files_B = self.filesB[indexB]
        img2 = Image.open(files_B)
        
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return img1, img2

In [None]:
def prepare_dataset(object):
    data_dir = 'datasets/nouveau2roman/'
    transforms_ = transforms.Compose([ transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ])
    
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_dataset = PainterDataset(img_dir = data_dir, transform = transforms_, unaligned=True)
    test_dataset = PainterDataset(img_dir = data_dir, transform = transforms_, mode='test')
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, **kwargs)
    
    return train_loader, test_loader

In [None]:
def train(args, train_loader):  
    # Define models
    netG_A2B = Generator()
    netG_B2A = Generator()
    netD_A = Discriminator()
    netD_B = Discriminator()
    
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()
    
    # May want to init weight here
    # .apply(weights_init_normal?)
    
    # Define Loss function
    criterion_GAN = nn.BCELoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()
    
    # Optimizer and decreasing LR
    optimizer_G = optim.Adam(
        itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
        lr=args.lr, betas=(0.5, 0.999))
    
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(), 
                                     lr=args.lr, betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(), 
                                     lr=args.lr, betas=(0.5, 0.999))
    
    target_real = Variable(torch.Tensor(args.batch_size).fill_(1.0).cuda(), requires_grad=False)
    target_fake = Variable(torch.Tensor(args.batch_size).fill_(0.0).cuda(), requires_grad=False)
    
    fakeA_buffer = FakePool()
    fakeB_buffer = FakePool()
    
    # May want to have lr_scheduler here
    
    epochs_to_run = args.epochs
    progress_bar = tqdm.tqdm(train_loader, desc='Training')

    
    for epoch in range (1, 1 + epochs_to_run):
        for batch_idx, (inputA, inputB) in enumerate(progress_bar):
            # Input
            inputA, inputB = Variable(inputA.cuda()), Variable(inputB.cuda())
            
            # Forward pass
            fakeB = netG_A2B(inputA) 
            recoveredA = netG_B2A(fakeB)
            fakeA = netG_B2A(inputB)
            recoveredB = netG_A2B(fakeA)
            
            # Backward for Generator
            optimizer_G.zero_grad()
            
            # Identity Loss
            sameA = netG_B2A(inputA)
            loss_identity_A = criterion_identity(sameA, inputA)*5.0
            
            sameB = netG_A2B(inputB)
            loss_identity_B = criterion_identity(sameB, inputB)*5.0
            
            # GAN Loss
            pred_fake = netD_B(fakeB)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)
            
            pred_fake = netD_A(fakeA)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)
            
            # Cycle Loss
            loss_cycle_ABA = criterion_cycle(recoveredA, inputA)*10.0
            loss_cycle_BAB = criterion_cycle(recoveredB, inputB)*10.0
            
            # Total Loss
            loss = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
            loss.backward()
            
            optimizer_G.step()
            
            # Discriminator A backward
            optimizer_D_A.zero_grad()
            
            # Real loss
            pred_real = netD_A(inputA)
            loss_D_real = criterion_GAN(pred_real, target_real)
            
            # Fake loss
            fakeA = fakeA_buffer.query(fakeA)
            pred_fake = netD_A(fakeA.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            
            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake)*0.5
            loss_D_A.backward()
            
            optimizer_D_A.step()
            
            # Disciminator B backward
            optimizer_D_B.zero_grad()
            
            # Real loss
            pred_real = netD_B(inputB)
            loss_D_real = criterion_GAN(pred_real, target_real)
            
            # Fake loss
            fakeB = fakeB_buffer.query(fakeB)
            pred_fake = netD_B(fakeB.detach())
            loss_D_fake = criterion_GAN(pred_fake, target_fake)
            
            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake)*0.5
            loss_D_B.backward()
            
            optimizer_D_B.step()
            
            
            progress_bar.set_description(
            'Epoch: {} Generator loss: {} Discriminator A loss: {} Discriminator B loss: {}' .format(
                epoch, loss.data[0], loss_D_A.data[0], loss_D_B.data[0]))
        
        progress_bar.refresh()
        data_dir = 'datasets/nouveau2roman/'
        # Save models
        output_dir = os.path.join(data_dir, 'output')
        torch.save(netG_A2B.state_dict(), os.path.join(output_dir,'netG_A2B.pth'))
        torch.save(netG_B2A.state_dict(), os.path.join(output_dir,'netG_B2A.pth'))
        torch.save(netD_A.state_dict(), os.path.join(output_dir,'netD_A.pth'))
        torch.save(netD_B.state_dict(), os.path.join(output_dir,'netD_B.pth'))

In [None]:
def test(args, test_loader):
    # Define models
    netG_A2B = Generator()
    netG_B2A = Generator()
    
    netG_A2B.cuda()
    netG_B2A.cuda()
    
    # Load in the model
    data_dir = 'datasets/nouveau2roman/'
    output_dir = os.path.join(data_dir, 'output')
    netG_A2B.load_state_dict(torch.load(os.path.join(output_dir,'netG_A2B.pth')))
    netG_B2A.load_state_dict(torch.load(os.path.join(output_dir,'netG_B2A.pth')))
    
    # Test
    netG_A2B.eval()
    netG_B2A.eval()
    
    progress_bar = tqdm.tqdm(test_loader, desc='Validation')
    
    for index, (inputA, inputB) in enumerate(progress_bar):
        # Input
        inputA, inputB = Variable(inputA.cuda()), Variable(inputB.cuda())
        
        # Generate output
        with torch.no_grad():
            outB = netG_A2B(inputA)
            outA = netG_B2A(inputB)
        
        # Save files
        save_image(outA, os.path.join(output_dir,'A/%04d.png' %(index+1)))
        save_image(outB, os.path.join(output_dir,'B/%04d.png' %(index+1)))
        
        print('Generated images %04d of %04d' %(index+1, len(test_loader)))

In [None]:
data_dir = 'datasets/nouveau2roman/'
args = Args()
train_loader, test_loader = prepare_dataset(args)

train(args, train_loader)
test(args, test_loader)