In [None]:
!pip install git+https://github.com/didriknielsen/survae_flows.git

In [None]:
import numpy as np
from PIL import Image
import cv2
import math
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from survae.distributions import DataParallelDistribution
from survae.distributions import StandardNormal, ConditionalNormal, ConditionalBernoulli, Distribution, ConditionalDistribution
import torch
import numpy as np
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from skimage import color
import zipfile
import os
import io
from PIL import Image
import cv2
from pathlib import Path
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import grad

from survae.distributions import StandardNormal, ConditionalNormal, ConditionalBernoulli, Distribution, ConditionalDistribution
from survae.utils import sum_except_batch, mean_except_batch
from torch.distributions import Normal
from torchvision.models import resnet50, resnet18
from torchvision.models.segmentation import fcn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.utils import load_state_dict_from_url

In [None]:
def linedraw(x):
    x = x.mean(1, keepdims=True)
    dilated = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
    diff = torch.abs(x - dilated)
    return diff
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
unnormalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                                   std=[1/0.229, 1/0.224, 1/0.225])

In [None]:
class ANIME(torchvision.datasets.VisionDataset):

    def __init__(self, img_size, train=True):
        transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()])
        super().__init__('../input/moeimouto-faces/moeimouto-faces', transform=transform)
        self.img_size = img_size
        self.samples = list(Path(self.root).glob("0*/*.png" if train else "1*/*.png"))

    def __getitem__(self, index: int):
        img = Image.open(str(self.samples[index])).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self) -> int:
        return len(self.samples)

def get_data_loaders(batch_size, img_size=256):
    tr_loader = DataLoader(ANIME(img_size, train=True), batch_size, shuffle=True, num_workers=2)
    va_loader = DataLoader(ANIME(img_size, train=False), batch_size, num_workers=2)
    return tr_loader, va_loader

In [None]:
class CustomizedResnet(nn.ModuleDict):
    def __init__(self, out_channels=20, build_fn=resnet50, fpn=True, aux_in=0, ori_out=False, spp=False):
        self.aux_in = aux_in
        self.fpn = fpn
        model = build_fn(True)
        model.conv1.in_channels = 1
        model.conv1.weight.data = model.conv1.weight.data.sum(1, keepdims=True)
        super(CustomizedResnet, self).__init__(model.named_children())
        if aux_in:
            self.decode = nn.Sequential(nn.Conv2d(aux_in, 256 if build_fn == resnet50 else 64, 1))
        if fpn:
            fpn_dim = 256 if build_fn == resnet50 else 64
            self.out4 = nn.Conv2d(2048 if build_fn == resnet50 else 512, fpn_dim, 1)
            self.out3 = nn.Conv2d(1024 if build_fn == resnet50 else 256, fpn_dim, 1)
            self.out2 = nn.Conv2d(512 if build_fn == resnet50 else 128, fpn_dim, 1)

            self.up3 = nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1)
            self.up2 = nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1)
            out = [
                nn.Conv2d(fpn_dim, fpn_dim//2, 3, 1, 1), nn.BatchNorm2d(fpn_dim//2), nn.ReLU(),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(fpn_dim//2, fpn_dim//4, 3, 1, 1), nn.BatchNorm2d(fpn_dim//4), nn.ReLU(),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(fpn_dim//4, fpn_dim//8, 3, 1, 1), nn.BatchNorm2d(fpn_dim//8), nn.ReLU(),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(fpn_dim//8, out_channels, 3, 1, 1)
            ]
            self.out = nn.Sequential(*out)
        else:
            self.fc = nn.Conv2d(2048 if build_fn == resnet50 else 512, out_channels, 1)

    def forward(self, x):
        if self.aux_in:
            x, z = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x2 = self.layer2(x + self.decode(z) if self.aux_in else x)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        if self.fpn:
            z4 = self.out4(x4)
            z3 = self.up3(self.out3(x3) + F.interpolate(z4, scale_factor=2))
            z2 = self.up2(self.out2(x2) + F.interpolate(z3, scale_factor=2))
            return self.out(z2)
        return self.fc(self.avgpool(x4))


class ConditionalNormalMean(ConditionalNormal):
    def sample(self, context):
        return self.mean(context)

class VAE(Distribution):
    def __init__(self, prior:ConditionalNormal, latent_size=20, vae=True):
        super().__init__()
        self.prior = prior
        self.vae = vae
        self.encoder = ConditionalNormal(CustomizedResnet(latent_size*2, resnet18, aux_in=3, fpn=False), 1)
        self.decoder = ConditionalNormalMean(CustomizedResnet(3*2, aux_in=latent_size, ori_out=True, spp=True), 1)

    def log_prob(self, x, l):
        z, log_qz = self.encoder.sample_with_log_prob(context=(l, F.avg_pool2d(x, 4, 4)))
        log_px = self.decoder.log_prob(x, context=(l, z))
        log_p = self.prior.log_prob(z, l) + log_px - log_qz
        return log_p

    def sample(self, l, num_samples=1):
        z = self.prior.sample(l)
        x = self.decoder.sample(context=(l, z))
        return x

class Discriminator(nn.ModuleDict):
    def __init__(self):
        super(Discriminator, self).__init__(resnet18(True).named_children())
        self.fc = nn.Linear(512, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x2 = self.layer2(x)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        return self.fc(self.avgpool(x4).flatten(1))

class RejVAE(VAE):
    def __init__(self, prior, latent_size=20, vae=True):
        super().__init__(prior, latent_size, vae)
        self.sampler = Discriminator()
        self.register_buffer('rej_prob', torch.tensor(0.5))

    def _grad_pen(self, G: torch.Tensor): 
        return grad_pen

    def log_prob(self, x, l):
        log_posterior = self.sampler(x).flatten()
        G = super().sample(l)
        G = 1.01 * G.detach() - 0.01 * G
        log_prior = self.sampler(G).flatten()
        alpha = torch.rand((x.shape[0], 1, 1, 1), device=x.device)
        x_hat = alpha * x + (1 - alpha) * G
        x_hat = x_hat.detach().requires_grad_()
        lipschitz_grad = grad(
                outputs=self.sampler(x_hat).sum(), 
                inputs=x_hat,
                create_graph = True, 
                retain_graph = True)[0].view(G.shape[0], -1)
        grad_pen = (torch.sum(lipschitz_grad**2, dim=1) - 1).relu()

        self.rej_prob = log_prior.detach().mean()
        log_prior = torch.logsumexp(log_prior - math.log(log_prior.size(0)), 0)
        rej_log_prob = log_posterior - log_prior + grad_pen.detach() - grad_pen
        return super().log_prob(x, l) + rej_log_prob * 20000 - rej_log_prob.detach() * 19999
def get_model(pretrained_backbone=True, vae=True, rej=True) -> VAE:
    prior = ConditionalNormal(CustomizedResnet(64 * 2, resnet18, fpn=False), 1)
    Model = RejVAE if rej else VAE
    return Model(prior, 64, vae=vae)

In [None]:
from torch.optim.lr_scheduler import MultiStepLR

class LinearWarmupScheduler(MultiStepLR):
    """ Linearly warm-up (increasing) learning rate, starting from zero.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        warmup: target learning rate is reached at warmup.
    """

    def __init__(self, optimizer, warmup, milestones, gamma=0.1, verbose=False, last_epoch=-1):
        self.warmup = warmup
        super(LinearWarmupScheduler, self).__init__(optimizer, milestones, gamma, last_epoch, verbose)

    def get_lr(self):
        if self.last_epoch >= self.warmup:
            return super().get_lr()
        return [0.1 * base_lr + 0.9 * base_lr * self.last_epoch / self.warmup for base_lr in self.base_lrs]

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import argparse
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def log_img(model, args, wandb, writer):
    with torch.no_grad():
        l = linedraw(X_test)
        lab = unnormalize(model.sample(l))
        img = torch.cat([1 - l.repeat(1, 3, 1, 1), lab], -1)
        if args.vis_mode == 'tensorboard':
            writer.add_images("result", img.transpose(0, 3, 1, 2), gIter)
        elif args.vis_mode == 'wandb':
            wandb.log({'result': [wandb.Image(i) for i in img]})
        else:
            save_plt_img(img, title='result')

class ARGS:
    def __init__(self):
        self.batch_size = 24
        self.img_size = 256
        self.num_epoch = 64
        self.lr = 0.001
        self.warmup = 1000
        self.vae = True
        self.rej = True
        self.vis_mode = 'wandb'
        self.param_path = 'models/'
        self.exp_name = 'vae'
        self.adam = False

In [None]:
############
##  Data  ##
############
args = ARGS()
os.makedirs(args.param_path, exist_ok=True)

torch.manual_seed(0)
tr_loader, va_loader = get_data_loaders(args.batch_size, args.img_size)

In [None]:
#############
##  Model  ##
#############

model = get_model(vae=args.vae, rej=args.rej).to(device)

if args.adam:
    optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
else:
    optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
sched = LinearWarmupScheduler(optim, args.warmup, [
    args.num_epoch * 7 * len(tr_loader) // 10, args.num_epoch * 9 * len(tr_loader) // 10])

In [None]:
###############
##  Logging  ##
###############

if args.vis_mode == 'tensorboard':
    from tensorboardX import SummaryWriter
    writer = SummaryWriter(flush_secs=30)
    wandb = None
elif args.vis_mode == 'wandb':
    import wandb
    wandb.login(key='9cb14069f28e84706771023c650bfbed5c65f65c')
    wandb.init(project='linedraw')
    wandb.config.update(args)
    wandb.watch(model)
    writer = None
gIter = 0

X_test = next(iter(va_loader))
X_test = X_test.to('cuda')

In [None]:
log_iters = [100 * i for i in range(1, 10)] + [1000 * i for i in range(1, 1000)]
for epoch in range(args.num_epoch):
    cum_loss = 0.0
    pbar = tqdm(tr_loader)
    for i, img in enumerate(pbar):
        img = img.to(device)
        l = linedraw(img)
        img = normalize(img)
        loss = -model.log_prob(img, l).mean() / (3 * args.img_size ** 2)
        optim.zero_grad()
        loss.backward()
        optim.step()
        sched.step()
        cum_loss += loss.item()
        pbar.set_description_str(f"Epoch {epoch}, nll {cum_loss / (i+1):.4f}")
        if args.vis_mode == 'tensorboard':
            writer.add_scalar("Train/nll", loss, gIter)
            if args.rej:
                writer.add_scalar("Train/rej", model.rej_prob, gIter)
        elif args.vis_mode == 'wandb':
            logs = {"Train/nll": loss}
            if args.rej:
                logs.update({"Train/rej": model.rej_prob})
            wandb.log(logs)
        if gIter in log_iters:
            log_img(model, args, wandb, writer)
        gIter += 1

#     # model.eval()
#     with torch.no_grad():
#         cum_loss = 0.0
#         pbar = tqdm(va_loader)
#         for i, img in enumerate(pbar):
#             img = img.to(device)
#             l = linedraw(img)
#             img = normalize(img)
#             loss = -model.log_prob(img, l).mean() / (3 * args.img_size ** 2)
#             cum_loss += loss.item()
#             pbar.set_description_str(f"Test nll {cum_loss / (i+1):.4f}")
#     if args.vis_mode == 'tensorboard':
#         writer.add_scalar("Val/nll", cum_loss / len(va_loader), gIter)
#     elif args.vis_mode == 'wandb':
#         wandb.log({"Val/nll": cum_loss / len(va_loader)})

    log_img(model, args, wandb, writer)

    torch.save(model.state_dict(), os.path.join(args.param_path, args.exp_name+'_model.pt'))

In [None]:
!rm -r wandb


# 