In [1]:
import torch,os
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from utils.data import RandomNoiseGenerator,Data
from utils.train_history import train_history
import itertools
from models.model import Generator, Discriminator
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=3)

In [3]:
latent_size = 1024
target_resol = 256
first_resol = 4
use_sigmoid = False
train_kimg = 600
train_img = train_kimg*1000
transition_kimg = 600
transition_img = transition_kimg*1000

g_lr_max = 0.001
d_lr_max = 0.001

beta1 = 0
beta2 = 0.99

sample_freq = 100
save_freq = 100
report_freq = 100

lambda_A=10
lambda_B=10
lambda_recon=0.8
lambda_idt=0.1

In [4]:
result_dir = 'results'

In [5]:
real_dir = '/data/persona_cyclegan/real/trainA'
anime_dir = '/data/persona_cyclegan/anime/trainB'

data_A = Data(real_dir)
data_B = Data(anime_dir)

In [6]:
G_A = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)
G_B = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)

D_A = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=True)
D_B = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=True)

# print(G_A)
# print(D_A)
G_A,G_B,D_A,D_B = G_A.to(device),G_B.to(device),D_A.to(device),D_B.to(device)
optim_G = optim.Adam(itertools.chain(G_A.parameters(),G_B.parameters()), g_lr_max, betas=(beta1, beta2))
optim_D = optim.Adam(itertools.chain(D_A.parameters(),D_B.parameters()), d_lr_max, betas=(beta1, beta2))

  kaiming_normal(layer.weight, a=gain)


In [7]:
rampup_kimg = 10000
rampdown_kimg = 10000
total_kimg = 10000

def _rampup(epoch, rampup_length):
    if epoch < rampup_length:
        p = max(0.0, float(epoch)) / float(rampup_length)
        p = 1.0 - p
        return np.exp(-p*p*5.0)
    else:
        return 1.0

def _rampdown_linear(epoch, num_epochs, rampdown_length):
    if epoch >= num_epochs - rampdown_length:
        return float(num_epochs - epoch) / rampdown_length
    else:
        return 1.0

In [8]:
def get_bs(resolution):
    R = int(np.log2(resolution))
    if R < 7:
        bs = 32 / 2**(max(0, R-4))
    else:
        bs = 8 / 2**(min(2, R-7))
    return int(bs)

bs_map = {2**R: get_bs(2**R) for R in range(2, 11)}

In [9]:
train_hist = train_history(['G_loss','D_loss'])

In [10]:
L1_loss = nn.L1Loss().to(device) 

## Train

In [11]:
# real_A = data_A.next(batch_size,cur_resol,cur_level)
# real_A = real_A.to(device)
# real_A.get_device()
# real_A.size()

In [12]:
to_level = int(np.log2(target_resol))
from_level = int(np.log2(first_resol))

for R in range(from_level-1, to_level):
    batch_size = bs_map[2 ** (R+1)]
    phases = {'stabilize':[0, train_img//batch_size], 'fade_in':[train_img//batch_size+1, (transition_img+train_img)//batch_size]}

    for phase in ['stabilize', 'fade_in']:
        if phase in phases:
            _range = phases[phase]
            from_it = _range[0]
            total_it = _range[1]
            cur_nimg = _range[0]*batch_size
            resol = 2 ** (R+1)
            for it in range(from_it, total_it):
                if phase == 'stabilize':
                    cur_level = R
                else:
                    cur_level = R + total_it/float(from_it)
                cur_resol = 2 ** int(np.ceil(cur_level+1))

                # get a batch noise and real images
                real_A_cur, real_A_max = data_A.next(batch_size,cur_resol,cur_level)
                real_A_cur, real_A_max = real_A_cur.to(device), real_A_max.to(device)
                
                real_B_cur, real_B_max = data_B.next(batch_size,cur_resol,cur_level)
                real_B_cur, real_B_max = real_B_cur.to(device), real_B_max.to(device)
                # ===preprocess===
                for param_group in optim_G.param_groups:
                    lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
                    lrate_coef *= _rampdown_linear(cur_nimg / 1000.0,total_kimg, rampdown_kimg)
                    param_group['lr'] = lrate_coef * g_lr_max
                for param_group in optim_D.param_groups:
                    lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
                    lrate_coef *= _rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg)
                    param_group['lr'] = lrate_coef * d_lr_max

                # ===update D===
                for model in [D_A,D_B]:
                    for param in model.parameters():
                        param.requires_grad = True
                        
                optim_D.zero_grad()

                fake_B = G_A(real_A_max, cur_level=cur_level)
                d_real_B = D_B(real_B_cur, cur_level=cur_level)
                d_fake_B = D_B(fake_B.detach(), cur_level=cur_level)

                d_real_B_loss = torch.mean((d_real_B-1)**2)
                d_fake_B_loss = torch.mean((d_fake_B-0)**2)
                 
                d_loss_B = 0.5 * (d_real_B_loss + d_fake_B_loss)
                d_loss_B.backward()
 
                fake_A = G_B(real_B_max, cur_level=cur_level)
                d_real_A = D_A(real_A_cur, cur_level=cur_level)
                d_fake_A = D_A(fake_A.detach(), cur_level=cur_level)

                d_real_A_loss = torch.mean((d_real_A-1)**2)
                d_fake_A_loss = torch.mean((d_fake_A-0)**2)
                 
                d_loss_A = 0.5 * (d_real_A_loss + d_fake_A_loss)
                d_loss_A.backward()
                
                optim_D.step()

                # ===update G===
                for model in [D_A,D_B]:
                    for param in model.parameters():
                        param.requires_grad = False
                        
                optim_G.zero_grad()
                
                d_fake_A = D_B(fake_B, cur_level=cur_level)
                d_fake_A_loss = torch.mean((d_fake_A-1)**2)
                
                sim_A_loss = L1_loss(fake_B,real_A_cur) * lambda_idt
                
                recon_A = G_B(fake_B,cur_level=cur_level)
                recon_A_loss = L1_loss(recon_A,real_A_cur) * lambda_recon
                
                G_A_loss = (d_fake_A_loss+sim_A_loss+recon_A_loss)*lambda_A
                G_A_loss.backward()
                
                d_fake_B = D_A(fake_A,cur_level=cur_level)
                d_fake_B_loss = torch.mean((d_fake_B-1)**2)
                
                sim_B_loss = L1_loss(fake_A,real_B_cur)*lambda_idt
                
                recon_B = G_A(fake_A,cur_level=cur_level)
                recon_B_loss = L1_loss(recon_B,real_B_cur)*lambda_recon
                
                G_B_loss = (d_fake_B_loss + sim_B_loss + recon_B_loss)*lambda_B
                G_B_loss.backward()

                optim_G.step()
                
                cur_nimg += batch_size
                
#                 # ===report ===
#                 train_hist.add_params([g_loss,d_loss])
#                 if it% report_freq == 0:
#                     print('%s phase, %d resolution, %d iteration upon %d'%(phase, cur_resol, it, total_it))
#                     print(train_hist.check_current_avg())

#                 # ===generate sample images===
#                 samples = []
#                 if it % sample_freq == 0:
#                     img_to_save = (fake.detach()[0].cpu().numpy().transpose(1,2,0)+1)/2
#                     plt.imsave(os.path.join(result_dir, '%dx%d-%s-%s.png' % (cur_resol, cur_resol, phase, str(it).zfill(6))), img_to_save)

#                 # ===save model===
#                 if it % save_freq == 0:
#                     g_file = os.path.join(result_dir,'G.pth')
#                     d_file = os.path.join(result_dir,'D.pth')
#                     torch.save(G.state_dict(), g_file)
#                     torch.save(D.state_dict(), d_file)


RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:0 and input b is on cuda:3