In [4]:
import os, time, pickle
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from lib import utils, networks, train_history
import itertools
from torchvision import datasets
from torch.autograd import Variable, grad

### Parameters 

In [66]:
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1
batch_size=10
ngf=32
ndf=32
#the number of resnet block layer for generator
nb=4
#input size
input_size=128
ng_downsampling=4
nd_downsampling=3

train_epoch=20
#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for loss
lambdaA=10
lambdaB=10
lambda_cycle = 1
# lambda_idt = 1
decay_epoch = 10

# wgan number of critics
n_critic = 5

#beta1 for Adam optimizer
# beta1=0
beta1 = 0.5
#beta2 for Adam optimizer
# beta2=0.9
# beta2=0.999
beta2=0.99

In [67]:
#results path
project_result_path='cycleGAN_8'
cycle_A_path = os.path.join(project_result_path, 'Cycle_G_A')
cycle_B_path = os.path.join(project_result_path, 'Cycle_G_B')
if not os.path.isdir(cycle_A_path):
    os.makedirs(cycle_A_path)
if not os.path.isdir(cycle_B_path):
    os.makedirs(cycle_B_path)

#data path
data_path = 'data'
src_data_path= os.path.join(data_path,'src_data_path_new')
tgt_data_path= os.path.join(data_path,'tgt_data_path')

In [50]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda')

In [68]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader_A = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_B = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)


In [69]:
# network
G_A = networks.cyclegan_generator1(in_ngc, out_ngc, ngf, nb, ng_downsampling)
G_B = networks.cyclegan_generator1(in_ngc, out_ngc, ngf, nb, ng_downsampling)
# D_A = networks.wgan_discriminator(in_ndc, ndf, input_size, nd_downsampling)
# D_B = networks.wgan_discriminator(in_ndc, ndf, input_size, nd_downsampling)
D_A = networks.discriminator(in_ndc, out_ndc, ndf)
D_B = networks.discriminator(in_ndc, out_ndc, ndf)


G_A.to(device)
G_B.to(device)
D_A.to(device)
D_B.to(device);

In [70]:
# loss
# GAN_loss = nn.BCELoss().to(device)
GAN_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)

def D_loss_criterion(D_decision,device,zeros,trick=True):
    if(zeros):
        if(trick):
            return GAN_loss(D_decision, torch.rand(D_decision.size(), device=device)/10.0)
        return GAN_loss(D_decision, torch.zeros(D_decision.size(), device=device))
    else:
        if(trick):
            return GAN_loss(D_decision, 1-torch.rand(D_decision.size(), device=device)/10.0)
        return GAN_loss(D_decision, torch.ones(D_decision.size(), device=device))

# def D_loss_criterion(D_decision):
#     return D_decision.mean()

In [71]:
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lrG, betas=(beta1, beta2))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=lrD, betas=(beta1, beta2))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=lrD, betas=(beta1, beta2))

In [72]:
train_hist = train_history.train_history(['per_epoch_time',
                                          'G_gan_loss',
                                          'G_cycle_loss',
                                          'D_A_fake_loss',
                                          'D_A_real_loss',
                                          'D_B_fake_loss',
                                          'D_B_real_loss'                                          
                                          ])

### Load existing model parameters

In [10]:
G_A.load_state_dict(torch.load(os.path.join(project_result_path, 'G_A.pkl')))
G_B.load_state_dict(torch.load(os.path.join(project_result_path, 'G_B.pkl')))
D_A.load_state_dict(torch.load(os.path.join(project_result_path, 'D_A.pkl')))
D_B.load_state_dict(torch.load(os.path.join(project_result_path, 'D_B.pkl')))

### Load train hist

In [None]:
train_hist.load_train(os.path.join(project_result_path, 'train_hist.pkl'))

### Train

In [73]:
#change the starting_epoch if needed
starting_epoch = 0

In [74]:
def WGAN_calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, real_data.nelement()//batch_size).contiguous()
    alpha = alpha.view(batch_size, 3, input_size, input_size)
    alpha = alpha.to(device)
#     alpha = torch.rand(real_data.size(), device=device)
    interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
    interpolates = interpolates.to(device)
    interpolates.requires_grad_(True)   

    disc_interpolates = netD(interpolates)
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1) 
    gradient_penalty = 10* ((1-(gradients+1e-16).norm(2, dim=1)) ** 2).mean()
    return gradient_penalty

def DRAGAN_calc_gradient_penalty(netD, X):
    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, X.nelement()//batch_size).contiguous()
    alpha = alpha.view(batch_size, 3, input_size, input_size)
    alpha = alpha.to(device)
    x_hat = alpha * X.data + (1 - alpha) * (X.data + 0.5 * X.data.std() * torch.rand(X.size(),device=device))
    x_hat.to(device)
    x_hat.requires_grad_(True)
    pred_hat = netD(x_hat)
    gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(device),
            create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients_penalty = 10 * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    gradients_penalty.backward(retain_graph = True)

In [None]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_A_pool = utils.ImagePool(num_pool)
fake_B_pool = utils.ImagePool(num_pool)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    print("==> Epoch {}/{}".format(starting_epoch+epoch + 1, starting_epoch+train_epoch))
    if (epoch + 1) > decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= lrD / 10
        D_B_optimizer.param_groups[0]['lr'] -= lrD / 10
        G_optimizer.param_groups[0]['lr'] -= lrG / 10
    
    G_gan_losses = []
    G_cycle_losses = []

    D_A_real_losses = []
    D_A_fake_losses = []
    D_B_real_losses = []
    D_B_fake_losses = []

    for (real_A,_),(real_B,_) in zip(train_loader_A, train_loader_B):
        G_A.train()
        G_B.train()

        # input image data
        real_A = real_A.to(device)
        real_B = real_B.to(device)
#         if(i%5 == 0):
        #fix D parameters
        for model in [D_A, D_B]:
            for param in model.parameters():
                param.requires_grad = False

        # Train generator G
        # A -> B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = D_loss_criterion(D_B_fake_decision,device,zeros=False,trick=False) * lambdaA
#         G_A_loss = D_loss_criterion(D_B_fake_decision)  * lambdaA

        # identity loss
#         G_A_idt_loss = L1_loss(fake_B, real_A) * lambdaA * lambda_idt

        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_loss(recon_A, real_A) * lambdaA * lambda_cycle

        # B -> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = D_loss_criterion(D_A_fake_decision,device,zeros=False,trick=False) * lambdaB
#         G_B_loss = D_loss_criterion(D_A_fake_decision) * lambdaB

        # identity loss
#         G_B_idt_loss = L1_loss(fake_A, real_B) * lambdaB * lambda_idt

        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_loss(recon_B, real_B) * lambdaB * lambda_cycle

        # Back propagation
        G_gan_loss = G_A_loss + G_B_loss
        G_gan_losses.append(G_gan_loss)
        G_cycle_loss = cycle_A_loss + cycle_B_loss
        G_cycle_losses.append(G_cycle_loss)
#         G_idt_loss = G_A_idt_loss + G_B_idt_loss
#         G_idt_losses.append(G_idt_loss)

#         G_loss = G_gan_loss + G_cycle_loss + G_idt_loss
#         G_loss = -G_gan_loss + G_cycle_loss
        G_loss = G_gan_loss + G_cycle_loss
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
#         else:
        #train D parameters
        for model in [D_A, D_B]:
            for param in model.parameters():
                param.requires_grad = True
#                     param.data.clamp_(-0.01,0.01)

        # Train discriminator D_A
        D_A_optimizer.zero_grad()

        fake_A = G_B(real_B)
#         D_A_gradient_penalty = WGAN_calc_gradient_penalty(D_A,real_A,fake_A)
#         D_A_gradient_penalty = DRAGAN_calc_gradient_penalty(D_A,real_A)
        DRAGAN_calc_gradient_penalty(D_A,real_A)
        
        D_A_real_decision = D_A(real_A)     
#         D_A_real_loss = D_loss_criterion(D_A_real_decision)
        D_A_real_loss = D_loss_criterion(D_A_real_decision,device,zeros=False,trick=True)
        D_A_real_losses.append(D_A_real_loss)


        fake_A = fake_A_pool.query(fake_A.detach())
        D_A_fake_decision = D_A(fake_A)     
#         D_A_fake_loss = D_loss_criterion(D_A_fake_decision)
        D_A_fake_loss = D_loss_criterion(D_A_fake_decision,device,zeros=True,trick=True)
        D_A_fake_losses.append(D_A_fake_loss)

        # Back propagation
#         D_A_loss = D_A_fake_loss - D_A_real_loss + D_A_gradient_penalty
        D_A_loss = D_A_fake_loss + D_A_real_loss
#         D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_loss.backward()
        D_A_optimizer.step()

        # Train discriminator D_B
        D_B_optimizer.zero_grad()

        fake_B = G_A(real_A)
#         D_B_gradient_penalty = WGAN_calc_gradient_penalty(D_B,real_B,fake_B)
        DRAGAN_calc_gradient_penalty(D_B,real_B)

        D_B_real_decision = D_B(real_B)
#         D_B_real_loss = D_loss_criterion(D_B_real_decision)
        D_B_real_loss = D_loss_criterion(D_B_real_decision,device,zeros=False,trick=True)
        D_B_real_losses.append(D_B_real_loss)          

        fake_B = fake_B_pool.query(fake_B.detach())
        D_B_fake_decision = D_B(fake_B)
#         D_B_fake_loss = D_loss_criterion(D_B_fake_decision)
        D_B_fake_loss = D_loss_criterion(D_B_fake_decision,device,zeros=True,trick=True)
        D_B_fake_losses.append(D_B_fake_loss)         

        # Back propagation
#         D_B_loss = D_B_fake_loss - D_B_real_loss + D_B_gradient_penalty
        D_B_loss = D_B_fake_loss + D_B_real_loss
#         D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_loss.backward()
        D_B_optimizer.step()
     
    train_params = []
    per_epoch_time = time.time() - epoch_start_time
    train_params.append(per_epoch_time)
    for loss in [G_gan_losses,G_cycle_losses,D_A_fake_losses,D_A_real_losses,D_B_fake_losses,D_B_real_losses]:
        train_params.append(torch.mean(torch.FloatTensor(loss)))
    
    train_hist.add_params(train_params)
    print(str.format('{}/{} ',starting_epoch+epoch+1,starting_epoch+train_epoch) + train_hist.get_last_param_str())
    
    #Save image result
    with torch.no_grad():
        G_A.eval()
        G_B.eval()
        for n, (x, _) in enumerate(train_loader_A):
            x = x.to(device)
            G_A_result = G_A(x)
            G_A_recon = G_B(G_A_result)
            result = torch.cat((x[0], G_A_result[0], G_A_recon[0]), 2)
            path = os.path.join(project_result_path, 'Cycle_G_A', str(epoch+starting_epoch) + '_epoch_'  + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        for n, (x,_) in enumerate(train_loader_B):
            x = x.to(device)
            G_B_result = G_B(x)
            G_B_recon = G_A(G_B_result)
            result = torch.cat((x[0],G_B_result[0],G_B_recon[0]),2)
            path = os.path.join(project_result_path,'Cycle_G_B',str(epoch+starting_epoch) + '_epoch_' +'_train_'+str(n+1)+'.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break
                
        torch.save(G_A.state_dict(), os.path.join(project_result_path, 'G_A.pkl'))
        torch.save(G_B.state_dict(), os.path.join(project_result_path, 'G_B.pkl')) 
        torch.save(D_A.state_dict(), os.path.join(project_result_path, 'D_A.pkl'))
        torch.save(D_B.state_dict(), os.path.join(project_result_path, 'D_B.pkl'))
        train_hist.save_train(os.path.join(project_result_path,  'train_hist.pkl'))

training start!
==> Epoch 1/20
1/20 per_epoch_time:649.226,G_gan_loss:7.903,G_cycle_loss:5.511,D_A_fake_loss:0.191,D_A_real_loss:0.552,D_B_fake_loss:0.275,D_B_real_loss:0.439,
==> Epoch 2/20
