In [34]:
import os, time, pickle, json
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from lib import utils, networks
import itertools
from torchvision import datasets

### Parameters 

In [35]:
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1
batch_size=8
ngf=64
ndf=32
#the number of resnet block layer for generator
nb=5
#input size
input_size=64
n_downsampling=3

train_epoch=20
#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for loss
lambdaA=1
lambdaB=1
lambda_cycle = 1
lambda_idt = 1
decay_epoch = 10

#beta1 for Adam optimizer
beta1=0.5
#beta2 for Adam optimizer
beta2=0.999

In [36]:
#results path
project_result_path='cycleGAN_3'
cycle_A_path = os.path.join(project_result_path, 'Cycle_G_A')
cycle_B_path = os.path.join(project_result_path, 'Cycle_G_B')
if not os.path.isdir(cycle_A_path):
    os.makedirs(cycle_A_path)
if not os.path.isdir(cycle_B_path):
    os.makedirs(cycle_B_path)

#data path
data_path = 'data'
src_data_path= os.path.join(data_path,'src_data_path_new')
tgt_data_path= os.path.join(data_path,'tgt_data_path')

In [37]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda')

In [38]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader_A = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_B = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)


In [39]:
# network
G_A = networks.cyclegan_generator1(in_ngc, out_ngc, ngf, nb, n_downsampling)
G_B = networks.cyclegan_generator1(in_ngc, out_ngc, ngf, nb, n_downsampling)
D_A = networks.discriminator(in_ndc, out_ndc, ndf)
D_B = networks.discriminator(in_ndc, out_ndc, ndf)

G_A.to(device)
G_B.to(device)
D_A.to(device)
D_B.to(device);

In [40]:
# loss
MSE_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)

In [41]:
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lrG, betas=(beta1, beta2))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=lrD, betas=(beta1, beta2))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=lrD, betas=(beta1, beta2))

In [42]:
train_hist = {}
train_hist['per_epoch_time'] = []
train_hist['total_time'] = []
train_hist['G_loss_one_epoch']=[]
train_hist['G_gan_loss_one_epoch']=[]
train_hist['G_cycle_loss_one_epoch']=[]
train_hist['G_idt_loss_one_epoch']=[]
train_hist['D_A_loss_one_epoch']=[]
train_hist['D_B_loss_one_epoch']=[]

### Load existing model parameters

In [11]:
G_A.load_state_dict(torch.load(os.path.join(project_result_path, 'G_A.pkl')))
G_B.load_state_dict(torch.load(os.path.join(project_result_path, 'G_B.pkl')))
D_A.load_state_dict(torch.load(os.path.join(project_result_path, 'D_A.pkl')))
D_B.load_state_dict(torch.load(os.path.join(project_result_path, 'D_B.pkl')))

### Load train hist

In [12]:
train_hist_path = os.path.join(project_result_path, 'train_hist.json')
with open(train_hist_path, 'r') as file:
    train_hist = json.load(file)

FileNotFoundError: [Errno 2] No such file or directory: 'cycleGAN_1_results/train_hist.json'

### Train

In [43]:
#change the starting_epoch if needed
starting_epoch = 0

In [44]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_A_pool = utils.ImagePool(num_pool)
fake_B_pool = utils.ImagePool(num_pool)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    print("==> Epoch {}/{}".format(starting_epoch+epoch + 1, starting_epoch+train_epoch))
    if (epoch + 1) > decay_epoch:
        D_A_optimizer.param_groups[0]['lr'] -= lrD / 10
        D_B_optimizer.param_groups[0]['lr'] -= lrD / 10
        G_optimizer.param_groups[0]['lr'] -= lrG / 10
    
    G_losses = []
    G_gan_losses = []
    G_cycle_losses = []
    G_idt_losses = []

    D_A_losses = []
    D_B_losses = []
    for (real_A,_),(real_B,_) in zip(train_loader_A, train_loader_B):
        G_A.train()
        G_B.train()

        # input image data
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        #fix D parameters
        for model in [D_A, D_B]:
            for param in model.parameters():
                param.requires_grad = False
        
        # Train generator G
        # A -> B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = MSE_loss(D_B_fake_decision, torch.ones(D_B_fake_decision.size(), device=device))

        # identity loss
        G_A_idt_loss = L1_loss(fake_B, real_A) * lambdaA * lambda_idt
        
        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_loss(recon_A, real_A) * lambdaA * lambda_cycle

        # B -> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = MSE_loss(D_A_fake_decision, torch.ones(D_A_fake_decision.size(), device=device))
        
        # identity loss
        G_B_idt_loss = L1_loss(fake_A, real_B) * lambdaB * lambda_idt
        
        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_loss(recon_B, real_B) * lambdaB * lambda_cycle
        
        # Back propagation
        G_gan_loss = G_A_loss + G_B_loss
        G_gan_losses.append(G_gan_loss)
        G_cycle_loss = cycle_A_loss + cycle_B_loss
        G_cycle_losses.append(G_cycle_loss)
#         G_idt_loss = G_A_idt_loss + G_B_idt_loss
#         G_idt_losses.append(G_idt_loss)

#         G_loss = G_gan_loss + G_cycle_loss + G_idt_loss
        G_loss = G_gan_loss + G_cycle_loss
        G_losses.append(G_loss)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        #train D parameters
        for model in [D_A, D_B]:
            for param in model.parameters():
                param.requires_grad = True
                
        # Train discriminator D_A
        D_A_real_decision = D_A(real_A)
        D_A_real_loss = MSE_loss(D_A_real_decision, 1-torch.rand(D_A_real_decision.size(), device=device)/10.0)
        fake_A = fake_A_pool.query(fake_A.detach())
        D_A_fake_decision = D_A(fake_A)
        D_A_fake_loss = MSE_loss(D_A_fake_decision, torch.rand(D_A_fake_decision.size(), device=device)/10.0)

        # Back propagation
        D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_losses.append(D_A_loss)
        D_A_optimizer.zero_grad()
        D_A_loss.backward()
        D_A_optimizer.step()

        # Train discriminator D_B
        D_B_real_decision = D_B(real_B)
        D_B_real_loss = MSE_loss(D_B_real_decision, 1-torch.rand(D_B_real_decision.size(), device=device)/10.0)
        fake_B = fake_B_pool.query(fake_B.detach())
        D_B_fake_decision = D_B(fake_B)
        D_B_fake_loss = MSE_loss(D_B_fake_decision, torch.rand(D_B_fake_decision.size(), device=device)/10.0)

        # Back propagation
        D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_losses.append(D_B_loss)
        D_B_optimizer.zero_grad()
        D_B_loss.backward()
        D_B_optimizer.step()
    
    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    G_loss_avg = torch.mean(torch.FloatTensor(G_losses))
    G_gan_loss_avg = torch.mean(torch.FloatTensor(G_gan_losses))
    G_cycle_loss_avg = torch.mean(torch.FloatTensor(G_cycle_losses))
#     G_idt_loss_avg = torch.mean(torch.FloatTensor(G_idt_losses))
    D_A_loss_avg = torch.mean(torch.FloatTensor(D_A_losses))
    D_B_loss_avg =  torch.mean(torch.FloatTensor(D_B_losses))
    
    train_hist['G_loss_one_epoch'].append(G_loss_avg)
    train_hist['G_gan_loss_one_epoch'].append(G_gan_loss_avg)
    train_hist['G_cycle_loss_one_epoch'].append(G_cycle_loss_avg)
#     train_hist['G_idt_loss_one_epoch'].append(G_idt_loss_avg)

    train_hist['D_A_loss_one_epoch'].append(D_A_loss_avg)
    train_hist['D_B_loss_one_epoch'].append(D_B_loss_avg)
    
    print(
    '[%d/%d] - time: %.2f, G loss: %.3f, gan loss: %.3f, cycle loss: %.3f, D_A loss: %.3f, D_B loss: %.3f' 
    % ((starting_epoch+epoch+1),starting_epoch+train_epoch, per_epoch_time, G_loss_avg, G_gan_loss_avg, G_cycle_loss_avg, D_A_loss_avg, D_B_loss_avg))
    
    #Save image result
    with torch.no_grad():
        G_A.eval()
        G_B.eval()
        for n, (x, _) in enumerate(train_loader_A):
            x = x.to(device)
            G_A_result = G_A(x)
            G_A_recon = G_B(G_A_result)
            result = torch.cat((x[0], G_A_result[0], G_A_recon[0]), 2)
            path = os.path.join(project_result_path, 'Cycle_G_A', str(epoch+starting_epoch) + '_epoch_'  + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        for n, (x,_) in enumerate(train_loader_B):
            x = x.to(device)
            G_B_result = G_B(x)
            G_B_recon = G_A(G_B_result)
            result = torch.cat((x[0],G_B_result[0],G_B_recon[0]),2)
            path = os.path.join(project_result_path,'Cycle_G_B',str(epoch+starting_epoch) + '_epoch_' +'_train_'+str(n+1)+'.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break
                
        torch.save(G_A.state_dict(), os.path.join(project_result_path, 'G_A.pkl'))
        torch.save(G_B.state_dict(), os.path.join(project_result_path, 'G_B.pkl')) 
        torch.save(D_A.state_dict(), os.path.join(project_result_path, 'D_A.pkl'))
        torch.save(D_B.state_dict(), os.path.join(project_result_path, 'D_B.pkl'))
        with open(os.path.join(project_result_path,  'train_hist.pkl'), 'wb') as f:
            pickle.dump(train_hist, f)

training start!
==> Epoch 1/20
[1/20] - time: 269.02, G loss: 2.510, gan loss: 1.033, cycle loss: 0.618, idt loss: 0.858, D_A loss: 0.054, D_B loss: 0.110
==> Epoch 2/20
[2/20] - time: 269.63, G loss: 2.261, gan loss: 0.808, cycle loss: 0.555, idt loss: 0.898, D_A loss: 0.097, D_B loss: 0.118
==> Epoch 3/20


KeyboardInterrupt: 