In [1]:
import os, time, pickle, argparse, networks, utils
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from edge_promoting import edge_promoting
import itertools

### Parameters 

In [2]:
name='project_name'
#source data path
src_data='src_data_path'
#target data path
tgt_data='tgt_data_path'
#pre-trained VGG19 model path
vgg_model='pre_trained_VGG19_model_path/vgg19.pth'
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1
batch_size=8
ngf=64
ndf=32
#the number of resnet block layer for generator
nb=8
#input size
input_size=256
train_epoch=10
#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for loss
lambdaA=10
lambdaB=10
decay_epoch = 10

#beta1 for Adam optimizer
beta1=0.5
#beta2 for Adam optimizer
beta2=0.999

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True

In [4]:
device

device(type='cuda')

In [6]:
# results save path
if not os.path.isdir(os.path.join(name + '_results', 'Cycle_G_A')):
    os.makedirs(os.path.join(name + '_results', 'Cycle_G_A'))
if not os.path.isdir(os.path.join(name + '_results', 'Cycle_G_B')):
    os.makedirs(os.path.join(name + '_results', 'Cycle_G_B'))

In [7]:
#setup source and target folder
if not os.path.isdir(os.path.join('data',tgt_data,'train')):
    os.makedirs(os.path.join('data',tgt_data,'train'))
if not os.path.isdir(os.path.join('data',tgt_data,'test')):
    os.makedirs(os.path.join('data',tgt_data,'test'))
if not os.path.isdir(os.path.join('data',src_data,'train')):
    os.makedirs(os.path.join('data',src_data,'train'))
if not os.path.isdir(os.path.join('data',src_data,'test')):
    os.makedirs(os.path.join('data',src_data,'test'))

In [5]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader_A = utils.data_load(os.path.join('data', src_data), 'train', transform, batch_size, shuffle=True, drop_last=True)
train_loader_B = utils.data_load(os.path.join('data', tgt_data), 'train', transform, batch_size, shuffle=True, drop_last=True)
test_loader_A = utils.data_load(os.path.join('data', src_data), 'test', transform, 1, shuffle=True, drop_last=True)

In [6]:
# network
G_A = networks.generator(in_ngc, out_ngc, ngf, nb)
G_B = networks.generator(in_ngc, out_ngc, ngf, nb)
D_A = networks.discriminator(in_ndc, out_ndc, ndf)
D_B = networks.discriminator(in_ndc, out_ndc, ndf)

G_A.to(device)
G_B.to(device)
D_A.to(device)
D_B.to(device)

discriminator(
  (convs): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): LeakyReLU(negative_slope=0.2, inplace)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): LeakyReLU(negative_slope=0.2, inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (11): LeakyReLU(negative_slope=0.2, inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, trac

### Load existing model parameters

In [7]:
G_A.load_state_dict(torch.load(os.path.join(name + '_results', 'G_A.pkl')))
G_B.load_state_dict(torch.load(os.path.join(name + '_results', 'G_B.pkl')))

### end loading parameters 

In [8]:
# loss
MSE_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)

In [9]:
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lrG, betas=(beta1, beta2))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=lrD, betas=(beta1, beta2))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=lrD, betas=(beta1, beta2))

In [10]:
train_hist = {}
train_hist['G_loss'] = []
train_hist['D_A_loss'] = []
train_hist['D_B_loss'] = []
train_hist['per_epoch_time'] = []
train_hist['total_time'] = []
train_hist['G_loss_one_epoch']=[]
train_hist['D_A_loss_one_epoch']=[]
train_hist['D_B_loss_one_epoch']=[]

In [11]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_A_pool = utils.ImagePool(num_pool)
fake_B_pool = utils.ImagePool(num_pool)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    print("==> Epoch {}/{}".format(epoch + 1, train_epoch))
    if (epoch + 1) > decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= lrD / 10
        D_B_optimizer.param_groups[0]['lr'] -= lrD / 10
        G_optimizer.param_groups[0]['lr'] -= lrG / 10
    
    G_losses = []
    D_A_losses = []
    D_B_losses = []
    for (real_A,_),(real_B,_) in zip(train_loader_A, train_loader_B):
        G_A.train()
        G_B.train()

        # input image data
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # Train generator G
        # A -> B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = MSE_loss(D_B_fake_decision, torch.ones(D_B_fake_decision.size(), device=device))

        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_loss(recon_A, real_A) * lambdaA

        # B -> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = MSE_loss(D_A_fake_decision, torch.ones(D_A_fake_decision.size(), device=device))

        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_loss(recon_B, real_B) * lambdaB
        
        #fix D parameters
        for model in [D_A, D_B]:
            for param in D_A.parameters():
                param.requires_grad = False
        
        
        # Back propagation
        G_loss = G_A_loss + G_B_loss + cycle_A_loss + cycle_B_loss
        G_losses.append(G_loss)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        #train D parameters
        for model in [D_A, D_B]:
            for param in D_A.parameters():
                param.requires_grad = True
                
        # Train discriminator D_A
        D_A_real_decision = D_A(real_A)
        D_A_real_loss = MSE_loss(D_A_real_decision, torch.ones(D_A_real_decision.size(), device=device))
        fake_A = fake_A_pool.query(fake_A)
        D_A_fake_decision = D_A(fake_A)
        D_A_fake_loss = MSE_loss(D_A_fake_decision, torch.zeros(D_A_fake_decision.size(), device=device))

        # Back propagation
        D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_losses.append(D_A_loss)
        D_A_optimizer.zero_grad()
        D_A_loss.backward()
        D_A_optimizer.step()

        # Train discriminator D_B
        D_B_real_decision = D_B(real_B)
        D_B_real_loss = MSE_loss(D_B_real_decision, torch.ones(D_B_real_decision.size(), device=device))
        fake_B = fake_B_pool.query(fake_B)
        D_B_fake_decision = D_B(fake_B)
        D_B_fake_loss = MSE_loss(D_B_fake_decision, torch.zeros(D_B_fake_decision.size(), device=device))

        # Back propagation
        D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_losses.append(D_B_loss)
        D_B_optimizer.zero_grad()
        D_B_loss.backward()
        D_B_optimizer.step()
    
    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    G_loss_avg = torch.mean(torch.FloatTensor(G_losses))
    D_A_loss_avg = torch.mean(torch.FloatTensor(D_A_losses))
    D_B_loss_avg =  torch.mean(torch.FloatTensor(D_B_losses))
    
    train_hist['G_loss_one_epoch'].append(G_loss_avg)
    train_hist['D_A_loss_one_epoch'].append(D_A_loss_avg)
    train_hist['D_B_loss_one_epoch'].append(D_B_loss_avg)
    
    print(
    '[%d/%d] - time: %.2f, G loss: %.3f, D_A loss: %.3f, D_B loss: %.3f' % ((epoch + 1), train_epoch, per_epoch_time, G_loss_avg, D_A_loss_avg, D_B_loss_avg))
    
    #Save image result
    with torch.no_grad():
        G_A.eval()
        G_B.eval()
        for n, (x, _) in enumerate(train_loader_A):
            x = x.to(device)
            G_A_result = G_A(x)
            G_A_recon = G_B(G_A_result)
            result = torch.cat((x[0], G_A_recon[0], G_A_result[0]), 2)
            path = os.path.join(name + '_results', 'Cycle_G_A', str(epoch+6) + '_epoch_'  + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        for n, (x, _) in enumerate(test_loader_A):
            x = x.to(device)
            G_A_result = G_A(x)
            G_A_recon = G_B(G_A_result)
            result = torch.cat((x[0], G_A_recon[0], G_A_result[0]), 2)
            path = os.path.join(name + '_results', 'Cycle_G_A', str(epoch+6) + '_epoch_'  + '_test_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        for n, (x,_) in enumerate(train_loader_B):
            x = x.to(device)
            G_B_result = G_B(x)
            G_B_recon = G_A(G_B_result)
            result = torch.cat((x[0],G_B_result[0],G_B_recon[0]),2)
            path = os.path.join(name+'_results','Cycle_G_B',str(epoch+1) + '_epoch_' +'_train_'+str(n+1)+'.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break
                
        torch.save(G_A.state_dict(), os.path.join(name + '_results', 'G_A.pkl'))
        torch.save(G_B.state_dict(), os.path.join(name + '_results', 'G_B.pkl')) 
        torch.save(D_A.state_dict(), os.path.join(name + '_results', 'D_A.pkl'))
        torch.save(D_B.state_dict(), os.path.join(name + '_results', 'D_B.pkl'))


training start!
==> Epoch 1/10
[1/10] - time: 2072.28, G loss: 2.618, D_A loss: 0.231, D_B loss: 0.157
==> Epoch 2/10
[2/10] - time: 2046.40, G loss: 2.837, D_A loss: 0.187, D_B loss: 0.132
==> Epoch 3/10
[3/10] - time: 2042.19, G loss: 2.962, D_A loss: 0.177, D_B loss: 0.051
==> Epoch 4/10
[4/10] - time: 2043.79, G loss: 3.078, D_A loss: 0.162, D_B loss: 0.021
==> Epoch 5/10
[5/10] - time: 2044.19, G loss: 3.102, D_A loss: 0.148, D_B loss: 0.030
==> Epoch 6/10
[6/10] - time: 2044.24, G loss: 3.105, D_A loss: 0.141, D_B loss: 0.000
==> Epoch 7/10
[7/10] - time: 2044.32, G loss: 3.099, D_A loss: 0.129, D_B loss: 0.000
==> Epoch 8/10
[8/10] - time: 2044.93, G loss: 3.125, D_A loss: 0.110, D_B loss: 0.000
==> Epoch 9/10
[9/10] - time: 2060.48, G loss: 3.089, D_A loss: 0.112, D_B loss: 0.000
==> Epoch 10/10
[10/10] - time: 2061.34, G loss: 2.880, D_A loss: 0.161, D_B loss: 0.000
