In [4]:
import random
import math
import os
import json
from tqdm import tqdm

import torch
from torch import optim, autograd
from torch.utils import data
from torch.nn import functional as F
from torchvision import transforms, datasets, utils

from model import Generator, Discriminator
from dataset import PortraitDataset

device = 'cpu'
path = '../datasets/danbooru-images/portrait/'  # path to the dataset
iteration = 4500                                # total training iterations
batch = 16                                      # batch size
n_sample = 16                                   # number of the samples generated during training
img_size = 128                                  # image sizes for the model
lb_size = 23                                    # label sizes for the model
r1 = 10                                         # weight of the r1 regularization
path_regularize = 2                             # weight of the path length regularization
path_batch_shrink = 2                           # batch size reducing factor for the path length regularization
                                                # (reduce memory consumption)
d_reg_every = 16                                # interval of the applying r1 regularization
g_reg_every = 4                                 # interval of the applying path length regularization
mixing = 0.9                                    # probability of latent code mixing
#ckpt = None                                     # path to the checkpoints to resume training
ckpt = f'checkpoint/{str(4500).zfill(6)}.pt'
lr = 0.002                                      # learning rate
channel_multiplier = 1                          # channel multiplier factor for the model
latent = 128
n_mlp = 6
start_iter = 0

In [2]:
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)

def sample_data(loader):
    while True:
        for img_batch, lb_batch in loader:
            yield img_batch, lb_batch

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def make_noise(batch, latent_dim, n_noise, device):
    if n_noise == 1:
        return torch.randn(batch, latent_dim, device=device)
    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
    return noises

def mixing_noise(batch, latent_dim, prob, device):
    if prob > 0 and random.random() < prob:
        return make_noise(batch, latent_dim, 2, device)
    else:
        return [make_noise(batch, latent_dim, 1, device)]

def d_logistic_loss(real_pred, fake_pred):
    real_loss = F.softplus(-real_pred)
    fake_loss = F.softplus(fake_pred)
    return real_loss.mean() + fake_loss.mean()

def d_r1_loss(real_pred, real_img):
    grad_real, = autograd.grad(
        outputs=real_pred.sum(), inputs=real_img, create_graph=True
    )
    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
    return grad_penalty

def g_nonsaturating_loss(fake_pred):
    loss = F.softplus(-fake_pred).mean()
    return loss

def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3]
    )
    grad, = autograd.grad(
        outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
    )
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
    path_penalty = (path_lengths - path_mean).pow(2).mean()
    return path_penalty, path_mean.detach(), path_lengths

with open('sample_labels.json', 'r') as f:
    sample_labels = json.load(f)
total_samples = len(sample_labels)
def get_random_labels(batch, device):
    labels = []
    for i in range(batch):
        labels.append(sample_labels[random.randint(0, total_samples - 1)])
    return torch.Tensor(labels).float().to(device)

def train(loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader)
    
    pbar = range(iteration)
    pbar = tqdm(pbar, initial=start_iter, dynamic_ncols=True, smoothing=0.01)
    
    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}
    
    g_module = generator
    d_module = discriminator
    
    accum = 0.5 ** (32 / (10 * 1000))

    sample_z = torch.randn(n_sample, latent, device=device)
    sample_l = get_random_labels(n_sample, device=device)
    
    for idx in pbar: 
        i = idx + start_iter

        if i > iteration:
            print('Done!')
            break

        real_img, real_lb = next(loader)
        real_img = real_img.to(device)
        real_lb = real_lb.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(batch, latent, mixing, device)
        fake_lb = get_random_labels(batch, device)
        fake_img, _ = generator(noise, fake_lb)
        
        fake_pred = discriminator(fake_img, fake_lb)
        real_pred = discriminator(real_img, real_lb)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict['d'] = d_loss
        loss_dict['real_score'] = real_pred.mean()
        loss_dict['fake_score'] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        d_regularize = i % d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True
            real_pred = discriminator(real_img, real_lb)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (r1 / 2 * r1_loss * d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict['r1'] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(batch, latent, mixing, device)
        fake_lb = get_random_labels(batch, device)
        fake_img, _ = generator(noise, fake_lb)
        fake_pred = discriminator(fake_img, fake_lb)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict['g'] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, batch // path_batch_shrink)
            noise = mixing_noise(
                path_batch_size, latent, mixing, device
            )
            fake_lb = get_random_labels(path_batch_size, device)
            fake_img, latents = generator(noise, fake_lb, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = path_regularize * g_reg_every * path_loss

            if path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (mean_path_length.item())

        loss_dict['path'] = path_loss
        loss_dict['path_length'] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        d_loss_val = loss_dict['d'].mean().item()
        g_loss_val = loss_dict['g'].mean().item()
        r1_val = loss_dict['r1'].mean().item()
        path_loss_val = loss_dict['path'].mean().item()
        real_score_val = loss_dict['real_score'].mean().item()
        fake_score_val = loss_dict['fake_score'].mean().item()
        path_length_val = loss_dict['path_length'].mean().item()

        pbar.set_description(
            (
                f'd: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; '
                f'path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}'
            )
        )
        '''
        if wandb and args.wandb:
            wandb.log(
                {
                    'Generator': g_loss_val,
                    'Discriminator': d_loss_val,
                    'R1': r1_val,
                    'Path Length Regularization': path_loss_val,
                    'Mean Path Length': mean_path_length,
                    'Real Score': real_score_val,
                    'Fake Score': fake_score_val,
                    'Path Length': path_length_val,
                }
            )
        '''
        if i % 100 == 0:
            with torch.no_grad():
                #g_ema.eval()
                sample, _ = generator([sample_z], sample_l)
                utils.save_image(
                    sample,
                    f'sample/{str(i).zfill(6)}.png',
                    nrow=int(n_sample ** 0.5),
                    normalize=True,
                    range=(-1, 1),
                )

        if (i % 1000 == 0) or (i == iteration):
            torch.save(
                {
                    'g': g_module.state_dict(),
                    'd': d_module.state_dict(),
                    'g_ema': g_ema.state_dict(),
                    'g_optim': g_optim.state_dict(),
                    'd_optim': d_optim.state_dict(),
                },
                f'checkpoint/{str(i).zfill(6)}.pt',
            )

In [5]:
generator = Generator(
    img_size,
    lb_size,
    latent,
    n_mlp,
    channel_multiplier=channel_multiplier
).to(device)
discriminator = Discriminator(
    img_size,
    lb_size,
    channel_multiplier=channel_multiplier
).to(device)
g_ema = Generator(
    img_size,
    lb_size,
    latent,
    n_mlp,
    channel_multiplier=channel_multiplier
).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)

g_reg_ratio = g_reg_every / (g_reg_every + 1)
d_reg_ratio = d_reg_every / (d_reg_every + 1)

g_optim = optim.Adam(
    generator.parameters(),
    lr=lr * g_reg_ratio,
    betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
)
d_optim = optim.Adam(
    discriminator.parameters(),
    lr=lr * d_reg_ratio,
    betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
)

if ckpt is not None:
    print('load model:', ckpt)

    args_ckpt = ckpt
    ckpt = torch.load(args_ckpt, map_location=lambda storage, loc: storage)

    try:
        ckpt_name = os.path.basename(args_ckpt)
        start_iter = int(os.path.splitext(ckpt_name)[0])
    except ValueError:
        pass

    generator.load_state_dict(ckpt['g'])
    discriminator.load_state_dict(ckpt['d'])
    g_ema.load_state_dict(ckpt['g_ema'])

    g_optim.load_state_dict(ckpt['g_optim'])
    d_optim.load_state_dict(ckpt['d_optim'])
    

transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
    ]
)

dataset = PortraitDataset(
    path,
    transform
)
loader = data.DataLoader(
    dataset,
    batch_size=batch,
    shuffle=True,
    drop_last=True
)

load model: checkpoint/004500.pt


In [11]:
#start_iter = 360
train(loader, generator, discriminator, g_optim, d_optim, g_ema, device)

d: 0.9280; g: 1.3966; r1: 0.0056; path: 0.0046; mean path: 0.3242: : 4501it [7:40:05,  7.67s/it]                          

Done!





In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('parameters number of generator:', count_parameters(generator))
print('parameters number of discriminator:', count_parameters(discriminator))

parameters number of generator: 1195277
parameters number of discriminator: 1210000


In [10]:
z = torch.randn(n_sample, latent, device=device)
l = get_random_labels(n_sample, device=device)

with torch.no_grad():
    mean_latent = g_ema.mean_latent(4096, get_random_labels(4096, device=device))
    g_ema.eval()
    sample, _ = g_ema(
        [z], l, truncation=0.7, truncation_latent=mean_latent
    )
    utils.save_image(
        sample,
        f'sample/{str(80000).zfill(6)}.png',
        nrow=int(n_sample ** 0.5),
        normalize=True,
        range=(-1, 1),
    )

In [14]:
# debug
with open('tags.json', 'r') as f:
    tags = json.load(f)

def convert2tags(l):
    labels = []
    for i, v in enumerate(l):
        if v > 0.5:
            labels.append(tags[i])
    return labels

[convert2tags(li) for li in l[:16]]

[['blush', 'brown_hair', 'hat', 'red_eyes', 'ribbon', 'twintails'],
 ['brown_eyes',
  'brown_hair',
  'open_mouth',
  'red_eyes',
  'ribbon',
  'smile',
  'twintails'],
 ['blonde_hair', 'hat', 'purple_eyes', 'purple_hair', 'ribbon', 'smile'],
 ['animal_ears', 'brown_hair', 'green_eyes', 'ribbon', 'smile'],
 ['blonde_hair', 'blush', 'green_eyes', 'hat', 'smile'],
 ['bangs', 'pink_hair', 'purple_eyes', 'purple_hair', 'smile'],
 ['animal_ears', 'blue_eyes', 'blush', 'hair_ornament', 'smile'],
 ['brown_eyes', 'brown_hair', 'red_eyes'],
 ['black_hair',
  'blush',
  'hair_ornament',
  'open_mouth',
  'purple_hair',
  'smile'],
 ['blue_eyes',
  'blue_hair',
  'blush',
  'green_hair',
  'open_mouth',
  'red_eyes',
  'smile'],
 ['brown_eyes', 'hat', 'silver_hair'],
 ['animal_ears',
  'bangs',
  'blush',
  'green_eyes',
  'green_hair',
  'hair_ornament',
  'open_mouth',
  'smile',
  'yellow_eyes'],
 ['blush',
  'brown_eyes',
  'brown_hair',
  'hair_ornament',
  'open_mouth',
  'ribbon',
  'smile