In [5]:
import random
import math
from tqdm import tqdm

import torch
from torch import optim, autograd
from torch.utils import data
from torch.nn import functional as F
from torchvision import transforms, datasets, utils

import model

g_reg_every = 4
lr = 0.002
batch = 16
g_reg_every = 4
d_reg_every = 16
iterate = 1
n_sample = 16
latent = 128
mixing = 0.9
r1 = 10
path_batch_shrink = 2
path_regularize = 2
device = 'cuda'
truncation = 0.7
truncation_mean = 4096

In [6]:
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)

def sample_data(loader):
    while True:
        for batch in loader:
            yield batch

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def make_noise(batch, latent_dim, n_noise, device):
    if n_noise == 1:
        return torch.randn(batch, latent_dim, device=device)
    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
    return noises

def mixing_noise(batch, latent_dim, prob, device):
    if prob > 0 and random.random() < prob:
        return make_noise(batch, latent_dim, 2, device)
    else:
        return [make_noise(batch, latent_dim, 1, device)]

def d_logistic_loss(real_pred, fake_pred):
    real_loss = F.softplus(-real_pred)
    fake_loss = F.softplus(fake_pred)
    return real_loss.mean() + fake_loss.mean()

def d_r1_loss(real_pred, real_img):
    grad_real, = autograd.grad(
        outputs=real_pred.sum(), inputs=real_img, create_graph=True
    )
    grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
    return grad_penalty

def g_nonsaturating_loss(fake_pred):
    loss = F.softplus(-fake_pred).mean()
    return loss

def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3]
    )
    grad = autograd.grad(
        outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
    )
    grad = [t.unsqueeze(0) for t in grad]
    grad = torch.cat(grad, 0)
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
    path_penalty = (path_lengths - path_mean).pow(2).mean()
    return path_penalty, path_mean.detach(), path_lengths

def train(loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader)
    
    pbar = tqdm(range(iterate), initial=start_iter, dynamic_ncols=True, smoothing=0.01)
    
    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    
    g_module = generator
    d_module = discriminator
    
    accum = 0.5 ** (32 / (10 * 1000))

    sample_z = torch.randn(n_sample, latent, device=device)
    
    for idx in pbar:
        i = idx + start_iter

        if i > iterate:
            print('Done!')
            break

        real_img = next(loader)
        real_img = real_img[0].to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(batch, latent, mixing, device)
        fake_img = generator(noise, truncation)
        fake_pred, _ = discriminator(fake_img)

        real_pred, _ = discriminator(real_img)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss
        loss_dict['real_score'] = real_pred.mean()
        loss_dict['fake_score'] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred, _ = discriminator(real_img)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (r1 / 2 * r1_loss * d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(batch, latent, mixing, device)
        fake_img = generator(noise, truncation)
        fake_pred, _ = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict['g'] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, batch // path_batch_shrink)
            noise = mixing_noise(
                path_batch_size, latent, mixing, device
            )
            fake_img, latents = generator(noise, truncation, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = path_regularize * g_reg_every * path_loss

            if path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (mean_path_length.item())

        loss_dict['path'] = path_loss
        loss_dict['path_length'] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        d_loss_val = loss_dict['d'].mean().item()
        g_loss_val = loss_dict['g'].mean().item()
        r1_val = loss_dict['r1'].mean().item()
        path_loss_val = loss_dict['path'].mean().item()
        real_score_val = loss_dict['real_score'].mean().item()
        fake_score_val = loss_dict['fake_score'].mean().item()
        path_length_val = loss_dict['path_length'].mean().item()

        pbar.set_description(
            (
                f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; '
                f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}'
            )
        )
        '''
        if wandb and args.wandb:
            wandb.log(
                {
                    'Generator': g_loss_val,
                    'Discriminator': d_loss_val,
                    'R1': r1_val,
                    'Path Length Regularization': path_loss_val,
                    'Mean Path Length': mean_path_length,
                    'Real Score': real_score_val,
                    'Fake Score': fake_score_val,
                    'Path Length': path_length_val,
                }
            )
        '''
        if i % 10 == 0:
            with torch.no_grad():
                g_ema.eval()
                sample = g_ema([sample_z], truncation)
                utils.save_image(
                    sample,
                    f'sample/{str(i).zfill(6)}.png',
                    nrow=int(n_sample ** 0.5),
                    normalize=True,
                    range=(-1, 1),
                )

        if i % 20 == 0:
            torch.save(
                {
                    'g': g_module.state_dict(),
                    'd': d_module.state_dict(),
                    'g_ema': g_ema.state_dict(),
                    'g_optim': g_optim.state_dict(),
                    'd_optim': d_optim.state_dict(),
                },
                f'checkpoint/{str(i).zfill(6)}.pt',
            )

In [7]:
generator = model.Generator(
    128, 128, 6, 
).to(device)
discriminator = model.Discriminator(
    128
).to(device)
g_ema = model.Generator(
    128, 128, 6
).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)

g_reg_ratio = g_reg_every / (g_reg_every + 1)
d_reg_ratio = d_reg_every / (d_reg_every + 1)

g_optim = optim.Adam(
    generator.parameters(),
    lr=lr * g_reg_ratio,
    betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
d_optim = optim.Adam(
    discriminator.parameters(),
    lr=lr * d_reg_ratio,
    betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)

transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ]
)

dataset = datasets.ImageFolder(
    '../datasets/danbooru-images/portrait/',
    transform
)
loader = data.DataLoader(
    dataset,
    batch_size=batch,
    shuffle=True,
    drop_last=True
)

RuntimeError: CUDA error: no kernel image is available for execution on the device

In [4]:
import os

args_ckpt = f'checkpoint/{str(0).zfill(6)}.pt'
ckpt = torch.load(args_ckpt)

try:
    ckpt_name = os.path.basename(args_ckpt)
    start_iter = int(os.path.splitext(ckpt_name)[0])
except ValueError:
    pass

generator.load_state_dict(ckpt['g'])
discriminator.load_state_dict(ckpt['d'])
g_ema.load_state_dict(ckpt['g_ema'])

g_optim.load_state_dict(ckpt['g_optim'])
d_optim.load_state_dict(ckpt['d_optim'])

In [6]:
#start_iter = 0
train(loader, generator, discriminator, g_optim, d_optim, g_ema, device)

d: 0.6477; g: 1.9039; r1: 0.0055; path: 0.2080; mean path: 0.1323:  25%|██▌       | 61/240 [8:55:58<26:12:46, 527.18s/it] 


KeyboardInterrupt: 

In [3]:
# debug
#debug = None

len(tags)

23

In [7]:
with torch.no_grad():
    #g_ema.eval()
    '''
    mean_latent = None
    if truncation < 1:
        mean_latent = g_ema.mean_latent(truncation_mean)
    '''
    sample_z = torch.randn(n_sample, latent, device=device)

    sample = generator([sample_z], truncation=truncation)#, truncation_latent=mean_latent)

    utils.save_image(
        sample,
        f'sample/{str(61).zfill(6)}.png',# args.pics
        nrow=int(n_sample ** 0.5),
        normalize=True,
        range=(-1, 1),
    )