In [1]:
import os, time, pickle
from lib import networks, utils
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets

In [2]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=3)

### Hyper-parameters

In [3]:
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=6
#output channel for discriminator
out_ndc=1

batch_size=1
#number of filters in the first layer of generator
ngf=64
#number of filters in the first layer of discriminator
ndf=32
#the number of resnet block layer for generator
# nb=9
#input size
input_size=64
train_epoch=300

#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for content loss
con_lambda=5
#beta1 for Adam optimizer
beta1=0.5
#beta2 for Adam optimizer
beta2=0.999

# n_downsampling = 3

In [4]:
project_name = 'pix2pix_mouth_1'
result_path = project_name+'_results'
data_name = '/data/pix2pix/mouth_train'
# data_name = 'data/combine_w_blur_curated'
test_data_name = 'data/mouth_test'
# results save path
if not os.path.isdir(result_path):
    os.makedirs(result_path)
#ensure data folder exists
if not os.path.isdir(data_name):
    os.makedirs(data_name)
    print("data folder does not exist!!")


In [5]:
# data_loader
train_transform = transforms.Compose([
        transforms.Resize((input_size, 2*input_size)),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
#         transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# train_loader = utils.data_load(data_name, 'train', train_transform, batch_size, shuffle=False, drop_last=True)
# test_loader = utils.data_load(data_name, 'test', train_transform, batch_size, shuffle=False, drop_last=True)
train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(data_name, train_transform), batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(test_data_name, test_transform), batch_size=1, shuffle=True, drop_last=True)

In [6]:
# network

# G = networks.generator(in_ngc, out_ngc, ngf, nb)
G = networks.UnetGenerator(in_ngc, out_ngc, 6, ngf)
D = networks.discriminator(in_ndc, out_ndc, ndf, False)
# D = networks.wgan_discriminator(in_ndc, ndf, input_size, n_downsampling)

G.to(device)
D.to(device)
G.train()
D.train();
print('---------- Networks initialized -------------')
for name,model in [('G',G),('D',D)]:
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(str.format('{} has {} number of parameters', name, num_params))
print('-----------------------------------------------')

---------- Networks initialized -------------
G has 29244035 number of parameters
D has 1129249 number of parameters
-----------------------------------------------


In [7]:
# loss
# GAN_loss = nn.BCELoss().to(device)
GAN_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)
# def D_loss_criterion(D_decision):
#     return D_decision.mean()

In [8]:
# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)
D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)

In [9]:
train_hist = {}
train_hist['Disc_loss'] = []
train_hist['Gen_loss'] = []
train_hist['Con_loss'] = []
train_hist['per_epoch_time'] = []
train_hist['Gen_loss_one_epoch']=[]
train_hist['Disc_loss_one_epoch']=[]
train_hist['Con_loss_one_epoch']=[]

In [10]:
#starting_epoch is used to avoid overriding of the previously generated results
starting_epoch = 0

In [None]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_pool = utils.ImagePool(num_pool)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    G.train()
    G_scheduler.step()
    D_scheduler.step()
    Disc_losses = []
    Gen_losses = []
    Con_losses = []
    for d, _ in train_loader:
        # x is the image at left, the source image
        # y is the image at right, the target image
        x = d[:, :, :, :input_size]
        y = d[:, :, :, input_size:input_size*2]
        x, y= x.to(device), y.to(device)
        
        # train D
        for param in D.parameters():
            param.requires_grad = True
        D_optimizer.zero_grad()

        real = torch.cat((y,x),1)
        D_real = D(real)
#         D_real_loss = D_loss_criterion(D_real)
#         D_real_loss = GAN_loss(D_real, 1-torch.rand(D_real.size(),device = device)/10.0)
        D_real_loss = GAN_loss(D_real, torch.ones(D_real.size(),device = device))

        with torch.no_grad():
            G_ = G(x)
        generated = torch.cat((G_,x), 1)
        generated = fake_pool.query(generated.detach())
        D_fake = D(generated)
#         D_fake_loss = D_loss_criterion(D_fake)
#         D_fake_loss = GAN_loss(D_fake, torch.rand(D_fake.size(),device = device)/10.0)
        D_fake_loss = GAN_loss(D_fake, torch.zeros(D_fake.size(),device = device))

        Disc_loss = 0.5*(D_real_loss + D_fake_loss)
        Disc_losses.append(Disc_loss.item())
        train_hist['Disc_loss'].append(Disc_loss.item())
        Disc_loss.backward()
        D_optimizer.step()

        # train G
        for param in D.parameters():
            param.requires_grad = False
        G_optimizer.zero_grad()

        G_ = G(x)
        generated = torch.cat((G_,x), 1)
        D_fake = D(generated)
#         D_fake_loss = D_loss_criterion(D_fake)
        D_fake_loss = GAN_loss(D_fake, torch.ones(D_fake.size(),device = device))

        Con_loss = con_lambda * L1_loss(G_, y)

        Gen_loss = D_fake_loss + Con_loss

        Gen_losses.append(D_fake_loss.item())
        train_hist['Gen_loss'].append(D_fake_loss.item())
        Con_losses.append(Con_loss.item())
        train_hist['Con_loss'].append(Con_loss.item())

        Gen_loss.backward()
        G_optimizer.step()


    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    Gen_loss_avg = torch.mean(torch.FloatTensor(Gen_losses))
    Con_loss_avg = torch.mean(torch.FloatTensor(Con_losses))
    Disc_loss_avg =  torch.mean(torch.FloatTensor(Disc_losses))
    
    train_hist['Gen_loss_one_epoch'].append(Gen_loss_avg)
    train_hist['Disc_loss_one_epoch'].append(Disc_loss_avg)
    train_hist['Con_loss_one_epoch'].append(Con_loss_avg)
    
    print(
    '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((starting_epoch + epoch + 1), train_epoch, per_epoch_time, Disc_loss_avg, Gen_loss_avg, Con_loss_avg))
    
    if(epoch%10 == 0):
        with torch.no_grad():
            G.eval()
            for n, (d, _) in enumerate(train_loader):
                x = d[:, :, :, :input_size]
                y = d[:, :, :, input_size:input_size*2]
                x,y = x.to(device),y.to(device)
                G_recon = G(x)
                result = torch.cat((x[0], G_recon[0],y[0]), 2)
                path = os.path.join(result_path, str(starting_epoch+epoch+1) + '_epoch_' + project_name + '_train_' + str(n + 1) + '.png')
                plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 1:
                    break
            for n, (d, _) in enumerate(test_loader):
                d = d.to(device)
                G_recon = G(d)
                result = torch.cat((d[0], G_recon[0]), 2)
                path = os.path.join(result_path, str(starting_epoch+epoch+1) + '_epoch_' + project_name + '_test_' + str(n + 1) + '.png')
                plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
                if n == 1:
                    break
    if(epoch%40 == 0):
        torch.save(G.state_dict(), os.path.join(result_path, str(epoch)+'_generator_latest.pkl'))
        torch.save(D.state_dict(), os.path.join(result_path, str(epoch)+'_discriminator_latest.pkl'))
        with open(os.path.join(result_path,  'train_hist.pkl'), 'wb') as f:
            pickle.dump(train_hist, f)


training start!
[1/300] - time: 15.12, Disc loss: 0.173, Gen loss: 0.492, Con loss: 1.437
[2/300] - time: 15.20, Disc loss: 0.164, Gen loss: 0.525, Con loss: 1.459
[3/300] - time: 14.92, Disc loss: 0.160, Gen loss: 0.537, Con loss: 1.476
[4/300] - time: 14.88, Disc loss: 0.158, Gen loss: 0.515, Con loss: 1.455
[5/300] - time: 15.04, Disc loss: 0.160, Gen loss: 0.518, Con loss: 1.467
[6/300] - time: 14.88, Disc loss: 0.140, Gen loss: 0.610, Con loss: 1.466
[7/300] - time: 14.95, Disc loss: 0.123, Gen loss: 0.633, Con loss: 1.493
[8/300] - time: 14.93, Disc loss: 0.113, Gen loss: 0.643, Con loss: 1.473
[9/300] - time: 16.84, Disc loss: 0.112, Gen loss: 0.662, Con loss: 1.478
[10/300] - time: 15.17, Disc loss: 0.078, Gen loss: 0.767, Con loss: 1.472
[11/300] - time: 15.01, Disc loss: 0.063, Gen loss: 0.808, Con loss: 1.438
[12/300] - time: 14.71, Disc loss: 0.075, Gen loss: 0.765, Con loss: 1.474
[13/300] - time: 15.86, Disc loss: 0.070, Gen loss: 0.785, Con loss: 1.447
[14/300] - time: 1