# Colab-CSA (pytorch)
Original repo: [KumapowerLIU/CSA-inpainting](https://github.com/KumapowerLIU/CSA-inpainting)

Yukariin's repo: [Yukariin/CSA_pytorch](https://github.com/Yukariin/CSA_pytorch)
Thank you Yukariin for fixing the broken original.

Differentiable Augmentation: [mit-han-lab/data-efficient-gans](https://github.com/mit-han-lab/data-efficient-gans)

My fork: [styler00dollar/Colab-CSA-pytorch](https://github.com/styler00dollar/Colab-CSA-pytorch)

In [None]:
!nvidia-smi

In [None]:
#@title Install
!git clone https://github.com/Yukariin/CSA_pytorch
%cd /content/
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh
!chmod +x Miniconda3-4.5.4-Linux-x86_64.sh
!bash ./Miniconda3-4.5.4-Linux-x86_64.sh -b -f -p /usr/local
!conda install pytorch==1.1 cudatoolkit torchvision -c pytorch -y
!conda install ipykernel -y
!pip install tensorboardX
!pip install opencv-python

# Training

Warning: It saves into the folder named ckpt. If you only create save_dir, then it will crash during saving.

In [None]:
#@title Create empty folders
!mkdir /content/root
!mkdir /content/save_dir
!mkdir /content/save_dir/ckpt

In [None]:
#@title Adding Differentiable Augmentation in loss.py (experimental)
%%writefile /content/CSA_pytorch/loss.py 
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F


def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}

policy = 'color,translation,cutout'

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms


def denorm(x):
    out = (x + 1) / 2 # [-1,1] -> [0,1]
    return out.clamp_(0, 1)


class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()

        vgg16 = models.vgg16(pretrained=True)

        self.enc_1 = nn.Sequential(*vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])
        self.enc_4 = nn.Sequential(*vgg16.features[17:23])

        #print(self.enc_1)
        #print(self.enc_2)
        #print(self.enc_3)
        #print(self.enc_4)

        # fix the encoder
        for i in range(4):
            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
                param.requires_grad = False

    def forward(self, image):
        results = [image]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]
    

class ConsistencyLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        self.vgg = VGG16FeatureExtractor()
        self.l2 = nn.MSELoss()

    def forward(self, csa, csa_d, target, mask):
        # https://pytorch.org/docs/stable/torchvision/models.html
        # Pre-trained VGG16 model expect input images normalized in the same way.
        # The images have to be loaded in to a range of [0, 1]
        # and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
        t = denorm(target) # [-1,1] -> [0,1]
        t = self.normalize(t[0]) # BxCxHxW -> CxHxW -> normalize
        t = t.unsqueeze(0) # CxHxW -> BxCxHxW

        vgg_gt = self.vgg(t)
        vgg_gt = vgg_gt[-1]

        mask_r = F.interpolate(mask, size=csa.size()[2:])

        lossvalue = self.l2(csa*mask_r, vgg_gt*mask_r) + self.l2(csa_d*mask_r, vgg_gt*mask_r)
        return lossvalue


def calc_gan_loss(discriminator, output, target):
    # Either here or down below should be diffaug applied. Needs to be tested.
    #output = DiffAugment(output, policy=policy)
    #target = DiffAugment(target, policy=policy)

    y_pred_fake = discriminator(output, target)
    y_pred = discriminator(target, output)

    y_pred_fake = DiffAugment(y_pred_fake, policy=policy)
    y_pred = DiffAugment(y_pred, policy=policy)

    g_loss = (torch.mean((y_pred - torch.mean(y_pred_fake) + 1.) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) - 1.) ** 2))/2
    d_loss = (torch.mean((y_pred - torch.mean(y_pred_fake) - 1.) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) + 1.) ** 2))/2

    return g_loss, d_loss

In [None]:
%cd /content/CSA_pytorch
!python /content/CSA_pytorch/train.py --root /content/root --save_dir /content/save_dir

# Testing 

In [None]:
!python /content/CSA_pytorch/test.py --image /content/image.png \
--mask /content/mask.png --output /content/output.png \
--checkpoint /content/save_dir/ckpt/G_10000.pth