In [1]:
import os, time, pickle, json
from lib import networks, utils
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True

In [3]:
device

device(type='cuda')

### Hyper-parameters

In [20]:
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1

batch_size=7
ngf=64
ndf=32
#the number of resnet block layer for generator
nb=8
#input size
input_size=256
train_epoch=20

#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for content loss
con_lambda=1
#beta1 for Adam optimizer
beta1=0.5
#beta2 for Adam optimizer
beta2=0.999

In [5]:
project_name = 'pix2pix_1'
result_path = project_name+'_results'
data_name = 'pix2pix_data'

# results save path
if not os.path.isdir(result_path):
    os.makedirs(result_path)
#ensure data folder exists
if not os.path.isdir(data_name):
    os.makedirs(data_name)
    print("data folder does not exist!!")


data folder does not exist!!


In [6]:
# data_loader
train_transform = transforms.Compose([
        transforms.Resize((input_size, 2*input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(data_name, train_transform), batch_size=batch_size, shuffle=True, drop_last=True)

In [16]:
# network
G = networks.generator(in_ngc, out_ngc, ngf, nb)
D = networks.discriminator(in_ndc, out_ndc, ndf)
G.to(device)
D.to(device)
G.train()
D.train()
print('---------- Networks initialized -------------')
utils.print_network(G)
utils.print_network(D)
print('-----------------------------------------------')

---------- Networks initialized -------------
generator(
  (down_convs): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): ReLU(inplace)
  )
  (resnet_blocks): Sequential(
    (0): resnet_block(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv1_norm): InstanceNorm2d(256,

In [17]:
# loss
BCE_loss = nn.BCELoss().to(device)
L1_loss = nn.L1Loss().to(device)

In [18]:
# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)
D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)

pre_train_hist = {}
pre_train_hist['Recon_loss'] = []
pre_train_hist['per_epoch_time'] = []
pre_train_hist['total_time'] = []

In [19]:
#starting_epoch is used to avoid overriding of the previously generated results
starting_epoch = 0

In [11]:
train_hist = {}
train_hist['Disc_loss'] = []
train_hist['Gen_loss'] = []
train_hist['Con_loss'] = []
train_hist['per_epoch_time'] = []
train_hist['total_time'] = []
train_hist['Gen_loss_one_epoch']=[]
train_hist['Disc_loss_one_epoch']=[]
train_hist['Con_loss_one_epoch']=[]

In [21]:
print('training start!')
start_time = time.time()
real = torch.ones(batch_size, 1, input_size // 4, input_size // 4).to(device)
fake = torch.zeros(batch_size, 1, input_size // 4, input_size // 4).to(device)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    G.train()
    G_scheduler.step()
    D_scheduler.step()
    Disc_losses = []
    Gen_losses = []
    Con_losses = []
    for d, _ in train_loader:
        x = d[:, :, :, :input_size]
        y = d[:, :, :, input_size:]
        x, y = x.to(device), y.to(device)

        # train D
        D_optimizer.zero_grad()

        D_real = D(y)
        D_real_loss = BCE_loss(D_real, real)

        G_ = G(x)
        D_fake = D(G_)
        D_fake_loss = BCE_loss(D_fake, fake)

        Disc_loss = D_real_loss + D_fake_loss
        Disc_losses.append(Disc_loss.item())
        train_hist['Disc_loss'].append(Disc_loss.item())

        Disc_loss.backward()
        D_optimizer.step()

        # train G
        G_optimizer.zero_grad()

        G_ = G(x)
        D_fake = D(G_)
        D_fake_loss = BCE_loss(D_fake, real)

        Con_loss = con_lambda * L1_loss(G_, y)

        Gen_loss = D_fake_loss + Con_loss

        Gen_losses.append(D_fake_loss.item())
        train_hist['Gen_loss'].append(D_fake_loss.item())
        Con_losses.append(Con_loss.item())
        train_hist['Con_loss'].append(Con_loss.item())

        Gen_loss.backward()
        G_optimizer.step()


    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    Gen_loss_avg = torch.mean(torch.FloatTensor(Gen_losses))
    Con_loss_avg = torch.mean(torch.FloatTensor(Con_losses))
    Disc_loss_avg =  torch.mean(torch.FloatTensor(Disc_losses))
    
    train_hist['Gen_loss_one_epoch'].append(Gen_loss_avg)
    train_hist['Disc_loss_one_epoch'].append(Disc_loss_avg)
    train_hist['Con_loss_one_epoch'].append(Con_loss_avg)
    
    print(
    '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((starting_epoch + epoch + 1), train_epoch, per_epoch_time, Disc_loss_avg, Gen_loss_avg, Con_loss_avg))

    with torch.no_grad():
        G.eval()
        for n, (d, _) in enumerate(train_loader):
            x = d[:, :, :, :input_size]
            y = d[:, :, :, input_size:]            
            x = x.to(device)
            G_recon = G(x)
            result = torch.cat((x[0], G_recon[0]), 2)
            path = os.path.join(result_path, str(starting_epoch+epoch+1) + '_epoch_' + project_name + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        torch.save(G.state_dict(), os.path.join(result_path, 'generator_latest.pkl'))
        torch.save(D.state_dict(), os.path.join(result_path, 'discriminator_latest.pkl'))

total_time = time.time() - start_time
train_hist['total_time'].append(total_time)

print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_time'])), train_epoch, total_time))

training start!
[1/20] - time: 4.88, Disc loss: 1.474, Gen loss: 0.741, Con loss: 0.566
[2/20] - time: 4.85, Disc loss: 1.408, Gen loss: 0.750, Con loss: 0.515
[3/20] - time: 4.87, Disc loss: 1.387, Gen loss: 0.710, Con loss: 0.474
[4/20] - time: 4.90, Disc loss: 1.384, Gen loss: 0.704, Con loss: 0.472
[5/20] - time: 4.85, Disc loss: 1.385, Gen loss: 0.703, Con loss: 0.460
[6/20] - time: 4.87, Disc loss: 1.385, Gen loss: 0.703, Con loss: 0.463
[7/20] - time: 4.84, Disc loss: 1.386, Gen loss: 0.702, Con loss: 0.459
[8/20] - time: 4.87, Disc loss: 1.385, Gen loss: 0.701, Con loss: 0.465
[9/20] - time: 4.89, Disc loss: 1.385, Gen loss: 0.701, Con loss: 0.454
[10/20] - time: 4.86, Disc loss: 1.385, Gen loss: 0.702, Con loss: 0.464
[11/20] - time: 4.94, Disc loss: 1.385, Gen loss: 0.702, Con loss: 0.464
[12/20] - time: 4.82, Disc loss: 1.384, Gen loss: 0.703, Con loss: 0.460
[13/20] - time: 4.88, Disc loss: 1.385, Gen loss: 0.702, Con loss: 0.458
[14/20] - time: 4.85, Disc loss: 1.384, Gen 