In [1]:
import torch,os
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from utils.data import RandomNoiseGenerator,Data
from utils.train_history import train_history
from utils import visualizer
import itertools
from models.model import Generator, Discriminator
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=3)

In [3]:
latent_size = 1024
target_resol = 256
first_resol = 4
use_sigmoid = False
train_kimg = 600
train_img = train_kimg*1000
transition_kimg = 600
transition_img = transition_kimg*1000

g_lr_max = 0.001
d_lr_max = 0.001

beta1 = 0
beta2 = 0.99

lambda_A=10
lambda_B=10
lambda_recon=0.8
lambda_idt=0.1

report_it = 400
show_it = 400
save_it=400

In [4]:
display_id=1
display_winsize=256
display_ncols=4
display_server='http://localhost'
display_port=8097
display_env='pggan'
results_dir='results'
project_name = 'pggan1'
project_dir=os.path.join(results_dir,project_name)

vis = visualizer.Visualizer(display_id,display_winsize,display_ncols,display_server,display_port,display_env,
                 name,results_dir)

if not os.path.isdir(project_dir):
    os.makedirs(project_dir)

In [5]:
real_dir = '/data/persona_cyclegan/real/trainA'
anime_dir = '/data/persona_cyclegan/anime/trainB'

data_A = Data(real_dir)
data_B = Data(anime_dir)

In [6]:
G_A = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)
G_B = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)

D_A = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=True)
D_B = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=True)

# print(G_A)
# print(D_A)
G_A,G_B,D_A,D_B = G_A.to(device),G_B.to(device),D_A.to(device),D_B.to(device)
optim_G = optim.Adam(itertools.chain(G_A.parameters(),G_B.parameters()), g_lr_max, betas=(beta1, beta2))
optim_D = optim.Adam(itertools.chain(D_A.parameters(),D_B.parameters()), d_lr_max, betas=(beta1, beta2))

all_models = {'G_A.pkl':G_A,
             'G_B.pkl':G_B,
             'D_A.pkl':D_A,
             'D_B.pkl':D_B}
            

print('---------- Networks initialized -------------')
for model_name,model in [('G',G_A),('D',D_A)]:
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(str.format('{} has {} number of parameters', model_name, num_params))
print('-----------------------------------------------')

---------- Networks initialized -------------
G has 149458133 number of parameters
D has 66632259 number of parameters
-----------------------------------------------


In [7]:
rampup_kimg = 10000
rampdown_kimg = 10000
total_kimg = 10000

def _rampup(epoch, rampup_length):
    if epoch < rampup_length:
        p = max(0.0, float(epoch)) / float(rampup_length)
        p = 1.0 - p
        return np.exp(-p*p*5.0)
    else:
        return 1.0

def _rampdown_linear(epoch, num_epochs, rampdown_length):
    if epoch >= num_epochs - rampdown_length:
        return float(num_epochs - epoch) / rampdown_length
    else:
        return 1.0

In [8]:
def get_bs(resolution):
    R = int(np.log2(resolution))
    if R < 7:
        bs = 32 / 2**(max(0, R-4))
    else:
        bs = 8 / 2**(min(2, R-7))
    return int(bs)

bs_map = {2**R: get_bs(2**R) for R in range(2, 11)}

In [10]:
train_hist = train_history(['G_gan_loss',
                                          'G_idt_loss',
                                          'G_cycle_loss',
                                          'D_A_loss',
                                          'D_B_loss',                                         
                                          ])

In [11]:
L1_loss = nn.L1Loss().to(device) 

In [12]:
def save_models(models, folder):
    for k, v in models.items():
        torch.save(v.state_dict(), os.path.join(folder, k))

## Train

In [13]:
to_level = int(np.log2(target_resol))
from_level = int(np.log2(first_resol))

for R in range(from_level-1, to_level):
    
    batch_size = bs_map[2 ** (R+1)]
    phases = {'stabilize':[0, train_img//batch_size], 'fade_in':[train_img//batch_size+1, (transition_img+train_img)//batch_size]}

    for phase in ['stabilize', 'fade_in']:
        if phase in phases:
            _range = phases[phase]
            from_it = _range[0]
            cur_it = 0
            total_it = _range[1]
            remaining_it = total_it-from_it
            cur_nimg = _range[0]*batch_size
            resol = 2 ** (R+1)
            
            previous_time = time.time()
            phase_start_time = time.time()
            
            for it in range(from_it, total_it):
                cur_it += 1
                if phase == 'stabilize':
                    cur_level = R
                else:
                    cur_level = R + cur_it/remaining_it
                cur_resol = 2 ** int(np.ceil(cur_level+1))

                # get a batch noise and real images
                real_A_cur, real_A_max = data_A.next(batch_size,cur_resol,cur_level)
                real_A_cur, real_A_max = real_A_cur.to(device), real_A_max.to(device)
                
                real_B_cur, real_B_max = data_B.next(batch_size,cur_resol,cur_level)
                real_B_cur, real_B_max = real_B_cur.to(device), real_B_max.to(device)
                # ===preprocess===
                for param_group in optim_G.param_groups:
                    lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
                    lrate_coef *= _rampdown_linear(cur_nimg / 1000.0,total_kimg, rampdown_kimg)
                    param_group['lr'] = lrate_coef * g_lr_max
                for param_group in optim_D.param_groups:
                    lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
                    lrate_coef *= _rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg)
                    param_group['lr'] = lrate_coef * d_lr_max

                # ===update D===
                for model in [D_A,D_B]:
                    for param in model.parameters():
                        param.requires_grad = True
                        
                optim_D.zero_grad()

                fake_B = G_A(real_A_max, cur_level=cur_level)
                d_real_B = D_B(real_B_cur, cur_level=cur_level)
                d_fake_B = D_B(fake_B.detach(), cur_level=cur_level)

                d_real_B_loss = torch.mean((d_real_B-1)**2)
                d_fake_B_loss = torch.mean((d_fake_B-0)**2)
                 
                d_loss_B = 0.5 * (d_real_B_loss + d_fake_B_loss)
                d_loss_B.backward()
 
                fake_A = G_B(real_B_max, cur_level=cur_level)
                d_real_A = D_A(real_A_cur, cur_level=cur_level)
                d_fake_A = D_A(fake_A.detach(), cur_level=cur_level)

                d_real_A_loss = torch.mean((d_real_A-1)**2)
                d_fake_A_loss = torch.mean((d_fake_A-0)**2)
                 
                d_loss_A = 0.5 * (d_real_A_loss + d_fake_A_loss)
                d_loss_A.backward()
                
                optim_D.step()

                # ===update G===
                for model in [D_A,D_B]:
                    for param in model.parameters():
                        param.requires_grad = False
                        
                optim_G.zero_grad()
                
                d_fake_A = D_B(fake_B, cur_level=cur_level)
                d_fake_A_loss = torch.mean((d_fake_A-1)**2)
                
                sim_A_loss = L1_loss(fake_B,real_A_cur) * lambda_idt
                
                recon_A = G_B(fake_B,cur_level=cur_level)
                recon_A_loss = L1_loss(recon_A,real_A_cur) * lambda_recon
                
                G_A_loss = (d_fake_A_loss+sim_A_loss+recon_A_loss)*lambda_A
                G_A_loss.backward()
                
                d_fake_B = D_A(fake_A,cur_level=cur_level)
                d_fake_B_loss = torch.mean((d_fake_B-1)**2)
                
                sim_B_loss = L1_loss(fake_A,real_B_cur)*lambda_idt
                
                recon_B = G_A(fake_A,cur_level=cur_level)
                recon_B_loss = L1_loss(recon_B,real_B_cur)*lambda_recon
                
                G_B_loss = (d_fake_B_loss + sim_B_loss + recon_B_loss)*lambda_B
                G_B_loss.backward()
                
                G_gan_loss = d_fake_A_loss + d_fake_B_loss
                G_idt_loss = sim_A_loss + sim_B_loss
                G_recon_loss = recon_A_loss + recon_B_loss

                optim_G.step()
                
                cur_nimg += batch_size
                
                
                # ===report ===
                train_hist.add_params([G_gan_loss,G_idt_loss,G_recon_loss,d_loss_A,d_loss_B])
                if it% report_it == 0:                    
                    phase_time = time.time() - phase_start_time
                    losses = train_hist.check_current_avg()
                    it_time = time.time() - previous_time
                    vis.print_current_losses(cur_level, it, losses, it_time, phase_time)
                    vis.plot_current_losses(cur_level, float(i) / data_size, losses)
                    previous_time = time.time()
                # ===generate sample images===
                if it % show_it == 0:
                    save_result = it % save_it == 0
                    vis.display_current_results([real_A_cur,fake_B,recon_A],cur_level, it,save_result)
                    
                # ===save model===
                if it % save_it == 0:
                    save_models(all_models,project_dir)

    model_folder_at_scale = os.path.join(project_dir,str(R))
    if not os.path.isdir(model_folder_at_scale):
        os.makedirs(model_folder_at_scale)
    save_models(all_models,model_folder_at_scale)

stabilize phase, 4 resolution, 0 iteration upon 18750
{'G_gan_loss': tensor(2.0849), 'G_idt_loss': tensor(0.0956), 'G_cycle_loss': tensor(0.7171), 'D_A_loss': tensor(0.6170), 'D_B_loss': tensor(0.6948)}
stabilize phase, 4 resolution, 400 iteration upon 18750
{'G_gan_loss': tensor(0.8257), 'G_idt_loss': tensor(0.1201), 'G_cycle_loss': tensor(0.5053), 'D_A_loss': tensor(0.3082), 'D_B_loss': tensor(0.2883)}
stabilize phase, 4 resolution, 800 iteration upon 18750
{'G_gan_loss': tensor(0.5718), 'G_idt_loss': tensor(0.1164), 'G_cycle_loss': tensor(0.5233), 'D_A_loss': tensor(0.2590), 'D_B_loss': tensor(0.2482)}
stabilize phase, 4 resolution, 1200 iteration upon 18750
{'G_gan_loss': tensor(0.5669), 'G_idt_loss': tensor(0.1173), 'G_cycle_loss': tensor(0.5214), 'D_A_loss': tensor(0.2566), 'D_B_loss': tensor(0.2456)}
stabilize phase, 4 resolution, 1600 iteration upon 18750
{'G_gan_loss': tensor(0.5717), 'G_idt_loss': tensor(0.1165), 'G_cycle_loss': tensor(0.5245), 'D_A_loss': tensor(0.2531), 'D_

stabilize phase, 4 resolution, 16000 iteration upon 18750
{'G_gan_loss': tensor(0.8084), 'G_idt_loss': tensor(0.1248), 'G_cycle_loss': tensor(0.5809), 'D_A_loss': tensor(0.1839), 'D_B_loss': tensor(0.2116)}
stabilize phase, 4 resolution, 16400 iteration upon 18750
{'G_gan_loss': tensor(0.8067), 'G_idt_loss': tensor(0.1258), 'G_cycle_loss': tensor(0.5850), 'D_A_loss': tensor(0.1892), 'D_B_loss': tensor(0.2113)}
stabilize phase, 4 resolution, 16800 iteration upon 18750
{'G_gan_loss': tensor(0.8109), 'G_idt_loss': tensor(0.1231), 'G_cycle_loss': tensor(0.5937), 'D_A_loss': tensor(0.1868), 'D_B_loss': tensor(0.2084)}
stabilize phase, 4 resolution, 17200 iteration upon 18750
{'G_gan_loss': tensor(0.8249), 'G_idt_loss': tensor(0.1215), 'G_cycle_loss': tensor(0.6029), 'D_A_loss': tensor(0.1879), 'D_B_loss': tensor(0.2104)}
stabilize phase, 4 resolution, 17600 iteration upon 18750
{'G_gan_loss': tensor(0.8282), 'G_idt_loss': tensor(0.1261), 'G_cycle_loss': tensor(0.5901), 'D_A_loss': tensor(0.



fade_in phase, 16 resolution, 18800 iteration upon 37500
{'G_gan_loss': tensor(0.9441), 'G_idt_loss': tensor(0.1317), 'G_cycle_loss': tensor(0.6997), 'D_A_loss': tensor(0.1995), 'D_B_loss': tensor(0.2177)}
fade_in phase, 16 resolution, 19200 iteration upon 37500
{'G_gan_loss': tensor(1.4482), 'G_idt_loss': tensor(0.2108), 'G_cycle_loss': tensor(1.2014), 'D_A_loss': tensor(0.1695), 'D_B_loss': tensor(0.1380)}
fade_in phase, 16 resolution, 19600 iteration upon 37500
{'G_gan_loss': tensor(1.6664), 'G_idt_loss': tensor(0.1974), 'G_cycle_loss': tensor(1.1874), 'D_A_loss': tensor(0.1182), 'D_B_loss': tensor(0.0632)}
fade_in phase, 16 resolution, 20000 iteration upon 37500
{'G_gan_loss': tensor(1.9079), 'G_idt_loss': tensor(0.2060), 'G_cycle_loss': tensor(1.1563), 'D_A_loss': tensor(0.0714), 'D_B_loss': tensor(0.0236)}
fade_in phase, 16 resolution, 20400 iteration upon 37500
{'G_gan_loss': tensor(1.9954), 'G_idt_loss': tensor(0.2125), 'G_cycle_loss': tensor(1.1507), 'D_A_loss': tensor(0.0468)

fade_in phase, 16 resolution, 34800 iteration upon 37500
{'G_gan_loss': tensor(2.0020), 'G_idt_loss': tensor(0.1403), 'G_cycle_loss': tensor(0.6443), 'D_A_loss': tensor(0.0012), 'D_B_loss': tensor(0.0010)}
fade_in phase, 16 resolution, 35200 iteration upon 37500
{'G_gan_loss': tensor(2.0025), 'G_idt_loss': tensor(0.1380), 'G_cycle_loss': tensor(0.6397), 'D_A_loss': tensor(0.0012), 'D_B_loss': tensor(0.0010)}
fade_in phase, 16 resolution, 35600 iteration upon 37500
{'G_gan_loss': tensor(2.0023), 'G_idt_loss': tensor(0.1390), 'G_cycle_loss': tensor(0.6355), 'D_A_loss': tensor(0.0012), 'D_B_loss': tensor(0.0009)}
fade_in phase, 16 resolution, 36000 iteration upon 37500
{'G_gan_loss': tensor(2.0022), 'G_idt_loss': tensor(0.1383), 'G_cycle_loss': tensor(0.6333), 'D_A_loss': tensor(0.0011), 'D_B_loss': tensor(0.0009)}
fade_in phase, 16 resolution, 36400 iteration upon 37500
{'G_gan_loss': tensor(2.0022), 'G_idt_loss': tensor(0.1382), 'G_cycle_loss': tensor(0.6316), 'D_A_loss': tensor(0.0011)

KeyboardInterrupt: 