In [1]:
import os, time, pickle
from lib import networks, utils
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda')

### Hyper-parameters

In [3]:
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1

batch_size=2
#number of filters in the first layer of generator
ngf=32
#number of filters in the first layer of discriminator
ndf=32
#the number of resnet block layer for generator
nb=9
#input size
input_size=128
train_epoch=200

#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for content loss
con_lambda=100
#beta1 for Adam optimizer
beta1=0.5
#beta2 for Adam optimizer
beta2=0.999

n_downsampling = 3

In [4]:
project_name = 'pix2pix_facades_5'
result_path = project_name+'_results'
data_name = 'data/facades'

# results save path
if not os.path.isdir(result_path):
    os.makedirs(result_path)
#ensure data folder exists
if not os.path.isdir(data_name):
    os.makedirs(data_name)
    print("data folder does not exist!!")


In [5]:
# data_loader
train_transform = transforms.Compose([
        transforms.Resize((input_size, 2*input_size)),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader = utils.data_load(data_name, 'train', train_transform, batch_size, shuffle=False, drop_last=True)
test_loader = utils.data_load(data_name, 'test', train_transform, batch_size, shuffle=False, drop_last=True)
# train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(data_name, train_transform), batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
# data_loader if data augmentation needed
train_transform = transforms.Compose([
        transforms.RandomResizedCrop (size = input_size,scale= (0.7,1.0)),
        transforms.Resize((input_size, 2*input_size)),
        transforms.RandomGrayscale(),
        transforms.RandomVerticalFlip(),      
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(data_name, train_transform), batch_size=batch_size, shuffle=True, drop_last=True)

In [12]:
# network

G = networks.generator(in_ngc, out_ngc, ngf, nb)
D = networks.discriminator(in_ndc, out_ndc, ndf)
# D = networks.wgan_discriminator(in_ndc, ndf, input_size, n_downsampling)

G.to(device)
D.to(device)
G.train()
D.train();
# print('---------- Networks initialized -------------')
# utils.print_network(G)
# utils.print_network(D)
# print('-----------------------------------------------')

In [7]:
# loss
# GAN_loss = nn.BCELoss().to(device)
GAN_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)
def D_loss_criterion(D_decision):
    return D_decision.mean()

In [13]:
# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)
D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)

In [14]:
train_hist = {}
train_hist['Disc_loss'] = []
train_hist['Gen_loss'] = []
train_hist['Con_loss'] = []
train_hist['per_epoch_time'] = []
train_hist['Gen_loss_one_epoch']=[]
train_hist['Disc_loss_one_epoch']=[]
train_hist['Con_loss_one_epoch']=[]

In [15]:
#starting_epoch is used to avoid overriding of the previously generated results
starting_epoch = 0

In [16]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_pool = utils.ImagePool(num_pool)
# real = torch.ones(batch_size, 1, input_size // 4, input_size // 4).to(device)
# fake = torch.zeros(batch_size, 1, input_size // 4, input_size // 4).to(device)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    G.train()
    G_scheduler.step()
    D_scheduler.step()
    Disc_losses = []
    Gen_losses = []
    Con_losses = []
    for d, _ in train_loader:
        # x is the image at left, the source image
        # y is the image at right, the target image
        y = d[:, :, :, :input_size]
        x = d[:, :, :, input_size:]
        x, y, d = x.to(device), y.to(device), d.to(device)
        
        # train D
        for param in D.parameters():
            param.requires_grad = True
#             param.data.clamp_(-0.005,0.005)
        D_optimizer.zero_grad()

        D_real = D(d)
#         D_real_loss = D_loss_criterion(D_real)
        D_real_loss = GAN_loss(D_real, 1-torch.rand(D_real.size(),device = device)/8.0)
#         D_real_loss = GAN_loss(D_real, torch.ones(D_real.size(),device = device))

        with torch.no_grad():
            G_ = G(x)
        G_ = fake_pool.query(G_.detach())
        generated = torch.cat((G_,x), 2)
        D_fake = D(generated)
#         D_fake_loss = D_loss_criterion(D_fake)
        D_fake_loss = GAN_loss(D_fake, torch.rand(D_fake.size(),device = device)/8.0)
#         D_fake_loss = GAN_loss(D_fake, torch.zeros(D_fake.size(),device = device))

        Disc_loss = 0.5 * (D_real_loss + D_fake_loss)
#         Disc_loss = D_fake_loss - D_real_loss
        Disc_losses.append(Disc_loss.item())
        train_hist['Disc_loss'].append(Disc_loss.item())
        Disc_loss.backward()
        D_optimizer.step()

        # train G
        for param in D.parameters():
            param.requires_grad = False
        G_optimizer.zero_grad()

        G_ = G(x)
        generated = torch.cat((G_,x), 2)
        D_fake = D(generated)
#         D_fake_loss = D_loss_criterion(D_fake)
        D_fake_loss = GAN_loss(D_fake, torch.ones(D_fake.size(),device = device))

        Con_loss = con_lambda * L1_loss(G_, y)

        Gen_loss = D_fake_loss + Con_loss
#         Gen_loss = -D_fake_loss + Con_loss

        Gen_losses.append(D_fake_loss.item())
        train_hist['Gen_loss'].append(D_fake_loss.item())
        Con_losses.append(Con_loss.item())
        train_hist['Con_loss'].append(Con_loss.item())

        Gen_loss.backward()
        G_optimizer.step()


    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    Gen_loss_avg = torch.mean(torch.FloatTensor(Gen_losses))
    Con_loss_avg = torch.mean(torch.FloatTensor(Con_losses))
    Disc_loss_avg =  torch.mean(torch.FloatTensor(Disc_losses))
    
    train_hist['Gen_loss_one_epoch'].append(Gen_loss_avg)
    train_hist['Disc_loss_one_epoch'].append(Disc_loss_avg)
    train_hist['Con_loss_one_epoch'].append(Con_loss_avg)
    
    print(
    '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((starting_epoch + epoch + 1), train_epoch, per_epoch_time, Disc_loss_avg, Gen_loss_avg, Con_loss_avg))

    with torch.no_grad():
        G.eval()
        for n, (d, _) in enumerate(train_loader):
            y = d[:, :, :, :input_size]
            x = d[:, :, :, input_size:]            
            x,y = x.to(device),y.to(device)
            G_recon = G(x)
            result = torch.cat((x[0], G_recon[0],y[0]), 2)
            path = os.path.join(result_path, str(starting_epoch+epoch+1) + '_epoch_' + project_name + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 2:
                break
        for n, (d, _) in enumerate(test_loader):
            y = d[:, :, :, :input_size]
            x = d[:, :, :, input_size:]            
            x,y = x.to(device),y.to(device)
            G_recon = G(x)
            result = torch.cat((x[0], G_recon[0],y[0]), 2)
            path = os.path.join(result_path, str(starting_epoch+epoch+1) + '_epoch_' + project_name + '_test_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 1:
                break
                
        torch.save(G.state_dict(), os.path.join(result_path, 'generator_latest.pkl'))
        torch.save(D.state_dict(), os.path.join(result_path, 'discriminator_latest.pkl'))
        with open(os.path.join(result_path,  'train_hist.pkl'), 'wb') as f:
            pickle.dump(train_hist, f)


training start!
[1/200] - time: 15.45, Disc loss: 0.050, Gen loss: 0.638, Con loss: 36.560
[2/200] - time: 14.84, Disc loss: 0.027, Gen loss: 0.726, Con loss: 35.021
[3/200] - time: 14.57, Disc loss: 0.021, Gen loss: 0.749, Con loss: 34.048
[4/200] - time: 14.55, Disc loss: 0.027, Gen loss: 0.736, Con loss: 33.520
[5/200] - time: 14.52, Disc loss: 0.094, Gen loss: 0.551, Con loss: 33.682
[6/200] - time: 14.55, Disc loss: 0.020, Gen loss: 0.761, Con loss: 33.086
[7/200] - time: 14.55, Disc loss: 0.018, Gen loss: 0.769, Con loss: 33.106
[8/200] - time: 14.56, Disc loss: 0.029, Gen loss: 0.729, Con loss: 33.137
[9/200] - time: 14.61, Disc loss: 0.017, Gen loss: 0.768, Con loss: 32.719
[10/200] - time: 14.77, Disc loss: 0.034, Gen loss: 0.708, Con loss: 33.014
[11/200] - time: 14.55, Disc loss: 0.029, Gen loss: 0.704, Con loss: 32.570
[12/200] - time: 14.54, Disc loss: 0.096, Gen loss: 0.530, Con loss: 32.676
[13/200] - time: 15.59, Disc loss: 0.023, Gen loss: 0.729, Con loss: 32.412
[14/2

[109/200] - time: 14.56, Disc loss: 0.016, Gen loss: 0.753, Con loss: 14.125
[110/200] - time: 14.83, Disc loss: 0.016, Gen loss: 0.754, Con loss: 13.963
[111/200] - time: 14.54, Disc loss: 0.015, Gen loss: 0.755, Con loss: 13.854
[112/200] - time: 14.53, Disc loss: 0.015, Gen loss: 0.755, Con loss: 13.810
[113/200] - time: 14.57, Disc loss: 0.015, Gen loss: 0.754, Con loss: 13.830
[114/200] - time: 14.77, Disc loss: 0.015, Gen loss: 0.752, Con loss: 13.653
[115/200] - time: 15.06, Disc loss: 0.016, Gen loss: 0.748, Con loss: 13.720
[116/200] - time: 14.55, Disc loss: 0.015, Gen loss: 0.752, Con loss: 13.728
[117/200] - time: 14.56, Disc loss: 0.015, Gen loss: 0.751, Con loss: 13.351
[118/200] - time: 14.55, Disc loss: 0.015, Gen loss: 0.750, Con loss: 13.562
[119/200] - time: 14.54, Disc loss: 0.015, Gen loss: 0.750, Con loss: 13.520
[120/200] - time: 14.53, Disc loss: 0.015, Gen loss: 0.751, Con loss: 13.560
[121/200] - time: 14.56, Disc loss: 0.015, Gen loss: 0.751, Con loss: 13.382

In [18]:
train_transform = transforms.Compose([
        transforms.Resize((input_size, 2*input_size)),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

test_loader = torch.utils.data.DataLoader(datasets.ImageFolder('data/eye_test', train_transform), batch_size=batch_size, shuffle=True, drop_last=True)

test_result_path = 'pix2pix_eyes_test_results'

# results save path
if not os.path.isdir(test_result_path):
    os.makedirs(test_result_path)

with torch.no_grad():
    G.eval()
    for n, (d, _) in enumerate(test_loader):
#         x = d[:, :, :, :input_size]
#         y = d[:, :, :, input_size:]            
#         x = x.to(device)
#         y = y.to(device)
        d = d.to(device)
        G_recon = G(d)
        result = torch.cat((d[0], G_recon[0]), 2)
        path = os.path.join(test_result_path, str(starting_epoch+epoch+1) + '_epoch_' + project_name + '_test_' + str(n + 1) + '.png')
        plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)