In [1]:
import torch,os,time
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from utils.data import RandomNoiseGenerator,Data
from utils.train_history import train_history
from utils import visualizer,util
import itertools
from models.model import Generator, Discriminator
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=3)

In [3]:
latent_size = 512
target_resol = 256
first_resol = 4
use_sigmoid = False
train_kimg = 600
train_img = train_kimg*1000
transition_kimg = 600
transition_img = transition_kimg*1000

g_lr_max = 0.001
d_lr_max = 0.001

beta1 = 0
beta2 = 0.99

lambda_A=10
lambda_B=10
lambda_recon=0.8
lambda_idt=0.1

report_it = 400
show_it = 400
save_it=400

In [4]:
display_id=1
display_winsize=256
display_ncols=4
display_server='http://localhost'
display_port=8097
display_env='pggan'
results_dir='results'
project_name = 'pggan1'
project_dir=os.path.join(results_dir,project_name)

vis = visualizer.Visualizer(display_id,display_winsize,display_ncols,display_server,display_port,display_env,
                 project_name,results_dir)

if not os.path.isdir(project_dir):
    os.makedirs(project_dir)

create web directory results/pggan1/web...


In [5]:
real_dir = '/data/persona_cyclegan/real/trainA'
anime_dir = '/data/persona_cyclegan/anime/trainB'
# real_dir = '/data/persona_cyclegan/real_test'
# anime_dir = '/data/persona_cyclegan/anime_test'
data_A = Data(real_dir)
data_B = Data(anime_dir)

In [6]:
G_A = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)
G_B = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)

D_A = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=True)
D_B = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=True)

print(G_A)
print(D_A)
G_A,G_B,D_A,D_B = G_A.to(device),G_B.to(device),D_A.to(device),D_B.to(device)
optim_G = optim.Adam(itertools.chain(G_A.parameters(),G_B.parameters()), g_lr_max, betas=(beta1, beta2))
optim_D_A = optim.Adam(D_A.parameters(), d_lr_max, betas=(beta1, beta2))
optim_D_B = optim.Adam(D_B.parameters(), d_lr_max, betas=(beta1, beta2))

all_models = {'G_A.pkl':G_A,
             'G_B.pkl':G_B,
             'D_A.pkl':D_A,
             'D_B.pkl':D_B}
            

print('---------- Networks initialized -------------')
for model_name,model in [('G',G_A),('D',D_A)]:
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(str.format('{} has {} number of parameters', model_name, num_params))
print('-----------------------------------------------')

Generator(
  (output_layer): GSelectLayer(
    (pre): ModuleList(
      (0): Sequential(
        (0): ReflectionPad2d((3, 3, 3, 3))
        (1): Conv2d(3, 512, kernel_size=(7, 7), stride=(1, 1))
        (2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace)
      )
      (1): Sequential(
        (0): ReflectionPad2d((3, 3, 3, 3))
        (1): Conv2d(3, 512, kernel_size=(7, 7), stride=(1, 1))
        (2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace)
      )
      (2): Sequential(
        (0): ReflectionPad2d((3, 3, 3, 3))
        (1): Conv2d(3, 512, kernel_size=(7, 7), stride=(1, 1))
        (2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): ReLU(inplace)
      )
      (3): Sequential(
        (0): ReflectionPad2d((3, 3, 3, 3))
        (1): Conv2d(3, 512, kernel_size=(7, 7), stride=(1, 1))
        (2): Ins

---------- Networks initialized -------------
G has 45463253 number of parameters
D has 21573187 number of parameters
-----------------------------------------------


In [7]:
rampup_kimg = 10000
rampdown_kimg = 10000
total_kimg = 10000

def _rampup(epoch, rampup_length):
    if epoch < rampup_length:
        p = max(0.0, float(epoch)) / float(rampup_length)
        p = 1.0 - p
        return np.exp(-p*p*5.0)
    else:
        return 1.0

def _rampdown_linear(epoch, num_epochs, rampdown_length):
    if epoch >= num_epochs - rampdown_length:
        return float(num_epochs - epoch) / rampdown_length
    else:
        return 1.0

In [8]:
def get_bs(resolution):
    R = int(np.log2(resolution))
    if R < 7:
        bs = 32 / 2**(max(0, R-4))
    else:
        bs = 8 / 2**(min(2, R-7))
    return int(bs)

bs_map = {2**R: get_bs(2**R) for R in range(2, 11)}

In [9]:
train_hist = train_history(['G_gan_loss',
                              'G_idt_loss',
                              'G_cycle_loss',
                              'D_A_loss',
                              'D_B_loss',                                         
                              ])

In [10]:
L1_loss = nn.L1Loss().to(device) 

In [11]:
def save_models(models, folder):
    for k, v in models.items():
        torch.save(v.state_dict(), os.path.join(folder, k))

## Train

In [12]:
to_level = int(np.log2(target_resol))
from_level = int(np.log2(first_resol))

for R in range(from_level-1, to_level):
    batch_size = bs_map[2 ** (R+1)]
    phases = {'stabilize':[0, train_img//batch_size], 'fade_in':[train_img//batch_size+1, (transition_img+train_img)//batch_size]}
    print('starting scale %d, batch_size is %d, and the iterations in each phase is %d'%(R,batch_size,train_img//batch_size))
#     for phase in ['fade_in']: 
    for phase in ['stabilize', 'fade_in']:
        num_pool = 80
        fake_A_pool = util.ImagePool(num_pool)
        fake_B_pool = util.ImagePool(num_pool)
        if phase in phases:
            _range = phases[phase]
            from_it = _range[0]
            cur_it = 0
            total_it = _range[1]
            remaining_it = total_it-from_it
            cur_nimg = _range[0]*batch_size
            resol = 2 ** (R+1)
            
            previous_time = time.time()
            phase_start_time = time.time()
            
            for it in range(from_it, total_it):
                cur_it += 1
                if phase == 'stabilize':
                    cur_level = R
                else:
                    cur_level = R + cur_it/remaining_it
                cur_resol = 2 ** int(np.ceil(cur_level+1))

                # get a batch noise and real images
                real_A = data_A.next(batch_size,cur_resol,cur_level)
                real_A = real_A.to(device)
                
                real_B = data_B.next(batch_size,cur_resol,cur_level)
                real_B = real_B.to(device)
                # ===preprocess===
#                 for param_group in optim_G.param_groups:
#                     lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
#                     lrate_coef *= _rampdown_linear(cur_nimg / 1000.0,total_kimg, rampdown_kimg)
#                     param_group['lr'] = lrate_coef * g_lr_max
  
#                 for param_group in optim_D_A.param_groups:
#                     lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
#                     lrate_coef *= _rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg)
#                     param_group['lr'] = lrate_coef * d_lr_max
#                 for param_group in optim_D_B.param_groups:
#                     lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
#                     lrate_coef *= _rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg)
#                     param_group['lr'] = lrate_coef * d_lr_max
                    
                # ===update D===
                for model in [D_A,D_B]:
                    for param in model.parameters():
                        param.requires_grad = True
                        
                optim_D_A.zero_grad()
                optim_D_B.zero_grad()

                fake_B = G_A(real_A, cur_level=cur_level)
                fake_B_from_pool = fake_B_pool.query(fake_B.detach())
                d_real_B = D_B(real_B, cur_level=cur_level)
                d_fake_B = D_B(fake_B_from_pool, cur_level=cur_level)

                d_real_B_loss = torch.mean((d_real_B-1)**2)
                d_fake_B_loss = torch.mean((d_fake_B-0)**2)
                 
                d_loss_B = 0.5 * (d_real_B_loss + d_fake_B_loss)
                d_loss_B.backward()
                optim_D_B.step()
 
                fake_A = G_B(real_B, cur_level=cur_level)
                fake_A_from_pool = fake_A_pool.query(fake_A.detach())
                d_real_A = D_A(real_A, cur_level=cur_level)
                d_fake_A = D_A(fake_A_from_pool, cur_level=cur_level)

                d_real_A_loss = torch.mean((d_real_A-1)**2)
                d_fake_A_loss = torch.mean((d_fake_A-0)**2)
                 
                d_loss_A = 0.5 * (d_real_A_loss + d_fake_A_loss)
                d_loss_A.backward()    
                optim_D_A.step()

                # ===update G===
                for model in [D_A,D_B]:
                    for param in model.parameters():
                        param.requires_grad = False
                        
                optim_G.zero_grad()
                
                d_fake_A = D_B(fake_B, cur_level=cur_level)
                d_fake_A_loss = torch.mean((d_fake_A-1)**2)
                
                sim_A_loss = L1_loss(fake_B,real_A) * lambda_idt
                
                recon_A = G_B(fake_B,cur_level=cur_level)
                recon_A_loss = L1_loss(recon_A,real_A) * lambda_recon
                
                G_A_loss = (d_fake_A_loss+sim_A_loss+recon_A_loss)*lambda_A
                G_A_loss.backward()
                
                d_fake_B = D_A(fake_A,cur_level=cur_level)
                d_fake_B_loss = torch.mean((d_fake_B-1)**2)
                
                sim_B_loss = L1_loss(fake_A,real_B)*lambda_idt
                
                recon_B = G_A(fake_A,cur_level=cur_level)
                recon_B_loss = L1_loss(recon_B,real_B)*lambda_recon
                
                G_B_loss = (d_fake_B_loss + sim_B_loss + recon_B_loss)*lambda_B
                G_B_loss.backward()
                
                G_gan_loss = d_fake_A_loss + d_fake_B_loss
                G_idt_loss = sim_A_loss + sim_B_loss
                G_recon_loss = recon_A_loss + recon_B_loss

                optim_G.step()
                
                cur_nimg += batch_size
                
                cur_scale = R+cur_it/total_it/2
                if phase == 'fade_in':
                    cur_scale+=0.5
                    
                # ===report ===
                train_hist.add_params([G_gan_loss,G_idt_loss,G_recon_loss,d_loss_A,d_loss_B])
                if it% report_it == 0:                    
                    phase_time = time.time() - phase_start_time
                    losses = train_hist.check_current_avg()
                    it_time = time.time() - previous_time
                    vis.print_current_losses(cur_level, it, losses, it_time, phase_time)
                    vis.plot_current_losses(cur_scale, losses)
                    previous_time = time.time()
                # ===generate sample images===
                if it % show_it == 0:
                    save_result = it % save_it == 0
                    vis.display_current_results([real_A,fake_B,recon_A],cur_scale,save_result)
                    
                # ===save model===
                if it % save_it == 0:
                    save_models(all_models,project_dir)

    model_folder_at_scale = os.path.join(project_dir,str(R))
    if not os.path.isdir(model_folder_at_scale):
        os.makedirs(model_folder_at_scale)
    save_models(all_models,model_folder_at_scale)

starting scale 1, batch_size is 32, and the iterations in each phase is 18750


NameError: name 'optim_D' is not defined