In [1]:
import os,time,pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from lib import networks,train_history,util,visualizer
import itertools
from torchvision import datasets

In [2]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=2)

In [3]:
display_id=1
display_winsize=256
display_ncols=4
display_server='http://localhost'
display_port=8097
display_env='main'
name = 'persona_3'
checkpoints_dir='checkpoints'

vis = visualizer.Visualizer(display_id,display_winsize,display_ncols,display_server,display_port,display_env,
                 name,checkpoints_dir)

create web directory checkpoints/persona_3/web...


In [4]:
g_input_nc = 3
g_output_nc = 3
d_input_nc = 3
d_output_nc = 1

ngf=32
ndf=64
nb = 9
n_downsampling = 4

load_size = 286
fine_size = 256

# original_input_size = 4
batch_size = 1
# progressive_rounds = 6

lambda_A = 10.0
lambda_B = 10.0
lambda_idt = 0.5

lr = 0.0002
beta1 = 0.5
beta2 = 0.999

train_epoch = 100

display_freq = 400
print_freq = 400
save_latest_freq = 400
update_html_freq = 800

In [5]:
G_A = networks.generator(g_input_nc,g_output_nc,ngf,nb,n_downsampling=n_downsampling)
G_B = networks.generator(g_input_nc,g_output_nc,ngf,nb,n_downsampling=n_downsampling)

# G_A = networks.UnetGenerator(g_input_nc, g_output_nc, 7, ngf)
# G_B = networks.UnetGenerator(g_input_nc, g_output_nc, 7, ngf)

D_A = networks.discriminator(d_input_nc,d_output_nc,ndf)
D_B = networks.discriminator(d_input_nc,d_output_nc,ndf)

G_A,G_B,D_A,D_B = G_A.to(device),G_B.to(device),D_A.to(device),D_B.to(device)

print('---------- Networks initialized -------------')
for model_name,model in [('G',G_A),('D',D_A)]:
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(str.format('{} has {} number of parameters', model_name, num_params))
print('-----------------------------------------------')

initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
---------- Networks initialized -------------
G has 45620867 number of parameters
D has 2764737 number of parameters
-----------------------------------------------


In [6]:
# src_data_path='data/real'
# tgt_data_path = "data/manga"
src_data_path='/data/persona_cyclegan/real'
tgt_data_path =  '/data/persona_cyclegan/anime'

In [7]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((load_size, load_size)),
        transforms.RandomCrop(fine_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1,0.1,0.1,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader_A = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_B = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)


In [11]:
# loss
GAN_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)

def D_loss_criterion(D_decision,device,zeros,trick=True):
    if(zeros):
        if(trick):
            return GAN_loss(D_decision, torch.rand(D_decision.size(), device=device)/10.0)
        return GAN_loss(D_decision, torch.zeros(D_decision.size(), device=device))
    else:
        if(trick):
            return GAN_loss(D_decision, 1-torch.rand(D_decision.size(), device=device)/10.0)
        return GAN_loss(D_decision, torch.ones(D_decision.size(), device=device))

In [9]:
G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lr, betas=(beta1, beta2))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, beta2))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, beta2))

In [10]:
epoch_count=1
niter = 3
niter_decay = 100
def lambda_rule(epoch):
    lr_l = 1.0 - max(0, epoch + epoch_count - niter) / float(niter_decay + 1)
    return lr_l
schedulers = [lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) for optimizer in [G_optimizer,D_A_optimizer,D_B_optimizer]]

In [12]:
train_hist = train_history.train_history(['G_gan_loss',
                                          'G_cycle_loss',
                                          'D_A_fake_loss',
                                          'D_A_real_loss',
                                          'D_B_fake_loss',
                                          'D_B_real_loss'                                          
                                          ])

In [13]:
starting_epoch=0

In [None]:
## print('training start!')

num_pool = 80
fake_A_pool = util.ImagePool(num_pool)
fake_B_pool = util.ImagePool(num_pool)

for epoch in range(train_epoch):
    epoch_start_time = time.time()

    for scheduler in schedulers:
        scheduler.step()
        print('learning rate = %.7f'%G_optimizer.param_groups[0]['lr'])
    data_size = min(len(train_loader_A), len(train_loader_B))
    
    for i,((real_A,_),(real_B,_)) in enumerate(zip(train_loader_A, train_loader_B)):
        G_A.train()
        G_B.train()

        iter_start_time = time.time()
        # input image data
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # Train generator G
        
        #fix D parameters
        for model in [D_A, D_B]:
            for param in model.parameters():
                param.requires_grad = False
        
        # A -> B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = D_loss_criterion(D_B_fake_decision,device,zeros=False,trick=False)

        # identity loss
#         G_A_idt_loss = L1_loss(fake_B, real_A) * lambda_A * lambda_idt

        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_loss(recon_A, real_A) * lambda_A

        # B -> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = D_loss_criterion(D_A_fake_decision,device,zeros=False,trick=False)

        # identity loss
#         G_B_idt_loss = L1_loss(fake_A, real_B) * lambda_B * lambda_idt

        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_loss(recon_B, real_B) * lambda_B 

        # Back propagation
        G_gan_loss = G_A_loss + G_B_loss
        G_cycle_loss = cycle_A_loss + cycle_B_loss
#         G_idt_loss = G_A_idt_loss + G_B_idt_loss

#         G_loss = G_gan_loss + G_cycle_loss + G_idt_loss
        G_loss = G_gan_loss + G_cycle_loss
        
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        if i % display_freq == 0:
            save_result = i % update_html_freq == 0
            vis.display_current_results([real_A,fake_B,recon_A],starting_epoch+epoch, i,save_result)
        
        #train D parameters
        for model in [D_A, D_B]:
            for param in model.parameters():
                param.requires_grad = True

        # Train discriminator D_A
        D_A_optimizer.zero_grad()
       
        D_A_real_decision = D_A(real_A)     
        D_A_real_loss = D_loss_criterion(D_A_real_decision,device,zeros=False,trick=False)

        fake_A = G_B(real_B)
        fake_A = fake_A_pool.query(fake_A.detach())
        D_A_fake_decision = D_A(fake_A)     
        D_A_fake_loss = D_loss_criterion(D_A_fake_decision,device,zeros=True,trick=False)

        # Back propagation
        D_A_loss = 0.5*(D_A_fake_loss + D_A_real_loss)
        D_A_loss.backward()
        D_A_optimizer.step()

        # Train discriminator D_B
        D_B_optimizer.zero_grad()

        D_B_real_decision = D_B(real_B)
        D_B_real_loss = D_loss_criterion(D_B_real_decision,device,zeros=False,trick=False)         

        fake_B = G_A(real_A)
        fake_B = fake_B_pool.query(fake_B.detach())
        D_B_fake_decision = D_B(fake_B)
        D_B_fake_loss = D_loss_criterion(D_B_fake_decision,device,zeros=True,trick=False)

        # Back propagation
        D_B_loss = 0.5*(D_B_fake_loss + D_B_real_loss)
        D_B_loss.backward()
        D_B_optimizer.step()

        train_hist.add_params([G_gan_loss,G_cycle_loss,D_A_fake_loss,D_A_real_loss,D_B_fake_loss,D_B_real_loss])

        if i % print_freq == 0:
            t_data = iter_start_time - epoch_start_time
            losses = train_hist.check_current_avg()
            t = (time.time() - iter_start_time) / batch_size
            vis.print_current_losses(starting_epoch+epoch, i, losses, t, t_data)
            if display_id > 0:
                vis.plot_current_losses(starting_epoch+epoch, float(i) / data_size, losses)
        
        if i % save_latest_freq == 0:
            torch.save(G_A.state_dict(), os.path.join(checkpoints_dir,name, 'G_A.pkl'))
            torch.save(G_B.state_dict(), os.path.join(checkpoints_dir,name, 'G_B.pkl')) 
            torch.save(D_A.state_dict(), os.path.join(checkpoints_dir,name, 'D_A.pkl'))
            torch.save(D_B.state_dict(), os.path.join(checkpoints_dir,name, 'D_B.pkl'))
            train_hist.save_train(os.path.join(checkpoints_dir,name,  'train_hist.pkl'))
            
    if (epoch+starting_epoch)%10 == 0:
        torch.save(G_A.state_dict(), os.path.join(checkpoints_dir,name, str(epoch+starting_epoch)+'G_A.pkl'))
        torch.save(G_B.state_dict(), os.path.join(checkpoints_dir,name, str(epoch+starting_epoch)+'G_B.pkl')) 
        torch.save(D_A.state_dict(), os.path.join(checkpoints_dir,name, str(epoch+starting_epoch)+'D_A.pkl'))
        torch.save(D_B.state_dict(), os.path.join(checkpoints_dir,name, str(epoch+starting_epoch)+'D_B.pkl'))

learning rate = 0.0002000
learning rate = 0.0002000
learning rate = 0.0002000
(epoch: 0, iters: 0, time: 1.446, data: 0.039) G_gan_loss: 3.985 G_cycle_loss: 16.153 D_A_fake_loss: 2.067 D_A_real_loss: 2.005 D_B_fake_loss: 1.635 D_B_real_loss: 2.375 
(epoch: 0, iters: 400, time: 0.688, data: 87.727) G_gan_loss: 1.073 G_cycle_loss: 6.795 D_A_fake_loss: 0.488 D_A_real_loss: 0.420 D_B_fake_loss: 0.415 D_B_real_loss: 0.383 
(epoch: 0, iters: 800, time: 0.728, data: 176.547) G_gan_loss: 0.802 G_cycle_loss: 5.340 D_A_fake_loss: 0.251 D_A_real_loss: 0.272 D_B_fake_loss: 0.233 D_B_real_loss: 0.232 
(epoch: 0, iters: 1200, time: 0.612, data: 265.380) G_gan_loss: 0.799 G_cycle_loss: 5.067 D_A_fake_loss: 0.239 D_A_real_loss: 0.269 D_B_fake_loss: 0.244 D_B_real_loss: 0.232 
(epoch: 0, iters: 1600, time: 0.705, data: 353.813) G_gan_loss: 0.796 G_cycle_loss: 4.780 D_A_fake_loss: 0.238 D_A_real_loss: 0.263 D_B_fake_loss: 0.235 D_B_real_loss: 0.234 
(epoch: 0, iters: 2000, time: 0.666, data: 442.605) G_

(epoch: 0, iters: 18800, time: 0.658, data: 4166.151) G_gan_loss: 1.094 G_cycle_loss: 3.320 D_A_fake_loss: 0.179 D_A_real_loss: 0.198 D_B_fake_loss: 0.158 D_B_real_loss: 0.158 
(epoch: 0, iters: 19200, time: 0.756, data: 4254.520) G_gan_loss: 1.059 G_cycle_loss: 3.278 D_A_fake_loss: 0.177 D_A_real_loss: 0.198 D_B_fake_loss: 0.155 D_B_real_loss: 0.163 
(epoch: 0, iters: 19600, time: 0.648, data: 4342.943) G_gan_loss: 1.120 G_cycle_loss: 3.294 D_A_fake_loss: 0.186 D_A_real_loss: 0.199 D_B_fake_loss: 0.147 D_B_real_loss: 0.154 
(epoch: 0, iters: 20000, time: 0.729, data: 4431.548) G_gan_loss: 1.110 G_cycle_loss: 3.340 D_A_fake_loss: 0.185 D_A_real_loss: 0.197 D_B_fake_loss: 0.148 D_B_real_loss: 0.164 
(epoch: 0, iters: 20400, time: 0.665, data: 4520.066) G_gan_loss: 1.127 G_cycle_loss: 3.367 D_A_fake_loss: 0.180 D_A_real_loss: 0.192 D_B_fake_loss: 0.141 D_B_real_loss: 0.154 
learning rate = 0.0002000
learning rate = 0.0002000
learning rate = 0.0002000
(epoch: 1, iters: 0, time: 0.439, dat