In [1]:
import torch,os
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from utils.data import RandomNoiseGenerator,Data
from utils.train_history import train_history
from models.model import Generator, Discriminator
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda', index=3)

In [4]:
latent_size = 512
target_resol = 256
first_resol = 4
use_sigmoid = False
train_kimg = 600
transition_kimg = 600

g_lr_max = 0.001
d_lr_max = 0.001

beta1 = 0
beta2 = 0.99

sample_freq = 100
save_freq = 100
report_freq = 100

In [5]:
train_dir = '/data/persona_cyclegan/anime/trainB'
result_dir = 'results'

In [5]:
data = Data(train_dir)
noise = RandomNoiseGenerator(latent_size, 'gaussian')

In [7]:
G = Generator(num_channels=3, latent_size=latent_size, resolution=target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=True)
D = Discriminator(num_channels=3, mbstat_avg='all', resolution=target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=use_sigmoid)
print(G)
print(D)
G,D = G.to(device),D.to(device)
optim_G = optim.Adam(G.parameters(), g_lr_max, betas=(beta1, beta2))
optim_D = optim.Adam(D.parameters(), d_lr_max, betas=(beta1, beta2))

Generator(
  (output_layer): GSelectLayer(
    (pre): PixelNormLayer(eps = 1e-08)
    (chain): ModuleList(
      (0): Sequential(
        (0): ReshapeLayer()
        (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(3, 3), bias=False)
        (2): WScaleLayer(incoming = Conv2d)
        (3): LeakyReLU(negative_slope=0.2)
        (4): PixelNormLayer(eps = 1e-08)
        (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (6): WScaleLayer(incoming = Conv2d)
        (7): LeakyReLU(negative_slope=0.2)
        (8): PixelNormLayer(eps = 1e-08)
      )
      (1): Sequential(
        (0): Upsample(scale_factor=2, mode=nearest)
        (1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (2): WScaleLayer(incoming = Conv2d)
        (3): LeakyReLU(negative_slope=0.2)
        (4): PixelNormLayer(eps = 1e-08)
        (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   

In [7]:
rampup_kimg = 10000
rampdown_kimg = 10000
total_kimg = 10000

def _rampup(epoch, rampup_length):
    if epoch < rampup_length:
        p = max(0.0, float(epoch)) / float(rampup_length)
        p = 1.0 - p
        return np.exp(-p*p*5.0)
    else:
        return 1.0

def _rampdown_linear(epoch, num_epochs, rampdown_length):
    if epoch >= num_epochs - rampdown_length:
        return float(num_epochs - epoch) / rampdown_length
    else:
        return 1.0

In [8]:
def get_bs(resolution):
    R = int(np.log2(resolution))
    if R < 7:
        bs = 32 / 2**(max(0, R-4))
    else:
        bs = 8 / 2**(min(2, R-7))
    return int(bs)

bs_map = {2**R: get_bs(2**R) for R in range(2, 11)}

In [9]:
train_hist = train_history(['G_loss','D_loss'])

## Train

In [10]:
to_level = int(np.log2(target_resol))
from_level = int(np.log2(first_resol))

train_img = int(train_kimg * 1000)
transition_img = int(transition_kimg * 1000)

for R in range(from_level-1, to_level):
    batch_size = bs_map[2 ** (R+1)]
    phases = {'stabilize':[0, train_img//batch_size], 'fade_in':[train_img//batch_size+1, (transition_img+train_img)//batch_size]}

    for phase in ['stabilize', 'fade_in']:
        if phase in phases:
            _range = phases[phase]
            from_it = _range[0]
            total_it = _range[1]
            cur_nimg = _range[0]*batch_size
            resol = 2 ** (R+1)
            for it in range(from_it, total_it):
                if phase == 'stabilize':
                    cur_level = R
                else:
                    cur_level = R + total_it/float(from_it)
                cur_resol = 2 ** int(np.ceil(cur_level+1))

                # get a batch noise and real images
                z = noise(batch_size)
                real = data.next(batch_size,cur_resol,cur_level)

                # ===preprocess===
                z = Variable(torch.from_numpy(z))
                z,real = z.to(device),real.to(device)
                for param_group in optim_G.param_groups:
                    lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
                    lrate_coef *= _rampdown_linear(cur_nimg / 1000.0,total_kimg, rampdown_kimg)
                    param_group['lr'] = lrate_coef * g_lr_max
                for param_group in optim_D.param_groups:
                    lrate_coef = _rampup(cur_nimg / 1000.0, rampup_kimg)
                    lrate_coef *= _rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg)
                    param_group['lr'] = lrate_coef * d_lr_max

                # ===update D===
                optim_D.zero_grad()

                fake = G(z, cur_level=cur_level)
                d_real = D(real, cur_level=cur_level, gdrop_strength=0)
                d_fake = D(fake.detach(), cur_level=cur_level)
                
#                 d_adv_loss_fake = MSE_Loss(d_fake, False) * 0.1
#                 d_adv_loss_real = MSE_Loss(d_real, True)
                
                d_adv_loss_fake = torch.mean((d_fake-0)**2) * 0.1
                d_adv_loss_real = torch.mean((d_real-1)**2)
                
                
                d_loss = 0.5 * (d_adv_loss_real + d_adv_loss_fake)
                d_loss.backward()
                optim_D.step()

                # ===update G===
                optim_G.zero_grad()
                d_fake = D(fake, cur_level=cur_level)
                g_loss = torch.mean((d_fake-1)**2)
                g_loss.backward()
                optim_G.step()
                
                cur_nimg += batch_size
                # ===report ===
                train_hist.add_params([g_loss,d_loss])
                if it% report_freq == 0:
                    print('%s phase, %d resolution, %d iteration upon %d'%(phase, cur_resol, it, total_it))
                    print(train_hist.check_current_avg())

                # ===generate sample images===
                samples = []
                if it % sample_freq == 0:
                    img_to_save = (fake.detach()[0].cpu().numpy().transpose(1,2,0)+1)/2
                    plt.imsave(os.path.join(result_dir, '%dx%d-%s-%s.png' % (cur_resol, cur_resol, phase, str(it).zfill(6))), img_to_save)

                # ===save model===
                if it % save_freq == 0:
                    g_file = os.path.join(result_dir,'G.pth')
                    d_file = os.path.join(result_dir,'D.pth')
                    torch.save(G.state_dict(), g_file)
                    torch.save(D.state_dict(), d_file)


RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:0 and input b is on cuda:3