# Colab-AdaptiveGAN

Original repo: [GuardSkill/AdaptiveGAN](https://github.com/GuardSkill/AdaptiveGAN)

Differentiable Augmentation: [mit-han-lab/data-efficient-gans](https://github.com/mit-han-lab/data-efficient-gans)

My fork: [styler00dollar/Colab-AdaptiveGAN](https://github.com/styler00dollar/Colab-AdaptiveGAN)

In [None]:
!nvidia-smi

In [None]:
#@title git clone and install
!git clone https://github.com/GuardSkill/AdaptiveGAN
%cd AdaptiveGAN
!pip3 install -r requirements.txt
!mkdir /content/model-checkpoints

In [None]:
#@title config.yml
%%writefile /content/checkpoints/config.yml
MODE: 1             # 1: train, 2: test, 3: eval
MASK: 3             # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)  6: one to one image mask
SEED: 10            # random seed
GPU: [0]            # list of gpu ids
DEBUG: 0           # turns on debugging mode
VERBOSE: 0          # turns on verbose mode in the output console

TRAIN_FLIST: "/content/train/train.tflist"
VAL_FLIST: "/content/val/val.tflist"
TEST_FLIST: "/content/val/val.tflist"

TRAIN_MASK_FLIST: "/content/mask_train/mask_train.tflist"
VAL_MASK_FLIST: "/content/mask_val/mask_val.tflist"
TEST_MASK_FLIST: "/content/mask_val/mask_val.tflist"

BLOCKS: 4                     # set the res block in each stage
LR: 1e-4                      # learning rate
D2G_LR: 0.1                   # discriminator/generator learning rate ratio
BETA1: 0.0                    # adam optimizer beta1
BETA2: 0.9                    # adam optimizer beta2
BATCH_SIZE: 1                 # input batch size for training #6
INPUT_SIZE: 256               # input image size for training 0 for original size
MAX_ITERS: 2e6                # maximum number of iterations to train the model
MAX_STEPS: 5000               # maximum number of each epoch
MAX_EPOCHES: 100              # maximum number of epoches
LOADWITHEPOCH: 1              # if load epoch when loading model 

L1_LOSS_WEIGHT: 1             # l1 loss weight
FM_LOSS_WEIGHT: 10            # feature-matching loss weight
STYLE_LOSS_WEIGHT: 250        # style loss weight
CONTENT_LOSS_WEIGHT: 0.1      # perceptual loss weight
INPAINT_ADV_LOSS_WEIGHT: 0.1  # adversarial loss weight

GAN_LOSS: nsgan               # nsgan | lsgan | hinge
GAN_POOL_SIZE: 0              # fake images pool size

SAVE_INTERVAL: 1000           # how many iterations to wait before saving model (0: never)
SAMPLE_INTERVAL: 1000         # how many iterations to wait before sampling (0: never)
SAMPLE_SIZE: 1               # number of images to sample #12
EVAL_INTERVAL: 20             # How many INTERVAL sample while valuation  (0: never  36000 in places)
LOG_INTERVAL: 10              # how many iterations to wait before logging training status (0: never)

# Training

In [None]:
#@title Differentiable Augmentation (experimental)
%%writefile /content/AdaptiveGAN/src/models.py

# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F


def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

policy = 'color,translation,cutout'

import os
import torch
import torch.nn as nn
import torch.optim as optim

from .networks import Discriminator
from .blocks import LinkNet, PyramidNet

from .loss import AdversarialLoss, PerceptualLoss, StyleLoss, GradientLoss


class BaseModel(nn.Module):
    def __init__(self, name, config):
        super(BaseModel, self).__init__()

        self.name = name
        self.config = config
        self.iteration = 0
        self.epoch = None
        self.gen_weights_path = os.path.join(config.PATH, name + '_gen.pth')
        self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')

    def load(self):
        if os.path.exists(self.gen_weights_path):
            print('Loading %s generator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.gen_weights_path)
            else:
                data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)

            self.generator.load_state_dict(data['generator'])
            # self.iteration = data['iteration']
            if self.config.LOADWITHEPOCH == 1:
                self.epoch = data['epoch']

        # load discriminator only when training
        if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
            print('Loading %s discriminator...' % self.name)

            if torch.cuda.is_available():
                data = torch.load(self.dis_weights_path)
            else:
                data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)

            self.discriminator.load_state_dict(data['discriminator'])

    def save(self, epoch):
        print('\nsaving %s...\n' % self.name)
        torch.save({
            # 'iteration': self.iteration,
            'generator': self.generator.state_dict(),
            'epoch': epoch
        }, os.path.join(os.path.dirname(self.gen_weights_path), self.name + '_%d_gen.pth' % (epoch)))

        torch.save({
            'discriminator': self.discriminator.state_dict()
        }, os.path.join(os.path.dirname(self.dis_weights_path), self.name + '_%d_dis.pth' % (epoch)))


class InpaintingModel(BaseModel):
    def __init__(self, config):
        super(InpaintingModel, self).__init__('InpaintingModel', config)

        # generator input: [rgb(3)]
        in_channel = 3
        generator = PyramidNet(in_channel, config.BLOCKS)
        # generator=UnetGeneratorSame()       Unet-lile generator
        # summary(generator, (3, 256, 256), 6,device='cpu')
        # print(generator)
        # discriminator input: [rgb(3)]
        discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')
        params=sum([param.nelement() for param in generator.parameters()])
        print("This Generative Model Total params: {}M /  {}K".format (round((params>>10)/1024,2),params>>10))
        params = sum([param.nelement() for param in discriminator.parameters()])
        print("This Adversarial Model Total params: {}M /  {}K".format(round((params>>10)/1024,2), params >> 10))
        l1_loss = nn.L1Loss()
        perceptual_loss = PerceptualLoss()
        style_loss = StyleLoss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
        # gradient_loss = GradientLoss(independent=True, distance='L2')
        gpus = config.GPU
        gpus_list_0_start = list(range(len(gpus)))  # beause we set os.environ to change the visible GPUS in main,py
        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, gpus_list_0_start)
            discriminator = nn.DataParallel(discriminator, gpus_list_0_start)
            # generator =nn.DistributedDataParallel(generator, gpus_list_0_start)
            # discriminator = nn.DistributedDataParallel(discriminator, gpus_list_0_start)
        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)
        self.add_module('l1_loss', l1_loss)
        self.add_module('perceptual_loss', perceptual_loss)
        self.add_module('style_loss', style_loss)
        self.add_module('adversarial_loss', adversarial_loss)
        # self.add_module('gradient_loss', gradient_loss)
        self.gen_optimizer = optim.Adam(
            params=generator.parameters(),
            lr=float(config.LR),
            betas=(config.BETA1, config.BETA2)
        )

        self.dis_optimizer = optim.Adam(
            params=discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(config.BETA1, config.BETA2)
        )

    def process(self, images, masks):
        self.iteration += 1

        # zero optimizers
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()

        # process outputs
        outputs = self(images, masks)
        gen_loss = 0
        dis_loss = 0

        # discriminator loss
        dis_input_real = images
        # dis_input_real =torch.cat((images, masks), dim=1)
        dis_input_fake = outputs.detach()
        # dis_input_fake =torch.cat((outputs.detach(), masks), dim=1)

        dis_input_real = DiffAugment(dis_input_real, policy=policy)
        dis_input_fake = DiffAugment(dis_input_fake, policy=policy)

        dis_real, dis_real_feat = self.discriminator(dis_input_real)  # in: [rgb(3)]
        dis_fake, dis_fake_feat = self.discriminator(dis_input_fake)  # in: [rgb(3)]


        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_gan_loss = torch.FloatTensor([0])
        if self.config.INPAINT_ADV_LOSS_WEIGHT > 0:
            gen_input_fake = outputs
            # gen_input_fake = torch.cat((outputs, masks), dim=1)
            gen_input_fake = DiffAugment(gen_input_fake, policy=policy)
            gen_fake, gen_fake_feat = self.discriminator(gen_input_fake)  # in: [rgb(3)]
            gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
            gen_loss += gen_gan_loss

        # generator feature matching loss
        gen_fm_loss = torch.FloatTensor([0])
        if self.config.FM_LOSS_WEIGHT > 0:
            gen_fm_loss = 0
            for i in range(len(dis_real_feat)):
                gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
            gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT
            gen_loss += gen_fm_loss

        # generator l1 loss
        gen_l1_loss = torch.FloatTensor([0])
        if self.config.L1_LOSS_WEIGHT > 0:
            gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(1 - masks)
            gen_loss += gen_l1_loss

        # # generator perceptual loss
        gen_content_loss = torch.FloatTensor([0])
        if self.config.CONTENT_LOSS_WEIGHT > 0:
            gen_content_loss = self.perceptual_loss(outputs, images)
            gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
            gen_loss += gen_content_loss

        # # generator style loss
        gen_style_loss = torch.FloatTensor([0])
        if self.config.STYLE_LOSS_WEIGHT > 0:
            gen_style_loss = self.style_loss(outputs, images)
            gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
            gen_loss += gen_style_loss

        # gradient loss
        # gen_gradient_loss = torch.FloatTensor([0])
        # if self.config.GRADIENT_LOSS_WEIGHT > 0:
        #     gen_gradient_loss = self.gradient_loss(outputs, images)
        #     gen_gradient_loss = gen_gradient_loss * self.config.GRADIENT_LOSS_WEIGHT
        #     gen_loss += gen_gradient_loss
        # create logs
        logs = {
            "l_d2": dis_loss.item(),
            "l_g2": gen_gan_loss.item(),
            "l_l1": gen_l1_loss.item(),
            "l_fm": gen_fm_loss.item(),
            "l_per": gen_content_loss.item(),
            "l_sty": gen_style_loss.item(),
            # 'l_grad': gen_gradient_loss.item()
        }

        if not self.training:
            val_logs = {}
            for key, value in logs.items():
                key = "val_" + key
                val_logs[key] = value
            logs = val_logs
        return outputs, gen_loss, dis_loss, logs

    def forward(self, images, masks):
        # images_masked = (images * (1 - masks).float()) + masks
        images_masked = (images * (masks).float())
        inputs = images_masked
        outputs = self.generator(inputs)  # in: [rgb(3)]
        return outputs

    def backward(self, gen_loss=None, dis_loss=None):
        dis_loss.backward()
        self.dis_optimizer.step()

        gen_loss.backward()
        self.gen_optimizer.step()



In [None]:
%cd /content/AdaptiveGAN
!python3 train.py --checkpoints /content/model-checkpoints

# Testing

In [None]:
#@title config.yml (testing)
%%writefile /content/checkpoints/config.yml
MODE: 2             # 1: train, 2: test, 3: eval
MASK: 3             # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)  6: one to one image mask
SEED: 10            # random seed
GPU: [0]            # list of gpu ids
DEBUG: 0           # turns on debugging mode
VERBOSE: 0          # turns on verbose mode in the output console

TRAIN_FLIST: "/content/train/train.tflist"
VAL_FLIST: "/content/val/val.tflist"
TEST_FLIST: "/content/val/val.tflist"

TRAIN_MASK_FLIST: "/content/mask_train/mask_train.tflist"
VAL_MASK_FLIST: "/content/mask_val/mask_val.tflist"
TEST_MASK_FLIST: "/content/mask_val/mask_val.tflist"

BLOCKS: 4                     # set the res block in each stage
LR: 1e-4                      # learning rate
D2G_LR: 0.1                   # discriminator/generator learning rate ratio
BETA1: 0.0                    # adam optimizer beta1
BETA2: 0.9                    # adam optimizer beta2
BATCH_SIZE: 1                 # input batch size for training #6
INPUT_SIZE: 256               # input image size for training 0 for original size
MAX_ITERS: 2e6                # maximum number of iterations to train the model
MAX_STEPS: 5000               # maximum number of each epoch
MAX_EPOCHES: 100              # maximum number of epoches
LOADWITHEPOCH: 1              # if load epoch when loading model 

L1_LOSS_WEIGHT: 1             # l1 loss weight
FM_LOSS_WEIGHT: 10            # feature-matching loss weight
STYLE_LOSS_WEIGHT: 250        # style loss weight
CONTENT_LOSS_WEIGHT: 0.1      # perceptual loss weight
INPAINT_ADV_LOSS_WEIGHT: 0.1  # adversarial loss weight

GAN_LOSS: nsgan               # nsgan | lsgan | hinge
GAN_POOL_SIZE: 0              # fake images pool size

SAVE_INTERVAL: 1000           # how many iterations to wait before saving model (0: never)
SAMPLE_INTERVAL: 1000         # how many iterations to wait before sampling (0: never)
SAMPLE_SIZE: 1               # number of images to sample #12
EVAL_INTERVAL: 20             # How many INTERVAL sample while valuation  (0: never  36000 in places)
LOG_INTERVAL: 10              # how many iterations to wait before logging training status (0: never)

In [None]:
%cd /content/AdaptiveGAN
!python3 test.py \
  --checkpoints /content/model-checkpoints \
  --input /content/0.jpg \
  --mask /content/mask.png \
  --output /content/output.png