In [2]:
import numpy as np
import os
from PIL import Image
import argparse

from semanticGAN_code.models.stylegan2_seg import GeneratorSeg
from semanticGAN_code.models.encoder_model import FPNEncoder

import torch
from torch import optim
import torch.nn.functional as F
import math
from torchvision import transforms
from semanticGAN_code.models import lpips

In [3]:
def mask2rgb(args, mask):
    if args.dataset_name == 'celeba-mask':
        color_table = torch.tensor(
                        [[  0,   0,   0],
                        [ 0,0,205],
                        [132,112,255],
                        [ 25,25,112],
                        [187,255,255],
                        [ 102,205,170],
                        [ 227,207,87],
                        [ 142,142,56]], dtype=torch.float)

    else:
        raise Exception('No such a dataloader!')

    rgb_tensor = F.embedding(mask, color_table).permute(0,3,1,2)
    return rgb_tensor

def make_mask(args, tensor, threshold=0.5):
    if args.seg_dim == 1:
        seg_prob = torch.sigmoid(tensor)
        seg_mask = torch.zeros_like(tensor)
        seg_mask[seg_prob > threshold] = 1.0
        seg_mask = (seg_mask.to('cpu')
                       .mul(255)
                       .type(torch.uint8)
                       .permute(0, 2, 3, 1)
                       .numpy())
    else:
        seg_prob = torch.argmax(tensor, dim=1)
        seg_mask = mask2rgb(args, seg_prob)
        seg_mask = (seg_mask.to('cpu')
                       .type(torch.uint8)
                       .permute(0, 2, 3, 1)
                       .numpy())
    

    return seg_mask

def make_image(tensor):
    return (
        tensor.detach()
        .clamp_(min=-1, max=1)
        .add(1)
        .div_(2)
        .mul(255)
        .type(torch.uint8)
        .permute(0, 2, 3, 1)
        .to('cpu')
        .numpy()
    )


def overlay_img_and_mask(args, img_pil, mask_pil, alpha=0.3):
    img_pil = img_pil.convert('RGBA')
    mask_pil = mask_pil.convert('RGBA')

    overlay_pil = Image.blend(img_pil, mask_pil, alpha)
    
    return overlay_pil

def noise_regularize(noises):
    loss = 0

    for noise in noises:
        size = noise.shape[2]

        while True:
            loss = (
                loss
                + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
                + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
            )

            if size <= 8:
                break

            noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2])
            noise = noise.mean([3, 5])
            size //= 2

    return loss


def noise_normalize_(noises):
    for noise in noises:
        mean = noise.mean()
        std = noise.std()

        noise.data.add_(-mean).div_(std)

def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)

    return initial_lr * lr_ramp

def get_transformation(args):
    if args.dataset_name == 'celeba-mask':
        transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
                ]
            )
    elif args.dataset_name == 'cxr':
        transform = transforms.Compose(
                        [
                            HistogramEqualization(),
                            AdjustGamma(0.5),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,)),
                        ]
                    )
    elif args.dataset_name == 'isic':
        transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
                ]
            )
    else:
        raise Exception('No such a dataloader!')
    
    return transform

In [4]:
parser = argparse.ArgumentParser()

parser.add_argument('--dataset_name', type=str, help='segmentation dataloader name [celeba-mask|cxr|isic]', default='celeba-mask')
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--image_mode', type=str, default='RGB')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--seg_dim', type=int, default=8)
parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0])

parser.add_argument('--mean_init', action='store_true', help='initialize latent code with mean')
parser.add_argument('--no_noises', action='store_true')
parser.add_argument('--w_plus', action='store_true', help='optimize in w+ space, otherwise w space')

parser.add_argument('--save_latent', action='store_true')
parser.add_argument('--save_steps', action='store_true', help='if to save intermediate optimization results')

parser.add_argument('--truncation', type=float, default=1, help='truncation tricky, trade-off between quality and diversity')
parser.add_argument('--truncation_mean', type=int, default=4096)

parser.add_argument('--lr_rampup', type=float, default=0.05)
parser.add_argument('--lr_rampdown', type=float, default=0.25)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--noise', type=float, default=0.05)
parser.add_argument('--noise_ramp', type=float, default=0.75)
parser.add_argument('--step', type=int, default=100, help='optimization steps [100-500 should give good results]')
parser.add_argument('--noise_regularize', type=float, default=1e2)
parser.add_argument('--lambda_mse', type=float, default=0.1)
parser.add_argument('--lambda_mean', type=float, default=0.01)
parser.add_argument('--lambda_label', type=float, default=1.0)
parser.add_argument('--lambda_encoder', type=float, default=1e-3)
parser.add_argument('--lambda_encoder_init', type=float, default=0.0)

args, unknown = parser.parse_known_args()

In [5]:
device = 'cuda'
args.latent = 512
args.n_mlp = 8
tru_mean_latent = None
d_input_dim = 3

In [6]:
checkpoint = torch.load('gan_enc_14k.pt')

In [7]:
g_ema = GeneratorSeg(args.size, args.latent, args.n_mlp, seg_dim=args.seg_dim, image_mode=args.image_mode,
        channel_multiplier=args.channel_multiplier
    ).to(device)
g_ema.load_state_dict(checkpoint['g_ema'], strict=False)
g_ema.eval()

encoder = FPNEncoder(input_dim=d_input_dim, n_latent=g_ema.n_latent).to(device)
encoder.load_state_dict(checkpoint['e'])
encoder.eval()

FPNEncoder(
  (FPN_module): FPN(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_sta