# Colab-BasicSR (pytorch lightning)

[This tutorial](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09), [this issue](https://stackoverflow.com/questions/65387967/misconfigurationerror-no-tpu-devices-were-found-even-when-tpu-is-connected-in)  and [this Colab](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb#scrollTo=3vKszYf6y1Vv) were very helpful. This Colab does support single-GPU, multi-GPU and TPU training.

Can use various loss functions and has the context_encoder discriminator as default. Currently there are only various inpainting generators from [my BasicSR fork](https://github.com/styler00dollar/Colab-BasicSR).

What is not included inside this Colab, but is included in [my normal BasicSR Colab](https://colab.research.google.com/github/styler00dollar/Colab-BasicSR/blob/master/Colab-BasicSR.ipynb):
- [edge-informed-sisr](https://github.com/knazeri/edge-informed-sisr/blob/master/src/models.py)
- [USRNet](https://github.com/cszn/KAIR/blob/master/models/network_usrnet.py)
- [OFT Dataloader](https://github.com/styler00dollar/Colab-BasicSR/tree/master/codes/data)
- Some loss functions, but most are here
- DiffAug / Mixup

What currently is here but not inside the other Colab:
- Custom mask loading
- New discriminators (EfficientNet, ResNeSt, Transformer)
- [AdamP](https://github.com/clovaai/AdamP)

Sidenotes:
- Does validation on set validation frequency and epoch end

In [None]:
!nvidia-smi

In [None]:
#@title GPU
# create empty folders
!mkdir /content/masks
!mkdir /content/validation
!mkdir /content/data
!mkdir /content/logs/

#!pip install pytorch-lightning -U
# Hotfix, to avoid pytorch-lightning bug
!pip install git+https://github.com/PyTorchLightning/pytorch-lightning
!pip install tensorboardX

In [None]:
#@title TPU  (restart runtime afterwards)
# create empty folders
!mkdir /content/masks
!mkdir /content/validation
!mkdir /content/data
!mkdir /content/logs/

#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
#!pip install pytorch-lightning
!pip install lightning-flash

import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()

!pip install pytorch-lightning


!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py > /dev/null
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev > /dev/null
!pip install pytorch-lightning > /dev/null

!pip install tensorboardX

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print('Google Drive connected.')

Paths:
```
/content/data (rgb data)
/content/masks (1 channel masks, black = mask, white = original image)
/content/validation (images for validation)
/content/validation_output (validation destination, will be created if not present)
/content/test (rgb data)
/content/test_output (test output, will be created if not present)
```
By default, random masks will have 50% chance and custom masks will have 50% chance. Current validation does not rely on metrics and will take a green masked LR image as input, but metrics are added and only need a custom dataloader.

In [None]:
#@title copy data somehow
!mkdir '/content/data'
!mkdir '/content/data/images'
!cp "/content/drive/MyDrive/classification_v3.7z" "/content/data/images/data.7z"
%cd /content/data/images
!7z x "data.7z"
!rm -rf /content/data/images/data.7z

# Optional

In [None]:
# EfficientNet
!pip install efficientnet_pytorch

In [None]:
# AdamP
!pip install adamp

In [None]:
#@title transformer.py
"""
ViT_8_8.py (17-2-20)
https://github.com/VITA-Group/TransGAN/blob/97d4b5b29d237ff4bf1337e2a2cf402a6c8a314c/models/ViT_8_8.py

ViT_helper.py (17-2-20)
https://github.com/VITA-Group/TransGAN/blob/97d4b5b29d237ff4bf1337e2a2cf402a6c8a314c/models/ViT_helper.py

"""

import torch
from torch import nn
import pytorch_lightning as pl
import torch
import math
import warnings
import torch
import torch.nn as nn
import math

def drop_path(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(pl.LightningModule):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

from itertools import repeat
from torch._six import container_abcs


# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)



def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


# -*- coding: utf-8 -*-
# @Date    : 2019-08-15
# @Author  : Xinyu Gong (xy_gong@tamu.edu)
# @Link    : None
# @Version : 0.0

class matmul(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
    def forward(self, x1, x2):
        x = x1@x2
        return x

def count_matmul(m, x, y):
    num_mul = x[0].numel() * x[1].size(-1)
    # m.total_ops += torch.DoubleTensor([int(num_mul)])
    m.total_ops += torch.DoubleTensor([int(0)])
    

def gelu(x):
    """ Original Implementation of the gelu activation function in Google Bert repo when initialy created.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class Mlp(pl.LightningModule):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=gelu, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(pl.LightningModule):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.mat = matmul()

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (self.mat(q, k.transpose(-2, -1))) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = self.mat(attn, v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(pl.LightningModule):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=gelu, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

def pixel_upsample(x, H, W):
    B, N, C = x.size()
    assert N == H*W
    x = x.permute(0, 2, 1)
    x = x.view(-1, C, H, W)
    x = nn.PixelShuffle(2)(x)
    B, C, H, W = x.size()
    x = x.view(-1, C, H*W)
    x = x.permute(0,2,1)
    return x, H, W


def _downsample(x):
    # Downsample (Mean Avg Pooling with 2x2 kernel)
    return nn.AvgPool2d(kernel_size=2)(x)

class Block(pl.LightningModule):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=gelu, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(pl.LightningModule):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class HybridEmbed(pl.LightningModule):
    """ CNN Feature Map Embedding
    Extract feature map from CNN, flatten, project to embedding dim.
    """
    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
        super().__init__()
        assert isinstance(backbone, pl.LightningModule)
        img_size = to_2tuple(img_size)
        self.img_size = img_size
        self.backbone = backbone
        if feature_size is None:
            with torch.no_grad():
                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
                # map for all networks, the feature metadata has reliable channel and stride info, but using
                # stride to calc feature dim requires info about padding of each stage that isn't captured.
                training = backbone.training
                if training:
                    backbone.eval()
                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
                feature_size = o.shape[-2:]
                feature_dim = o.shape[1]
                backbone.train(training)
        else:
            feature_size = to_2tuple(feature_size)
            feature_dim = self.backbone.feature_info.channels()[-1]
        self.num_patches = feature_size[0] * feature_size[1]
        self.proj = nn.Linear(feature_dim, embed_dim)

    def forward(self, x):
        x = self.backbone(x)[-1]
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x


class TranformerDiscriminator(pl.LightningModule):
    """ Vision Transformer with support for patch or hybrid CNN input stage
    """
    def __init__(self, img_size=32, patch_size=1, in_chans=3, num_classes=1, embed_dim=64, depth=7,
                 num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = embed_dim = self.embed_dim = 64  # num_features for consistency with other models
        self.depth = depth
        self.patch_size = patch_size
        self.img_size = img_size

        if hybrid_backbone is not None:
            self.patch_embed = HybridEmbed(
                hybrid_backbone, img_size=self.img_size, in_chans=in_chans, embed_dim=embed_dim)
        else:
            self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0)
        num_patches = (self.img_size // patch_size)**2

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(self.depth)])
        self.norm = norm_layer(embed_dim)

        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
        #self.repr = nn.Linear(embed_dim, representation_size)
        #self.repr_act = nn.Tanh()

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)
        print("Transformer init complete")

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        #if "None" not in self.args.diff_aug:
        #    x = DiffAugment(x, self.args.diff_aug, True)
        B = x.shape[0]
        x = self.patch_embed(x).flatten(2).permute(0,2,1)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:,0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v
    return out_dict

In [None]:
#@title ResNeSt.py
#https://github.com/zhanghang1989/ResNeSt/blob/11eb547225c6b98bdf6cab774fb58dffc53362b1/resnest/torch/splat.py
"""Split-Attention"""

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair

__all__ = ['SplAtConv2d']

class SplAtConv2d(Module):
    """Split-Attention Conv2d
    """
    def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
                 dilation=(1, 1), groups=1, bias=True,
                 radix=2, reduction_factor=4,
                 rectify=False, rectify_avg=False, norm_layer=None,
                 dropblock_prob=0.0, **kwargs):
        super(SplAtConv2d, self).__init__()
        padding = _pair(padding)
        self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
        self.rectify_avg = rectify_avg
        inter_channels = max(in_channels*radix//reduction_factor, 32)
        self.radix = radix
        self.cardinality = groups
        self.channels = channels
        self.dropblock_prob = dropblock_prob
        if self.rectify:
            from rfconv import RFConv2d
            self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
                                 groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
        else:
            self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
                               groups=groups*radix, bias=bias, **kwargs)
        self.use_bn = norm_layer is not None
        if self.use_bn:
            self.bn0 = norm_layer(channels*radix)
        self.relu = ReLU(inplace=True)
        self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
        if self.use_bn:
            self.bn1 = norm_layer(inter_channels)
        self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality)
        if dropblock_prob > 0.0:
            self.dropblock = DropBlock2D(dropblock_prob, 3)
        self.rsoftmax = rSoftMax(radix, groups)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn0(x)
        if self.dropblock_prob > 0.0:
            x = self.dropblock(x)
        x = self.relu(x)

        batch, rchannel = x.shape[:2]
        if self.radix > 1:
            if torch.__version__ < '1.5':
                splited = torch.split(x, int(rchannel//self.radix), dim=1)
            else:
                splited = torch.split(x, rchannel//self.radix, dim=1)
            gap = sum(splited) 
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        if self.use_bn:
            gap = self.bn1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

        if self.radix > 1:
            if torch.__version__ < '1.5':
                attens = torch.split(atten, int(rchannel//self.radix), dim=1)
            else:
                attens = torch.split(atten, rchannel//self.radix, dim=1)
            out = sum([att*split for (att, split) in zip(attens, splited)])
        else:
            out = atten * x
        return out.contiguous()

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x





#https://github.com/zhanghang1989/ResNeSt/blob/11eb547225c6b98bdf6cab774fb58dffc53362b1/resnest/torch/resnet.py
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""ResNet variants"""
import math
import torch
import torch.nn as nn

#from .splat import SplAtConv2d

__all__ = ['ResNet', 'Bottleneck']

class DropBlock2D(object):
    def __init__(self, *args, **kwargs):
        raise NotImplementedError

class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        """Global average pooling over the input's spatial dimensions"""
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, inputs):
        return nn.functional.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)

class Bottleneck(nn.Module):
    """ResNet Bottleneck
    """
    # pylint: disable=unused-argument
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 radix=1, cardinality=1, bottleneck_width=64,
                 avd=False, avd_first=False, dilation=1, is_first=False,
                 rectified_conv=False, rectify_avg=False,
                 norm_layer=None, dropblock_prob=0.0, last_gamma=False):
        super(Bottleneck, self).__init__()
        group_width = int(planes * (bottleneck_width / 64.)) * cardinality
        self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
        self.bn1 = norm_layer(group_width)
        self.dropblock_prob = dropblock_prob
        self.radix = radix
        self.avd = avd and (stride > 1 or is_first)
        self.avd_first = avd_first

        if self.avd:
            self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
            stride = 1

        if dropblock_prob > 0.0:
            self.dropblock1 = DropBlock2D(dropblock_prob, 3)
            if radix == 1:
                self.dropblock2 = DropBlock2D(dropblock_prob, 3)
            self.dropblock3 = DropBlock2D(dropblock_prob, 3)

        if radix >= 1:
            self.conv2 = SplAtConv2d(
                group_width, group_width, kernel_size=3,
                stride=stride, padding=dilation,
                dilation=dilation, groups=cardinality, bias=False,
                radix=radix, rectify=rectified_conv,
                rectify_avg=rectify_avg,
                norm_layer=norm_layer,
                dropblock_prob=dropblock_prob)
        elif rectified_conv:
            from rfconv import RFConv2d
            self.conv2 = RFConv2d(
                group_width, group_width, kernel_size=3, stride=stride,
                padding=dilation, dilation=dilation,
                groups=cardinality, bias=False,
                average_mode=rectify_avg)
            self.bn2 = norm_layer(group_width)
        else:
            self.conv2 = nn.Conv2d(
                group_width, group_width, kernel_size=3, stride=stride,
                padding=dilation, dilation=dilation,
                groups=cardinality, bias=False)
            self.bn2 = norm_layer(group_width)

        self.conv3 = nn.Conv2d(
            group_width, planes * 4, kernel_size=1, bias=False)
        self.bn3 = norm_layer(planes*4)

        if last_gamma:
            from torch.nn.init import zeros_
            zeros_(self.bn3.weight)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.dilation = dilation
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        if self.dropblock_prob > 0.0:
            out = self.dropblock1(out)
        out = self.relu(out)

        if self.avd and self.avd_first:
            out = self.avd_layer(out)

        out = self.conv2(out)
        if self.radix == 0:
            out = self.bn2(out)
            if self.dropblock_prob > 0.0:
                out = self.dropblock2(out)
            out = self.relu(out)

        if self.avd and not self.avd_first:
            out = self.avd_layer(out)

        out = self.conv3(out)
        out = self.bn3(out)
        if self.dropblock_prob > 0.0:
            out = self.dropblock3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    """ResNet Variants
    Parameters
    ----------
    block : Block
        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
    layers : list of int
        Numbers of layers in each block
    classes : int, default 1000
        Number of classification classes.
    dilated : bool, default False
        Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
        typically used in Semantic Segmentation.
    norm_layer : object
        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
        for Synchronized Cross-GPU BachNormalization).
    Reference:
        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
        - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
    """
    # pylint: disable=unused-variable
    def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
                 num_classes=1000, dilated=False, dilation=1,
                 deep_stem=False, stem_width=64, avg_down=False,
                 rectified_conv=False, rectify_avg=False,
                 avd=False, avd_first=False,
                 final_drop=0.0, dropblock_prob=0,
                 last_gamma=False, norm_layer=nn.BatchNorm2d):
        self.cardinality = groups
        self.bottleneck_width = bottleneck_width
        # ResNet-D params
        self.inplanes = stem_width*2 if deep_stem else 64
        self.avg_down = avg_down
        self.last_gamma = last_gamma
        # ResNeSt params
        self.radix = radix
        self.avd = avd
        self.avd_first = avd_first

        super(ResNet, self).__init__()
        self.rectified_conv = rectified_conv
        self.rectify_avg = rectify_avg
        if rectified_conv:
            from rfconv import RFConv2d
            conv_layer = RFConv2d
        else:
            conv_layer = nn.Conv2d
        conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
        if deep_stem:
            self.conv1 = nn.Sequential(
                conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
                norm_layer(stem_width),
                nn.ReLU(inplace=True),
                conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
                norm_layer(stem_width),
                nn.ReLU(inplace=True),
                conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
            )
        else:
            self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
                                   bias=False, **conv_kwargs)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
        if dilated or dilation == 4:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
                                           dilation=2, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
                                           dilation=4, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        elif dilation==2:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                           dilation=1, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
                                           dilation=2, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        else:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                           norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                           norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        self.avgpool = GlobalAvgPool2d()
        self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, norm_layer):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
                    dropblock_prob=0.0, is_first=True):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            down_layers = []
            if self.avg_down:
                if dilation == 1:
                    down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
                                                    ceil_mode=True, count_include_pad=False))
                else:
                    down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
                                                    ceil_mode=True, count_include_pad=False))
                down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
                                             kernel_size=1, stride=1, bias=False))
            else:
                down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
                                             kernel_size=1, stride=stride, bias=False))
            down_layers.append(norm_layer(planes * block.expansion))
            downsample = nn.Sequential(*down_layers)

        layers = []
        if dilation == 1 or dilation == 2:
            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                                radix=self.radix, cardinality=self.cardinality,
                                bottleneck_width=self.bottleneck_width,
                                avd=self.avd, avd_first=self.avd_first,
                                dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
                                rectify_avg=self.rectify_avg,
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                                last_gamma=self.last_gamma))
        elif dilation == 4:
            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                                radix=self.radix, cardinality=self.cardinality,
                                bottleneck_width=self.bottleneck_width,
                                avd=self.avd, avd_first=self.avd_first,
                                dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
                                rectify_avg=self.rectify_avg,
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                                last_gamma=self.last_gamma))
        else:
            raise RuntimeError("=> unknown dilation size: {}".format(dilation))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes,
                                radix=self.radix, cardinality=self.cardinality,
                                bottleneck_width=self.bottleneck_width,
                                avd=self.avd, avd_first=self.avd_first,
                                dilation=dilation, rectified_conv=self.rectified_conv,
                                rectify_avg=self.rectify_avg,
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                                last_gamma=self.last_gamma))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        #x = x.view(x.size(0), -1)
        x = torch.flatten(x, 1)
        if self.drop:
            x = self.drop(x)
        x = self.fc(x)

        return x

#https://github.com/zhanghang1989/ResNeSt/blob/master/resnest/torch/resnest.py
import torch
#from .resnet import ResNet, Bottleneck

__all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269']

_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth'

_model_sha256 = {name: checksum for checksum, name in [
    ('528c19ca', 'resnest50'),
    ('22405ba7', 'resnest101'),
    ('75117900', 'resnest200'),
    ('0cc87c48', 'resnest269'),
    ]}

def short_hash(name):
    if name not in _model_sha256:
        raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
    return _model_sha256[name][:8]

resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
    name in _model_sha256.keys()
}

def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3],
                   radix=2, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=32, avg_down=True,
                   avd=True, avd_first=False, **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnest50'], progress=True, check_hash=True))
    return model

def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
    model = ResNet(Bottleneck, [3, 4, 23, 3],
                   radix=2, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=64, avg_down=True,
                   avd=True, avd_first=False, **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnest101'], progress=True, check_hash=True))
    return model

def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
    model = ResNet(Bottleneck, [3, 24, 36, 3],
                   radix=2, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=64, avg_down=True,
                   avd=True, avd_first=False, **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnest200'], progress=True, check_hash=True))
    return model

def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
    model = ResNet(Bottleneck, [3, 30, 48, 8],
                   radix=2, groups=1, bottleneck_width=64,
                   deep_stem=True, stem_width=64, avg_down=True,
                   avd=True, avd_first=False, **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnest269'], progress=True, check_hash=True))
    return model

# Loss

In [None]:
#@title getting pytorch-loss-functions
%cd /content
!git clone https://github.com/styler00dollar/pytorch-loss-functions pytorchloss
%cd /content/pytorchloss

In [None]:
# restart from here if you reset your notebook
%cd /content/pytorchloss

# Conv

In [None]:
#@title partialconv.py
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
#
# Source: https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
###############################################################################

import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False  

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
            
        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in
                        
                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                #make sure the value of self.mask_ratio for the entries in the interior (no need for padding) have value 1. If not, you replace with the line below.
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        # if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
        #     self.update_mask.to(input)
        #     self.mask_ratio.to(input)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output

In [None]:
#@title deformconv2d.py
import torch.nn as nn
import torchvision.ops as O


class DeformConv2d(nn.Module):
    def __init__(self, in_nc, out_nc, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
        super(DeformConv2d, self).__init__()

        self.conv_offset = nn.Conv2d(in_nc, 2 * (kernel_size**2), kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

        self.dcn_conv = O.DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)

    def forward(self, x):
        offset = self.conv_offset(x)
        return self.dcn_conv(x, offset=offset)

# Data (for inpainting)

The Dataloader can be customized in several different ways, here are some examples. The most basic one is this one, which does load an image and maybe outputs an image mask or an autogenerated mask.

In [None]:
#@title data.py
import os

import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import random

import cv2
import random
import glob

class DS(Dataset):
    def __init__(self, root, transform=None, size=256):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform
        self.mask_dir = '/content/masks'
        self.files = glob.glob(self.mask_dir + '/**/*.png', recursive=True)
        files_jpg = glob.glob(self.mask_dir + '/**/*.jpg', recursive=True)
        self.files.extend(files_jpg)

        self.size = size

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        sample = Image.open(sample_path).convert('RGB')

        if self.transform:
            sample = self.transform(sample)

        # if edges are required
        grayscale = cv2.cvtColor(np.array(sample), cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(grayscale,100,150)
        grayscale = torch.from_numpy(grayscale).unsqueeze(0)/255
        edges = torch.from_numpy(edges).unsqueeze(0)

        if random.uniform(0, 1) < 0.5:
          # generating mask automatically with 50% chance
          mask = DS.random_mask(height=self.size, width=self.size)
          mask = torch.from_numpy(mask)

        else:
          # load random mask from folder
          mask = cv2.imread(random.choice([x for x in self.files]), cv2.IMREAD_UNCHANGED)
          mask = cv2.resize(mask, (self.size,self.size), interpolation=cv2.INTER_NEAREST)
          
          # flip mask randomly
          if 0.3 < random.uniform(0, 1) <= 0.66:
            mask = np.flip(mask, axis=0)
          elif 0.66 < random.uniform(0, 1) <= 1:
            mask = np.flip(mask, axis=1)

          mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)/255

        #sample = torch.from_numpy(sample)
        sample = transforms.ToTensor()(sample)

        # apply mask
        masked = sample * mask
        return masked, mask, sample

        # EdgeConnect
        #return masked, mask, sample, edges, grayscale

        # PRVS
        #return masked, mask, sample, edges

    
    @staticmethod
    def random_mask(height=256, width=256,
                    min_stroke=1, max_stroke=4,
                    min_vertex=1, max_vertex=12,
                    min_brush_width_divisor=16, max_brush_width_divisor=10):
        mask = np.ones((height, width))

        min_brush_width = height // min_brush_width_divisor
        max_brush_width = height // max_brush_width_divisor
        max_angle = 2*np.pi
        num_stroke = np.random.randint(min_stroke, max_stroke+1)
        average_length = np.sqrt(height*height + width*width) / 8

        for _ in range(num_stroke):
            num_vertex = np.random.randint(min_vertex, max_vertex+1)
            start_x = np.random.randint(width)
            start_y = np.random.randint(height)

            for _ in range(num_vertex):
                angle = np.random.uniform(max_angle)
                length = np.clip(np.random.normal(average_length, average_length//2), 0, 2*average_length)
                brush_width = np.random.randint(min_brush_width, max_brush_width+1)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 0., brush_width)

                start_x, start_y = end_x, end_y
        if np.random.random() < 0.5:
            mask = np.fliplr(mask)
        if np.random.random() < 0.5:
            mask = np.flipud(mask)
        return mask.reshape((1,)+mask.shape).astype(np.float32) 



class DS_green_from_mask(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        #sample = Image.open(sample_path).convert('RGB')
        sample = cv2.imread(sample_path)
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)

        # if edges are required
        grayscale = cv2.cvtColor(sample, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(grayscale,100,150)
        grayscale = torch.from_numpy(grayscale).unsqueeze(0)
        edges = torch.from_numpy(edges).unsqueeze(0)

        green_mask = 1-np.all(sample == [0,255,0], axis=-1).astype(int)
        green_mask = torch.from_numpy(green_mask).unsqueeze(0)
        sample = torch.from_numpy(sample.astype(np.float32)).permute(2, 0, 1)/255
        sample = sample * green_mask

        # train_batch[0] = masked
        # train_batch[1] = mask
        # train_batch[2] = path
        return sample, green_mask, sample_path 

        # EdgeConnect
        #return sample, green_mask, sample_path, edges, grayscale

        # PRVS
        #return sample, green_mask, sample_path, edges


In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', validation_path: str = './', test_path: str = './', batch_size: int = 5, num_workers: int = 2):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = 256
    def setup(self, stage=None):
        img_tf = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.CenterCrop(size=self.size),
            transforms.RandomHorizontalFlip()
            #transforms.ToTensor()
        ])
        
        self.DFNetdataset_train = DS(self.training_dir, img_tf, self.size)
        self.DFNetdataset_validation = DS_green_from_mask(self.validation_dir, img_tf)
        self.DFNetdataset_test = DS_green_from_mask(self.test_dir)

    def train_dataloader(self):
        return DataLoader(self.DFNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

# Data (16x16 256px (for inpainting))

Uses 16x16 image grids to avoid problems with Colab. Colab tends to freeze if you upload too many files. Training will be slower, but will result in fewer files and avoids crashing Colab. Only recommended for usage with a lot of files, like 5-6 digit amount of files.

In [None]:
#@title data.py
import os

import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import random

import cv2
import random
import glob
import random


class DS(Dataset):
    def __init__(self, root, transform=None, size=256):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform
        self.mask_dir = '/content/masks'
        self.files = glob.glob(self.mask_dir + '/**/*.png', recursive=True)
        files_jpg = glob.glob(self.mask_dir + '/**/*.jpg', recursive=True)
        self.files.extend(files_jpg)

        self.size = size

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        #sample = Image.open(sample_path).convert('RGB')
        sample = cv2.imread(sample_path)


        x_rand = random.randint(0,15)
        y_rand = random.randint(0,15)

        sample = sample[x_rand*256:(x_rand+1)*256, y_rand*256:(y_rand+1)*256]
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)

        #sample = torch.from_numpy(sample)

        #if self.transform:
        #    sample = self.transform(sample)

        # if edges are required
        grayscale = cv2.cvtColor(np.array(sample), cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(grayscale,100,150)
        grayscale = torch.from_numpy(grayscale).unsqueeze(0)/255
        edges = torch.from_numpy(edges).unsqueeze(0)

        if random.uniform(0, 1) < 0.5:
          # generating mask automatically with 50% chance
          mask = DS.random_mask(height=self.size, width=self.size)
          mask = torch.from_numpy(mask)

        else:
          # load random mask from folder
          mask = cv2.imread(random.choice([x for x in self.files]), cv2.IMREAD_UNCHANGED)
          mask = cv2.resize(mask, (self.size,self.size), interpolation=cv2.INTER_NEAREST)
          
          # flip mask randomly
          if 0.3 < random.uniform(0, 1) <= 0.66:
            mask = np.flip(mask, axis=0)
          elif 0.66 < random.uniform(0, 1) <= 1:
            mask = np.flip(mask, axis=1)

          mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)/255

        sample = torch.from_numpy(sample).permute(2, 0, 1)/255
        #sample = transforms.ToTensor()(sample)

        # apply mask
        #print(sample.shape)
        #print(mask.shape)
        masked = sample * mask
        return masked, mask, sample

        # EdgeConnect
        #return masked, mask, sample, edges, grayscale

        # PRVS
        #return masked, mask, sample, edges

    
    @staticmethod
    def random_mask(height=256, width=256,
                    min_stroke=1, max_stroke=4,
                    min_vertex=1, max_vertex=12,
                    min_brush_width_divisor=16, max_brush_width_divisor=10):
        mask = np.ones((height, width))

        min_brush_width = height // min_brush_width_divisor
        max_brush_width = height // max_brush_width_divisor
        max_angle = 2*np.pi
        num_stroke = np.random.randint(min_stroke, max_stroke+1)
        average_length = np.sqrt(height*height + width*width) / 8

        for _ in range(num_stroke):
            num_vertex = np.random.randint(min_vertex, max_vertex+1)
            start_x = np.random.randint(width)
            start_y = np.random.randint(height)

            for _ in range(num_vertex):
                angle = np.random.uniform(max_angle)
                length = np.clip(np.random.normal(average_length, average_length//2), 0, 2*average_length)
                brush_width = np.random.randint(min_brush_width, max_brush_width+1)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 0., brush_width)

                start_x, start_y = end_x, end_y
        if np.random.random() < 0.5:
            mask = np.fliplr(mask)
        if np.random.random() < 0.5:
            mask = np.flipud(mask)
        return mask.reshape((1,)+mask.shape).astype(np.float32) 



class DS_green_from_mask(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        #sample = Image.open(sample_path).convert('RGB')
        sample = cv2.imread(sample_path)
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)

        # if edges are required
        grayscale = cv2.cvtColor(sample, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(grayscale,100,150)
        grayscale = torch.from_numpy(grayscale).unsqueeze(0)
        edges = torch.from_numpy(edges).unsqueeze(0)

        green_mask = 1-np.all(sample == [0,255,0], axis=-1).astype(int)
        green_mask = torch.from_numpy(green_mask).unsqueeze(0)
        sample = torch.from_numpy(sample.astype(np.float32)).permute(2, 0, 1)/255
        sample = sample * green_mask

        # train_batch[0] = masked
        # train_batch[1] = mask
        # train_batch[2] = path
        return sample, green_mask, sample_path 

        # EdgeConnect
        #return sample, green_mask, sample_path, edges, grayscale

        # PRVS
        #return sample, green_mask, sample_path, edges


In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', validation_path: str = './', test_path: str = './', batch_size: int = 5, num_workers: int = 2):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = 256
    def setup(self, stage=None):
        img_tf = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.CenterCrop(size=self.size),
            transforms.RandomHorizontalFlip()
            #transforms.ToTensor()
        ])
        
        self.DFNetdataset_train = DS(self.training_dir, img_tf, self.size)
        self.DFNetdataset_validation = DS_green_from_mask(self.validation_dir, img_tf)
        self.DFNetdataset_test = DS_green_from_mask(self.test_dir)

    def train_dataloader(self):
        return DataLoader(self.DFNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

# Data (16x16 256px, batch (for inpainting))

Uses 16x16 image grids to avoid problems with Colab. Colab tends to freeze if you upload too many files. Optionally, also assumes that the masks are 4x4 (untested). Only recommended for usage with a lot of files, like 5-6 digit amount of files. Returns a batch within dataloader and not with ``getitem`` and not only one image to improve speed. Assumes you use ``batch_size 1`` and the usage of ``batch_size_DL`` in ``Trainer()``. The higher the batch_size, the higher the training speed will benefit from this.

A few benchmarks, with 4k images and non-grid masks.
```
Tesla T4: 
RFR (batch_size=5):
256px:       1-1.08 s/it
4k (normal): 1-1.1 s/it
4k (batch):  1-1.1 s/it

DFNet (batch_size=20):
256px:       2,2-2,6 it/s
4k (normal): 2.8-3.4 s/it
4k (batch):  2.3-2.6 it/s
```

In [None]:
#@title data.py
import os

import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import random

import cv2
import random
import glob
import random


class DS(Dataset):
    def __init__(self, root, transform=None, size=256, batch_size_DL = 3):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform
        self.mask_dir = '/content/masks'
        self.files = glob.glob(self.mask_dir + '/**/*.png', recursive=True)
        files_jpg = glob.glob(self.mask_dir + '/**/*.jpg', recursive=True)
        self.files.extend(files_jpg)

        self.size = size
        self.batch_size = batch_size_DL

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        sample = cv2.imread(sample_path)

        #batch_size = 10
        pos_total = []
        self.total_size = 0

        while True:
          # determine random position
          x_rand = random.randint(0,15)
          y_rand = random.randint(0,15)
          
          pos_rand = [x_rand, y_rand]

          if (pos_rand in pos_total) != True:
            pos_total.append(pos_rand)
            self.total_size += 1

          # return batchsize
          if self.total_size == self.batch_size:
            break

        self.total_size = 0
        for i in pos_total:
          # creating sample if for start
          """
          print("pos_total")
          print(pos_total)

          print("i")
          print(i)
          """
          if self.total_size == 0:
            sample_add = sample[i[0]*256:(i[0]+1)*256, i[1]*256:(i[1]+1)*256]
            sample_add = cv2.cvtColor(sample_add, cv2.COLOR_BGR2RGB)
            sample_add = torch.from_numpy(sample_add).permute(2, 0, 1).unsqueeze(0)/255

            # if edges are required
            """
            grayscale = cv2.cvtColor(np.array(sample_add), cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(grayscale,100,150)
            grayscale = torch.from_numpy(grayscale).unsqueeze(0)/255
            edges = torch.from_numpy(edges).unsqueeze(0)
            """

            self.total_size += 1
          else:
            sample_add2 = sample[i[0]*256:(i[0]+1)*256, i[1]*256:(i[1]+1)*256]
            sample_add2 = cv2.cvtColor(sample_add2, cv2.COLOR_BGR2RGB)
            # if edges are required
            """
            grayscale = cv2.cvtColor(np.array(sample_add2), cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(grayscale,100,150)
            grayscale = torch.from_numpy(grayscale).unsqueeze(0)/255
            edges = torch.from_numpy(edges).unsqueeze(0)
            """
            sample_add2 = torch.from_numpy(sample_add2).permute(2, 0, 1).unsqueeze(0)/255
            sample_add = torch.cat((sample_add, sample_add2), dim=0)

        # getting mask batch
        
        self.total_size = 0
        for i in range(self.batch_size):
          # randommly loading one mask

          if random.uniform(0, 1) < 0.5:
            # generating mask automatically with 50% chance
            mask = DS.random_mask(height=self.size, width=self.size)
            mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)
            #print("random mask")
            #print(mask.shape)

          else:
            # load random mask from folder
            mask = cv2.imread(random.choice([x for x in self.files]), cv2.IMREAD_UNCHANGED)
            mask = cv2.resize(mask, (self.size,self.size), interpolation=cv2.INTER_NEAREST)
            
            # flip mask randomly
            if 0.3 < random.uniform(0, 1) <= 0.66:
              mask = np.flip(mask, axis=0)
            elif 0.66 < random.uniform(0, 1) <= 1:
              mask = np.flip(mask, axis=1)
            mask = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0).unsqueeze(0)
            #print("read mask")
            #print(mask.shape)

          if self.total_size == 0:
            mask_add = mask/255
            self.total_size += 1
          else:
            mask_add2 = mask/255
            mask_add = torch.cat((mask_add, mask_add2), dim=0)
            self.total_size += 1

        # apply mask
        masked = sample_add * mask_add

        return masked, mask_add, sample_add

        # EdgeConnect
        #return masked, mask, sample, edges, grayscale

        # PRVS
        #return masked, mask, sample, edges

    
    @staticmethod
    def random_mask(height=256, width=256,
                    min_stroke=1, max_stroke=4,
                    min_vertex=1, max_vertex=12,
                    min_brush_width_divisor=16, max_brush_width_divisor=10):
        mask = np.ones((height, width))

        min_brush_width = height // min_brush_width_divisor
        max_brush_width = height // max_brush_width_divisor
        max_angle = 2*np.pi
        num_stroke = np.random.randint(min_stroke, max_stroke+1)
        average_length = np.sqrt(height*height + width*width) / 8

        for _ in range(num_stroke):
            num_vertex = np.random.randint(min_vertex, max_vertex+1)
            start_x = np.random.randint(width)
            start_y = np.random.randint(height)

            for _ in range(num_vertex):
                angle = np.random.uniform(max_angle)
                length = np.clip(np.random.normal(average_length, average_length//2), 0, 2*average_length)
                brush_width = np.random.randint(min_brush_width, max_brush_width+1)
                end_x = (start_x + length * np.sin(angle)).astype(np.int32)
                end_y = (start_y + length * np.cos(angle)).astype(np.int32)

                cv2.line(mask, (start_y, start_x), (end_y, end_x), 0., brush_width)

                start_x, start_y = end_x, end_y
        if np.random.random() < 0.5:
            mask = np.fliplr(mask)
        if np.random.random() < 0.5:
            mask = np.flipud(mask)
        return mask.reshape((1,)+mask.shape).astype(np.float32) 



class DS_green_from_mask(Dataset):
    def __init__(self, root, transform=None):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        #sample = Image.open(sample_path).convert('RGB')
        sample = cv2.imread(sample_path)
        sample = cv2.cvtColor(sample, cv2.COLOR_BGR2RGB)

        # if edges are required
        grayscale = cv2.cvtColor(sample, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(grayscale,100,150)
        grayscale = torch.from_numpy(grayscale).unsqueeze(0)
        edges = torch.from_numpy(edges).unsqueeze(0)

        green_mask = 1-np.all(sample == [0,255,0], axis=-1).astype(int)
        green_mask = torch.from_numpy(green_mask).unsqueeze(0)
        sample = torch.from_numpy(sample.astype(np.float32)).permute(2, 0, 1)/255
        sample = sample * green_mask

        # train_batch[0] = masked
        # train_batch[1] = mask
        # train_batch[2] = path
        return sample, green_mask, sample_path 

        # EdgeConnect
        #return sample, green_mask, sample_path, edges, grayscale

        # PRVS
        #return sample, green_mask, sample_path, edges


In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', validation_path: str = './', test_path: str = './', batch_size: int = 5, batch_size_DL: int = 2, num_workers: int = 2):
        super().__init__()
        self.training_dir = training_path
        self.validation_dir = validation_path
        self.test_dir = test_path
        self.batch_size = batch_size
        self.batch_size_DL = batch_size_DL
        self.num_workers = num_workers
        self.size = 256
    def setup(self, stage=None):
        img_tf = transforms.Compose([
            transforms.Resize(size=self.size),
            transforms.CenterCrop(size=self.size),
            transforms.RandomHorizontalFlip()
            #transforms.ToTensor()
        ])
        
        self.DFNetdataset_train = DS(self.training_dir, img_tf, self.size, batch_size_DL = self.batch_size_DL)
        self.DFNetdataset_validation = DS_green_from_mask(self.validation_dir, img_tf)
        self.DFNetdataset_test = DS_green_from_mask(self.test_dir)

    def train_dataloader(self):
        return DataLoader(self.DFNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
#@title CustomTrainClass.py (adding squeeze)
from vic.loss import CharbonnierLoss, GANLoss, GradientPenaltyLoss, HFENLoss, TVLoss, GradientLoss, ElasticLoss, RelativeL1, L1CosineSim, ClipL1, MaskedL1Loss, MultiscalePixelLoss, FFTloss, OFLoss, L1_regularization, ColorLoss, AverageLoss, GPLoss, CPLoss, SPL_ComputeWithTrace, SPLoss, Contextual_Loss, StyleLoss
from vic.perceptual_loss import PerceptualLoss
from metrics import *
from torchvision.utils import save_image
from torch.autograd import Variable

from tensorboardX import SummaryWriter
logdir='/content/'
writer = SummaryWriter(logdir=logdir)

from adamp import AdamP
#from adamp import SGDP

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()
    ############################
    # generators with one output, no AMP means nan loss during training

    #self.netG = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, upscale=4, norm_type='null',
    #            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
    #            finalact=None, gaussian_noise=True, plus=False, 
    #            nr=3)


    # DFNet
    self.netG = DFNet(c_img=3, c_mask=1, c_alpha=3,
            mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
            en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3, 3, 3, 3, 3, 3, 3, 3],
            blend_layers=[0, 1, 2, 3, 4, 5], conv_type='partial')
    
    # AdaFill
    #self.netG = InpaintNet()

    # MEDFE (batch_size: 1, no AMP)
    #self.netG = MEDFEGenerator()

    # RFR
    # conv_type = partial or deform
    #self.netG = RFRNet(conv_type='partial')

    # LBAM
    #self.netG = LBAMModel(inputChannels=4, outputChannels=3)

    # DMFN
    #self.netG = InpaintingGenerator(in_nc=4, out_nc=3,nf=64,n_res=8,
    #      norm='in', activation='relu')

    # partial
    #self.netG = Model()

    # RN
    #self.netG = G_Net(input_channels=3, residual_blocks=8, threshold=0.8)
    # using rn init to avoid errors
    #RN_arch = rn_initialize_weights(self.netG, scale=0.1)


    ############################

    # generators with two outputs

    # deepfillv1
    #self.netG = InpaintSANet()

    # deepfillv2
    # conv_type = partial or deform
    #self.netG = GatedGenerator(in_channels=4, out_channels=3, 
    #  latent_channels=64, pad_type='zero', activation='lrelu', norm='in', conv_type = 'partial')

    # Adaptive
    # [Warning] Adaptive does not like PatchGAN, Multiscale and ResNet.
    #self.netG = PyramidNet(in_channels=3, residual_blocks=1, init_weights='True')

    ############################
    # exotic generators

    # Pluralistic
    #self.netG = PluralisticGenerator(ngf_E=opt_net['ngf_E'], z_nc_E=opt_net['z_nc_E'], img_f_E=opt_net['img_f_E'], layers_E=opt_net['layers_E'], norm_E=opt_net['norm_E'], activation_E=opt_net['activation_E'],
    #            ngf_G=opt_net['ngf_G'], z_nc_G=opt_net['z_nc_G'], img_f_G=opt_net['img_f_G'], L_G=opt_net['L_G'], output_scale_G=opt_net['output_scale_G'], norm_G=opt_net['norm_G'], activation_G=opt_net['activation_G'])

    
    # EdgeConnect
    #conv_type_edge: 'normal' # normal | partial | deform (has no spectral_norm)
    #self.netG = EdgeConnectModel(residual_blocks_edge=8,
    #        residual_blocks_inpaint=8, use_spectral_norm=True,
    #        conv_type_edge='normal', conv_type_inpaint='normal')

    # FRRN
    #self.netG = FRRNet()

    # PRVS
    #self.netG = PRVSNet()

    # CSA
    #self.netG = InpaintNet(c_img=3, norm='instance', act_en='leaky_relu', 
    #                           act_de='relu')


    weights_init(self.netG, 'kaiming')
    ############################


    # discriminators
    # size refers to input shape of tensor

    self.netD = context_encoder()

    # VGG
    #self.netD = Discriminator_VGG(size=256, in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN')
    #self.netD = Discriminator_VGG_fea(size=256, in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D',
    #     arch='ESRGAN', spectral_norm=False, self_attention = False, max_pool=False, poolsize = 4)
    #self.netD = Discriminator_VGG_128_SN()
    #self.netD = VGGFeatureExtractor(feature_layer=34,use_bn=False,use_input_norm=True,device=torch.device('cpu'),z_norm=False)

    # PatchGAN
    #self.netD = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
    #    use_sigmoid=False, getIntermFeat=False, patch=True, use_spectral_norm=False)

    # Multiscale
    #self.netD = MultiscaleDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
    #             use_sigmoid=False, num_D=3, getIntermFeat=False)

    # ResNet
    #self.netD = Discriminator_ResNet_128(in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA')
    #self.netD = ResNet101FeatureExtractor(use_input_norm=True, device=torch.device('cpu'), z_norm=False)
    
    # MINC
    #self.netD = MINCNet()

    # Pixel
    #self.netD = PixelDiscriminator(input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d)

    # EfficientNet
    #from efficientnet_pytorch import EfficientNet
    #self.netD = EfficientNet.from_pretrained('efficientnet-b0')

    # ResNeSt
    # ["resnest50", "resnest101", "resnest200", "resnest269"]
    #self.netD = resnest50(pretrained=True)

    # need fixing
    #FileNotFoundError: [Errno 2] No such file or directory: '../experiments/pretrained_models/VGG16minc_53.pth'
    #self.netD = MINCFeatureExtractor(feature_layer=34, use_bn=False, use_input_norm=True, device=torch.device('cpu'))

    # Transformer (Warning: uses own init!)
    #self.netD  = TranformerDiscriminator(img_size=256, patch_size=1, in_chans=3, num_classes=1, embed_dim=64, depth=7,
    #             num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
    #             drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm)
    

    weights_init(self.netD, 'kaiming')


    # loss functions
    self.l1 = nn.L1Loss()
    l_hfen_type = L1CosineSim()
    self.HFENLoss = HFENLoss(loss_f=l_hfen_type, kernel='log', kernel_size=15, sigma = 2.5, norm = False)
    self.ElasticLoss = ElasticLoss(a=0.2, reduction='mean')
    self.RelativeL1 = RelativeL1(eps=.01, reduction='mean')
    self.L1CosineSim = L1CosineSim(loss_lambda=5, reduction='mean')
    self.ClipL1 = ClipL1(clip_min=0.0, clip_max=10.0)
    self.FFTloss = FFTloss(loss_f = torch.nn.L1Loss, reduction='mean')
    self.OFLoss = OFLoss()
    self.GPLoss = GPLoss(trace=False, spl_denorm=False)
    self.CPLoss = CPLoss(rgb=True, yuv=True, yuvgrad=True, trace=False, spl_denorm=False, yuv_denorm=False)
    self.StyleLoss = StyleLoss()
    self.TVLoss = TVLoss(tv_type='tv', p = 1)
    self.PerceptualLoss = PerceptualLoss(model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], model_path=None)
    layers_weights = {'conv_1_1': 1.0, 'conv_3_2': 1.0}
    self.Contextual_Loss = Contextual_Loss(layers_weights, crop_quarter=False, max_1d_size=100,
        distance_type = 'cosine', b=1.0, band_width=0.5,
        use_vgg = True, net = 'vgg19', calc_type = 'regular')

    self.MSELoss = torch.nn.MSELoss()
    self.L1Loss = nn.L1Loss()

    # metrics
    self.psnr_metric = PSNR()
    self.ssim_metric = SSIM()
    self.ae_metric = AE()
    self.mse_metric = MSE()


  def forward(self, image, masks):
      return self.netG(image, masks)

  #def adversarial_loss(self, y_hat, y):
  #    return F.binary_cross_entropy(y_hat, y)


  def training_step(self, train_batch, batch_idx):
      # train_batch[0][0] = batch_size
      # train_batch[0] = masked
      # train_batch[1] = mask
      # train_batch[2] = original

      # train generator
      ############################
      # generate fake (1 output)
      squeeze0 = torch.squeeze(train_batch[0], 0)
      squeeze1 = torch.squeeze(train_batch[1], 0)
      squeeze2 = torch.squeeze(train_batch[2], 0)
      out = self(squeeze0,squeeze1)

      # masking, taking original content from HR
      out = squeeze0*(squeeze1)+out*(1-squeeze1)

      ############################
      # generate fake (2 outputs)
      #out, other_img = self(train_batch[0],train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      ############################
      # exotic generators
      # CSA
      #coarse_result, out, csa, csa_d = self(train_batch[0],train_batch[1])
      
      # EdgeConnect
      # train_batch[3] = edges
      # train_batch[4] = grayscale
      #out, other_img = self.netG(train_batch[0], train_batch[3], train_batch[4], train_batch[1])
      
      # PVRS
      #out, _ ,edge_small, edge_big = self.netG(train_batch[0], train_batch[1], train_batch[3])

      # FRRN
      #out, mid_x, mid_mask = self(train_batch[0], train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      ############################
      # loss calculation
      total_loss = 0
      """
      HFENLoss_forward = self.HFENLoss(out, train_batch[0])
      total_loss += HFENLoss_forward
      ElasticLoss_forward = self.ElasticLoss(out, train_batch[0])
      total_loss += ElasticLoss_forward
      RelativeL1_forward = self.RelativeL1(out, train_batch[0])
      total_loss += RelativeL1_forward
      """
      #print("out")
      #print(out.shape)
      #print("squeeze2")
      #print(squeeze2.shape)
      L1CosineSim_forward = 5*self.L1CosineSim(out, squeeze2)
      total_loss += L1CosineSim_forward
      #self.log('loss/L1CosineSim', L1CosineSim_forward)
      writer.add_scalar('loss/L1CosineSim', L1CosineSim_forward, self.trainer.global_step)

      """
      ClipL1_forward = self.ClipL1(out, train_batch[0])
      total_loss += ClipL1_forward
      FFTloss_forward = self.FFTloss(out, train_batch[0])
      total_loss += FFTloss_forward
      OFLoss_forward = self.OFLoss(out)
      total_loss += OFLoss_forward
      GPLoss_forward = self.GPLoss(out, train_batch[0])
      total_loss += GPLoss_forward
      
      CPLoss_forward = 0.1*self.CPLoss(out, train_batch[0])
      total_loss += CPLoss_forward
      

      Contextual_Loss_forward = self.Contextual_Loss(out, train_batch[0])
      total_loss += Contextual_Loss_forward
      self.log('loss/contextual', Contextual_Loss_forward)
      """

      #style_forward = 240*self.StyleLoss(out, train_batch[2])
      #total_loss += style_forward
      #self.log('loss/style', style_forward)

      tv_forward = 0.0000005*self.TVLoss(out)
      total_loss += tv_forward
      #self.log('loss/tv', tv_forward)
      writer.add_scalar('loss/tv', tv_forward, self.trainer.global_step)

      perceptual_forward = 2*self.PerceptualLoss(out, squeeze2)
      total_loss += perceptual_forward
      #self.log('loss/perceptual', perceptual_forward)
      writer.add_scalar('loss/perceptual', perceptual_forward, self.trainer.global_step)







      #########################
      # exotic loss

      # if model has two output, also calculate loss for such an image
      # example with just l1 loss
      
      #l1_stage1 = self.L1Loss(other_img, train_batch[0])
      #self.log('loss/l1_stage1', l1_stage1)
      #total_loss += l1_stage1


      # CSA Loss
      """
      recon_loss = self.L1Loss(coarse_result, train_batch[2]) + self.L1Loss(out, train_batch[2])
      cons = ConsistencyLoss()
      cons_loss = cons(csa, csa_d, train_batch[2], train_batch[1])
      self.log('loss/recon_loss', recon_loss)
      total_loss += recon_loss
      self.log('loss/cons_loss', cons_loss)
      total_loss += cons_loss
      """

      # EdgeConnect
      # train_batch[3] = edges
      # train_batch[4] = grayscale
      #l1_edge = self.L1Loss(other_img, train_batch[3])
      #self.log('loss/l1_edge', l1_edge)
      #total_loss += l1_edge

      # PVRS
      """
      edge_big_l1 = self.L1Loss(edge_big, train_batch[3])
      edge_small_l1 = self.L1Loss(edge_small, torch.nn.functional.interpolate(train_batch[3], scale_factor = 0.5))
      self.log('loss/edge_big_l1', edge_big_l1)
      total_loss += edge_big_l1
      self.log('loss/edge_small_l1', edge_small_l1)
      total_loss += edge_small_l1
      """ 

      # FRRN
      """
      mid_l1_loss = 0
      for idx in range(len(mid_x) - 1):
          mid_l1_loss += self.L1Loss(mid_x[idx] * mid_mask[idx], train_batch[2] * mid_mask[idx])
      self.log('loss/mid_l1_loss', mid_l1_loss)
      total_loss += mid_l1_loss
      """

      #self.log('loss/g_loss', total_loss)
      writer.add_scalar('loss/g_loss', total_loss, self.trainer.global_step)

      #return total_loss
      #########################








      # train discriminator
      # resizing input if needed
      #train_batch[2] = torch.nn.functional.interpolate(train_batch[2], (128,128), align_corners=False, mode='bilinear')
      #out = torch.nn.functional.interpolate(out, (128,128), align_corners=False, mode='bilinear')

      Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
      valid = Variable(Tensor(out.shape).fill_(1.0), requires_grad=False)
      fake = Variable(Tensor(out.shape).fill_(0.0), requires_grad=False)
      dis_real_loss = self.MSELoss(squeeze2, valid)
      dis_fake_loss = self.MSELoss(out, fake)

      d_loss = (dis_real_loss + dis_fake_loss) / 2
      #self.log('loss/d_loss', d_loss)
      writer.add_scalar('loss/d_loss', d_loss, self.trainer.global_step)

      return total_loss+d_loss

  def configure_optimizers(self):
      #optimizer = torch.optim.Adam(self.netG.parameters(), lr=2e-3)
      optimizer = AdamP(self.netG.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-2)
      #optimizer = SGDP(self.netG.parameters(), lr=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
      return optimizer

  def validation_step(self, train_batch, train_idx):
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path

    #########################
    # generate fake (one output generator)
    out = self(train_batch[0],train_batch[1])
    # masking, taking original content from HR
    out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    #########################
    # generate fake (two output generator)
    #out, _ = self(train_batch[0],train_batch[1])

    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])
    #########################
    # CSA
    #_, out, _, _ = self(train_batch[0],train_batch[1])
    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    # EdgeConnect
    # train_batch[3] = edges
    # train_batch[4] = grayscale
    #out, _ = self.netG(train_batch[0], train_batch[3], train_batch[4], train_batch[1])

    # PVRS
    #out, _ ,_, _ = self.netG(train_batch[0], train_batch[1], train_batch[3])

    # FRRN
    #out, _, _ = self(train_batch[0], train_batch[1])

    """
    # Validation metrics work, but they need an origial source image, which is
    # not implemented. Change dataloader to provide LR and HR if you want metrics.
    self.log('metrics/PSNR', self.psnr_metric(train_batch[2], out))
    self.log('metrics/SSIM', self.ssim_metric(train_batch[2], out))
    self.log('metrics/MSE', self.mse_metric(train_batch[2], out))
    self.log('metrics/LPIPS', self.PerceptualLoss(out, train_batch[2]))
    """

    validation_output = '/content/validation_output/' #@param

    # train_batch[3] can contain multiple files, depending on the batch_size
    for f in train_batch[2]:
      # data is processed as a batch, to save indididual files, a counter is used
      counter = 0
      if not os.path.exists(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0])):
        os.makedirs(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0]))

      filename_with_extention = os.path.basename(f)
      filename = os.path.splitext(filename_with_extention)[0]
      save_image(out[counter], os.path.join(validation_output, filename, str(self.trainer.global_step) + '.png'))

      counter += 1

  def test_step(self, train_batch, train_idx):
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path
    test_output = '/content/test_output/' #@param
    if not os.path.exists(test_output):
      os.makedirs(test_output)

    out = self(train_batch[0].unsqueeze(0),train_batch[1].unsqueeze(0))
    out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    save_image(out, os.path.join(test_output, os.path.splitext(os.path.basename(train_batch[2]))[0] + '.png'))


# Data (Simple lr/hr folder loader (for 3-channel super resolution))

Just take this if you want to provide lr and hr with folders. Same for validation. Applies random crop if hr_size is smaller than the image.

In [None]:
#@title data.py
import os

import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import random

import cv2
import random
import glob
import random


class DS(Dataset):
    def __init__(self, lr_path, hr_path, hr_size, scale):
        self.samples = []
        for hr_path, _, fnames in sorted(os.walk(hr_path)):
            for fname in sorted(fnames):
                path = os.path.join(hr_path, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + hr_path)
        self.hr_size = hr_size
        self.scale = scale
        self.lr_path = lr_path

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        # getting hr image
        hr_path = self.samples[index]
        hr_image = cv2.imread(hr_path)
        hr_image = cv2.cvtColor(hr_image, cv2.COLOR_BGR2RGB)

        # getting lr image
        lr_path = os.path.join(self.lr_path, os.path.basename(hr_path))
        lr_image = cv2.imread(lr_path)
        lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)

        # checking for hr_size limitation
        if hr_image.shape[0] > self.hr_size or hr_image.shape[1] > self.hr_size:
          # image too big, random crop
          random_pos1 = random.randint(0,hr_image.shape[0]-self.hr_size)
          random_pos2 = random.randint(0,hr_image.shape[0]-self.hr_size)

          image_hr = hr_image[random_pos1:random_pos1+self.hr_size, random_pos2:random_pos2+self.hr_size]
          image_lr = lr_image[int(random_pos1/self.scale):int((random_pos2+self.hr_size)/self.scale), int(random_pos2/self.scale):int((random_pos2+self.hr_size)/self.scale)]

        # to tensor
        hr_image = torch.from_numpy(hr_image).permute(2, 0, 1)/255
        lr_image = torch.from_numpy(lr_image).permute(2, 0, 1)/255

        return lr_image, hr_image


class DS_val(Dataset):
    def __init__(self, lr_path, hr_path):
        self.samples = []
        for hr_path, _, fnames in sorted(os.walk(hr_path)):
            for fname in sorted(fnames):
                path = os.path.join(hr_path, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + hr_path)

        self.lr_path = lr_path

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        # getting hr image
        hr_path = self.samples[index]
        hr_image = cv2.imread(hr_path)

        # getting lr image
        lr_path = os.path.join(self.lr_path, os.path.basename(hr_path))
        lr_image = cv2.imread(lr_path)

        # to tensor
        hr_image = torch.from_numpy(hr_image).permute(2, 0, 1)/255
        lr_image = torch.from_numpy(lr_image).permute(2, 0, 1)/255

        return lr_image, hr_image, lr_path



In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, dir_lr: str = './',  dir_hr: str = './', val_lr: str = './', val_hr: str = './', batch_size: int = 5, num_workers: int = 2, hr_size = 256, scale = 4):
        super().__init__()

        self.dir_lr = dir_lr
        self.dir_hr = dir_hr

        self.val_lr = val_lr
        self.val_hr = val_hr

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.hr_size = hr_size
        self.scale = scale

    def setup(self, stage=None):
        self.DFNetdataset_train = DS(lr_path=self.dir_lr, hr_path=self.dir_hr, hr_size = self.hr_size, scale = self.scale)
        self.DFNetdataset_validation = DS_val(self.val_lr, self.val_hr)
        self.DFNetdataset_test = DS_val(self.val_lr, self.val_hr)

    def train_dataloader(self):
        return DataLoader(self.DFNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

# Data (3x3 400px, batch (for 1-channel super resolution))
This is for ESRGAN, reads images as single channel image and creates lr/hr pairs by using random downscaling. Uses 3x3 images as input for 4x training. Uses lr/hr folders for validation.

In [None]:
#@title data.py
import os

import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import random

import cv2
import random
import glob
import random


class DS(Dataset):
    def __init__(self, root, transform=None, size=256, batch_size_DL = 3, scale=4, image_size=400, amount_tiles=3):
        self.samples = []
        for root, _, fnames in sorted(os.walk(root)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + root)

        self.transform = transform

        #self.size = size
        self.image_size = image_size # how big one tile is
        self.scale = scale
        self.batch_size = batch_size_DL
        self.interpolation_method = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
        self.amount_tiles = amount_tiles

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        sample_path = self.samples[index]
        sample = cv2.imread(sample_path, cv2.IMREAD_GRAYSCALE)

        
        pos_total = []

        self.total_size = 0 # the current amount of images that got a random position

        while True:
          # determine random position
          x_rand = random.randint(0,self.amount_tiles-1)
          y_rand = random.randint(0,self.amount_tiles-1)
          
          pos_rand = [x_rand, y_rand]

          if (pos_rand in pos_total) != True:
            pos_total.append(pos_rand)
            self.total_size += 1

          # return batchsize
          if self.total_size == self.batch_size:
            break

        self.total_size = 0 # counter for making sure array gets appended if processed images > 1
        
        for i in pos_total:
          # creating sample if for start
          if self.total_size == 0:
            # cropping from hr image
            image_hr = sample[i[0]*self.image_size:(i[0]+1)*self.image_size, i[1]*self.image_size:(i[1]+1)*self.image_size]
            # creating lr on the fly
            #image_lr = cv2.resize(image_hr, (int(self.image_size/self.scale), int(self.image_size/self.scale)), interpolation=random.choice(self.interpolation_method))
            image_lr = cv2.resize(image_hr, (int(self.image_size/self.scale), int(self.image_size/self.scale)), interpolation=random.choice(self.interpolation_method))

            """
            print("-----------------------")
            print(i[0]*(self.image_size/self.scale))
            print((i[0]+1)*(self.image_size/self.scale))
             
            print(i[1]*(self.image_size/self.scale))
            print((i[1]+1)*(self.image_size/self.scale))
            """
            #image_lr = image_lr[i[0]*(self.image_size/self.scale):(i[0]+1)*(self.image_size/self.scale), i[1]*(self.image_size/self.scale):(i[1]+1)*(self.image_size/self.scale)]


            # creating torch tensor
            image_hr = torch.from_numpy(image_hr).unsqueeze(2).permute(2, 0, 1).unsqueeze(0)/255
            image_lr = torch.from_numpy(image_lr).unsqueeze(2).permute(2, 0, 1).unsqueeze(0)/255
            

            # if edges are required
            """
            grayscale = cv2.cvtColor(np.array(sample_add), cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(grayscale,100,150)
            grayscale = torch.from_numpy(grayscale).unsqueeze(0)/255
            edges = torch.from_numpy(edges).unsqueeze(0)
            """

            self.total_size += 1
          else:
            # cropping from hr image
            image_hr2 = sample[i[0]*self.image_size:(i[0]+1)*self.image_size, i[1]*self.image_size:(i[1]+1)*self.image_size]
            # creating lr on the fly
            #image_lr2 = cv2.resize(image_hr2, (int(self.image_size/self.scale), int(self.image_size/self.scale)), interpolation=random.choice(self.interpolation_method))
            #image_lr2 = image_lr2[i[0]*(self.image_size/self.scale):(i[0]+1)*(self.image_size/self.scale), i[1]*(self.image_size/self.scale):(i[1]+1)*(self.image_size/self.scale)]
            image_lr2 = cv2.resize(image_hr2, (int(self.image_size/self.scale), int(self.image_size/self.scale)), interpolation=random.choice(self.interpolation_method))



            # if edges are required
            """
            grayscale = cv2.cvtColor(np.array(sample_add2), cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(grayscale,100,150)
            grayscale = torch.from_numpy(grayscale).unsqueeze(0)/255
            edges = torch.from_numpy(edges).unsqueeze(0)
            """
            # creating torch tensor
            image_hr2 = torch.from_numpy(image_hr2).unsqueeze(2).permute(2, 0, 1).unsqueeze(0)/255
            image_hr = torch.cat((image_hr, image_hr2), dim=0)
            
            image_lr2 = torch.from_numpy(image_lr2).unsqueeze(2).permute(2, 0, 1).unsqueeze(0)/255
            image_lr = torch.cat((image_lr, image_lr2), dim=0)

        return image_lr, image_hr


class DS_val(Dataset):
    def __init__(self, lr_path, hr_path):
        self.samples = []
        for hr_path, _, fnames in sorted(os.walk(hr_path)):
            for fname in sorted(fnames):
                path = os.path.join(hr_path, fname)
                self.samples.append(path)
        if len(self.samples) == 0:
            raise RuntimeError("Found 0 files in subfolders of: " + hr_path)

        self.lr_path = lr_path

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        # getting hr image
        hr_path = self.samples[index]
        hr_image = cv2.imread(hr_path, cv2.IMREAD_GRAYSCALE)

        # getting lr image
        lr_path = os.path.join(self.lr_path, os.path.basename(hr_path))
        lr_image = cv2.imread(lr_path, cv2.IMREAD_GRAYSCALE)



        hr_image = torch.from_numpy(hr_image).unsqueeze(2).permute(2, 0, 1).unsqueeze(0)/255
        lr_image = torch.from_numpy(lr_image).unsqueeze(2).permute(2, 0, 1).unsqueeze(0)/255


        return lr_image, hr_image, lr_path



In [None]:
#@title dataloader.py
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class DFNetDataModule(pl.LightningDataModule):
    def __init__(self, training_path: str = './', val_lr: str = './', val_hr: str = './', batch_size: int = 5, batch_size_DL: int = 2, num_workers: int = 2, hr_size=256, scale = 4, image_size = 400, amount_tiles=3):
        super().__init__()
        self.training_dir = training_path
        self.val_lr = val_lr
        self.val_hr = val_hr

        #self.test_dir = test_path
        self.batch_size = batch_size
        self.batch_size_DL = batch_size_DL
        self.num_workers = num_workers
        self.hr_size = hr_size
        self.scale = scale
        self.image_size = image_size
        self.amount_tiles = amount_tiles

    def setup(self, stage=None):
        self.DFNetdataset_train = DS(self.training_dir, self.hr_size, batch_size_DL = self.batch_size_DL, scale=self.scale, image_size = self.image_size, amount_tiles = self.amount_tiles)
        self.DFNetdataset_validation = DS_val(self.val_lr, self.val_hr)
        self.DFNetdataset_test = DS_val(self.val_lr, self.val_hr)

    def train_dataloader(self):
        return DataLoader(self.DFNetdataset_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.DFNetdataset_validation, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.DFNetdataset_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
#@title CustomTrainClass.py (adding squeeze)
from vic.loss import CharbonnierLoss, GANLoss, GradientPenaltyLoss, HFENLoss, TVLoss, GradientLoss, ElasticLoss, RelativeL1, L1CosineSim, ClipL1, MaskedL1Loss, MultiscalePixelLoss, FFTloss, OFLoss, L1_regularization, ColorLoss, AverageLoss, GPLoss, CPLoss, SPL_ComputeWithTrace, SPLoss, Contextual_Loss, StyleLoss
from vic.perceptual_loss import PerceptualLoss
from metrics import *
from torchvision.utils import save_image
from torch.autograd import Variable

from tensorboardX import SummaryWriter
logdir='/content/'
writer = SummaryWriter(logdir=logdir)

from adamp import AdamP
#from adamp import SGDP

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()
    ############################
    # generators with one output, no AMP means nan loss during training

    self.netG = RRDBNet(in_nc=1, out_nc=1, nf=64, nb=23, gc=32, upscale=4, norm_type=None,
                act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
                finalact=None, gaussian_noise=True, plus=False, 
                nr=3)


    # DFNet
    #self.netG = DFNet(c_img=3, c_mask=1, c_alpha=3,
    #        mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
    #        en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3, 3, 3, 3, 3, 3, 3, 3],
    #        blend_layers=[0, 1, 2, 3, 4, 5], conv_type='partial')
    
    # AdaFill
    #self.netG = InpaintNet()

    # MEDFE (batch_size: 1, no AMP)
    #self.netG = MEDFEGenerator()

    # RFR
    # conv_type = partial or deform
    #self.netG = RFRNet(conv_type='partial')

    # LBAM
    #self.netG = LBAMModel(inputChannels=4, outputChannels=3)

    # DMFN
    #self.netG = InpaintingGenerator(in_nc=4, out_nc=3,nf=64,n_res=8,
    #      norm='in', activation='relu')

    # partial
    #self.netG = Model()

    # RN
    #self.netG = G_Net(input_channels=3, residual_blocks=8, threshold=0.8)
    # using rn init to avoid errors
    #RN_arch = rn_initialize_weights(self.netG, scale=0.1)


    ############################

    # generators with two outputs

    # deepfillv1
    #self.netG = InpaintSANet()

    # deepfillv2
    # conv_type = partial or deform
    #self.netG = GatedGenerator(in_channels=4, out_channels=3, 
    #  latent_channels=64, pad_type='zero', activation='lrelu', norm='in', conv_type = 'partial')

    # Adaptive
    # [Warning] Adaptive does not like PatchGAN, Multiscale and ResNet.
    #self.netG = PyramidNet(in_channels=3, residual_blocks=1, init_weights='True')

    ############################
    # exotic generators

    # Pluralistic
    #self.netG = PluralisticGenerator(ngf_E=opt_net['ngf_E'], z_nc_E=opt_net['z_nc_E'], img_f_E=opt_net['img_f_E'], layers_E=opt_net['layers_E'], norm_E=opt_net['norm_E'], activation_E=opt_net['activation_E'],
    #            ngf_G=opt_net['ngf_G'], z_nc_G=opt_net['z_nc_G'], img_f_G=opt_net['img_f_G'], L_G=opt_net['L_G'], output_scale_G=opt_net['output_scale_G'], norm_G=opt_net['norm_G'], activation_G=opt_net['activation_G'])

    
    # EdgeConnect
    #conv_type_edge: 'normal' # normal | partial | deform (has no spectral_norm)
    #self.netG = EdgeConnectModel(residual_blocks_edge=8,
    #        residual_blocks_inpaint=8, use_spectral_norm=True,
    #        conv_type_edge='normal', conv_type_inpaint='normal')

    # FRRN
    #self.netG = FRRNet()

    # PRVS
    #self.netG = PRVSNet()

    # CSA
    #self.netG = InpaintNet(c_img=3, norm='instance', act_en='leaky_relu', 
    #                           act_de='relu')


    weights_init(self.netG, 'kaiming')
    ############################


    # discriminators
    # size refers to input shape of tensor

    self.netD = context_encoder()

    # VGG
    #self.netD = Discriminator_VGG(size=256, in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN')
    #self.netD = Discriminator_VGG_fea(size=256, in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D',
    #     arch='ESRGAN', spectral_norm=False, self_attention = False, max_pool=False, poolsize = 4)
    #self.netD = Discriminator_VGG_128_SN()
    #self.netD = VGGFeatureExtractor(feature_layer=34,use_bn=False,use_input_norm=True,device=torch.device('cpu'),z_norm=False)

    # PatchGAN
    #self.netD = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
    #    use_sigmoid=False, getIntermFeat=False, patch=True, use_spectral_norm=False)

    # Multiscale
    #self.netD = MultiscaleDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
    #             use_sigmoid=False, num_D=3, getIntermFeat=False)

    # ResNet
    #self.netD = Discriminator_ResNet_128(in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA')
    #self.netD = ResNet101FeatureExtractor(use_input_norm=True, device=torch.device('cpu'), z_norm=False)
    
    # MINC
    #self.netD = MINCNet()

    # Pixel
    #self.netD = PixelDiscriminator(input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d)

    # EfficientNet
    #from efficientnet_pytorch import EfficientNet
    #self.netD = EfficientNet.from_pretrained('efficientnet-b0')

    # ResNeSt
    # ["resnest50", "resnest101", "resnest200", "resnest269"]
    #self.netD = resnest50(pretrained=True)

    # need fixing
    #FileNotFoundError: [Errno 2] No such file or directory: '../experiments/pretrained_models/VGG16minc_53.pth'
    #self.netD = MINCFeatureExtractor(feature_layer=34, use_bn=False, use_input_norm=True, device=torch.device('cpu'))

    # Transformer (Warning: uses own init!)
    #self.netD  = TranformerDiscriminator(img_size=256, patch_size=1, in_chans=3, num_classes=1, embed_dim=64, depth=7,
    #             num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
    #             drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm)
    

    weights_init(self.netD, 'kaiming')


    # loss functions
    self.l1 = nn.L1Loss()
    l_hfen_type = L1CosineSim()
    self.HFENLoss = HFENLoss(loss_f=l_hfen_type, kernel='log', kernel_size=15, sigma = 2.5, norm = False)
    self.ElasticLoss = ElasticLoss(a=0.2, reduction='mean')
    self.RelativeL1 = RelativeL1(eps=.01, reduction='mean')
    self.L1CosineSim = L1CosineSim(loss_lambda=5, reduction='mean')
    self.ClipL1 = ClipL1(clip_min=0.0, clip_max=10.0)
    self.FFTloss = FFTloss(loss_f = torch.nn.L1Loss, reduction='mean')
    self.OFLoss = OFLoss()
    self.GPLoss = GPLoss(trace=False, spl_denorm=False)
    self.CPLoss = CPLoss(rgb=True, yuv=True, yuvgrad=True, trace=False, spl_denorm=False, yuv_denorm=False)
    self.StyleLoss = StyleLoss()
    self.TVLoss = TVLoss(tv_type='tv', p = 1)
    self.PerceptualLoss = PerceptualLoss(model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], model_path=None)
    layers_weights = {'conv_1_1': 1.0, 'conv_3_2': 1.0}
    self.Contextual_Loss = Contextual_Loss(layers_weights, crop_quarter=False, max_1d_size=100,
        distance_type = 'cosine', b=1.0, band_width=0.5,
        use_vgg = True, net = 'vgg19', calc_type = 'regular')

    self.MSELoss = torch.nn.MSELoss()
    self.L1Loss = nn.L1Loss()

    # metrics
    self.psnr_metric = PSNR()
    self.ssim_metric = SSIM()
    self.ae_metric = AE()
    self.mse_metric = MSE()


  # inpainting
  #def forward(self, image, masks):
  #    return self.netG(image, masks)

  # super resolution
  def forward(self, image):
    return self.netG(image)

  #def adversarial_loss(self, y_hat, y):
  #    return F.binary_cross_entropy(y_hat, y)


  def training_step(self, train_batch, batch_idx):
      # inpainting:
      # train_batch[0][0] = batch_size
      # train_batch[0] = masked
      # train_batch[1] = mask
      # train_batch[2] = original

      # super resolution
      # train_batch[0] = lr
      # train_batch[1] = hr

      # train generator
      ############################
      # generate fake (1 output)
      squeeze0 = torch.squeeze(train_batch[0], 0)
      squeeze1 = torch.squeeze(train_batch[1], 0)
      #squeeze2 = torch.squeeze(train_batch[2], 0)
      #out = self(squeeze0,squeeze1)

      # masking, taking original content from HR
      #out = squeeze0*(squeeze1)+out*(1-squeeze1)

      ############################
      # generate fake (2 outputs)
      #out, other_img = self(train_batch[0],train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      ############################
      # exotic generators
      # CSA
      #coarse_result, out, csa, csa_d = self(train_batch[0],train_batch[1])
      
      # EdgeConnect
      # train_batch[3] = edges
      # train_batch[4] = grayscale
      #out, other_img = self.netG(train_batch[0], train_batch[3], train_batch[4], train_batch[1])
      
      # PVRS
      #out, _ ,edge_small, edge_big = self.netG(train_batch[0], train_batch[1], train_batch[3])

      # FRRN
      #out, mid_x, mid_mask = self(train_batch[0], train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      ############################
      # ESRGAN
      out = self.netG(squeeze0, squeeze1)

      ############################
      # loss calculation
      total_loss = 0
      """
      HFENLoss_forward = self.HFENLoss(out, train_batch[0])
      total_loss += HFENLoss_forward
      ElasticLoss_forward = self.ElasticLoss(out, train_batch[0])
      total_loss += ElasticLoss_forward
      RelativeL1_forward = self.RelativeL1(out, train_batch[0])
      total_loss += RelativeL1_forward
      """
      #print("out")
      #print(out.shape)
      #print("squeeze2")
      #print(squeeze2.shape)
      L1CosineSim_forward = 5*self.L1CosineSim(out, squeeze1)
      total_loss += L1CosineSim_forward
      #self.log('loss/L1CosineSim', L1CosineSim_forward)
      writer.add_scalar('loss/L1CosineSim', L1CosineSim_forward, self.trainer.global_step)

      """
      ClipL1_forward = self.ClipL1(out, train_batch[0])
      total_loss += ClipL1_forward
      FFTloss_forward = self.FFTloss(out, train_batch[0])
      total_loss += FFTloss_forward
      OFLoss_forward = self.OFLoss(out)
      total_loss += OFLoss_forward
      GPLoss_forward = self.GPLoss(out, train_batch[0])
      total_loss += GPLoss_forward
      
      CPLoss_forward = 0.1*self.CPLoss(out, train_batch[0])
      total_loss += CPLoss_forward
      

      Contextual_Loss_forward = self.Contextual_Loss(out, train_batch[0])
      total_loss += Contextual_Loss_forward
      self.log('loss/contextual', Contextual_Loss_forward)
      """

      #style_forward = 240*self.StyleLoss(out, train_batch[2])
      #total_loss += style_forward
      #self.log('loss/style', style_forward)

      tv_forward = 0.0000005*self.TVLoss(out)
      total_loss += tv_forward
      #self.log('loss/tv', tv_forward)
      writer.add_scalar('loss/tv', tv_forward, self.trainer.global_step)

      perceptual_forward = 2*self.PerceptualLoss(out, squeeze1)
      total_loss += perceptual_forward
      #self.log('loss/perceptual', perceptual_forward)
      writer.add_scalar('loss/perceptual', perceptual_forward, self.trainer.global_step)







      #########################
      # exotic loss

      # if model has two output, also calculate loss for such an image
      # example with just l1 loss
      
      #l1_stage1 = self.L1Loss(other_img, train_batch[0])
      #self.log('loss/l1_stage1', l1_stage1)
      #total_loss += l1_stage1


      # CSA Loss
      """
      recon_loss = self.L1Loss(coarse_result, train_batch[2]) + self.L1Loss(out, train_batch[2])
      cons = ConsistencyLoss()
      cons_loss = cons(csa, csa_d, train_batch[2], train_batch[1])
      self.log('loss/recon_loss', recon_loss)
      total_loss += recon_loss
      self.log('loss/cons_loss', cons_loss)
      total_loss += cons_loss
      """

      # EdgeConnect
      # train_batch[3] = edges
      # train_batch[4] = grayscale
      #l1_edge = self.L1Loss(other_img, train_batch[3])
      #self.log('loss/l1_edge', l1_edge)
      #total_loss += l1_edge

      # PVRS
      """
      edge_big_l1 = self.L1Loss(edge_big, train_batch[3])
      edge_small_l1 = self.L1Loss(edge_small, torch.nn.functional.interpolate(train_batch[3], scale_factor = 0.5))
      self.log('loss/edge_big_l1', edge_big_l1)
      total_loss += edge_big_l1
      self.log('loss/edge_small_l1', edge_small_l1)
      total_loss += edge_small_l1
      """ 

      # FRRN
      """
      mid_l1_loss = 0
      for idx in range(len(mid_x) - 1):
          mid_l1_loss += self.L1Loss(mid_x[idx] * mid_mask[idx], train_batch[2] * mid_mask[idx])
      self.log('loss/mid_l1_loss', mid_l1_loss)
      total_loss += mid_l1_loss
      """

      #self.log('loss/g_loss', total_loss)
      writer.add_scalar('loss/g_loss', total_loss, self.trainer.global_step)

      #return total_loss
      #########################








      # train discriminator
      # resizing input if needed
      #train_batch[2] = torch.nn.functional.interpolate(train_batch[2], (128,128), align_corners=False, mode='bilinear')
      #out = torch.nn.functional.interpolate(out, (128,128), align_corners=False, mode='bilinear')

      Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
      valid = Variable(Tensor(out.shape).fill_(1.0), requires_grad=False)
      fake = Variable(Tensor(out.shape).fill_(0.0), requires_grad=False)
      dis_real_loss = self.MSELoss(squeeze1, valid)
      dis_fake_loss = self.MSELoss(out, fake)

      d_loss = (dis_real_loss + dis_fake_loss) / 2
      #self.log('loss/d_loss', d_loss)
      writer.add_scalar('loss/d_loss', d_loss, self.trainer.global_step)

      return total_loss+d_loss

  def configure_optimizers(self):
      #optimizer = torch.optim.Adam(self.netG.parameters(), lr=2e-3)
      optimizer = AdamP(self.netG.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-2)
      #optimizer = SGDP(self.netG.parameters(), lr=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
      return optimizer

  def validation_step(self, train_batch, train_idx):
    # inpainting
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path

    # super resolution
    # train_batch[0] = lr
    # train_batch[1] = hr
    # train_batch[2] = lr_path

    #########################
    # generate fake (one output generator)
    #out = self(train_batch[0],train_batch[1])
    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    #########################
    # generate fake (two output generator)
    #out, _ = self(train_batch[0],train_batch[1])

    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])
    #########################
    # CSA
    #_, out, _, _ = self(train_batch[0],train_batch[1])
    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    # EdgeConnect
    # train_batch[3] = edges
    # train_batch[4] = grayscale
    #out, _ = self.netG(train_batch[0], train_batch[3], train_batch[4], train_batch[1])

    # PVRS
    #out, _ ,_, _ = self.netG(train_batch[0], train_batch[1], train_batch[3])

    # FRRN
    #out, _, _ = self(train_batch[0], train_batch[1])

    # ESRGAN
    out = self.netG(train_batch[0].squeeze(0), train_batch[1].squeeze(0))


    # Validation metrics work, but they need an origial source image
    #self.log('metrics/PSNR', self.psnr_metric(train_batch[1], out))
    #self.log('metrics/SSIM', self.ssim_metric(train_batch[1], out))
    #self.log('metrics/MSE', self.mse_metric(train_batch[1], out))
    #self.log('metrics/LPIPS', self.PerceptualLoss(out, train_batch[1]))

    validation_output = '/content/validation_output/' #@param

    # train_batch[3] can contain multiple files, depending on the batch_size
    for f in train_batch[2]:
      # data is processed as a batch, to save indididual files, a counter is used
      counter = 0
      if not os.path.exists(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0])):
        os.makedirs(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0]))

      filename_with_extention = os.path.basename(f)
      filename = os.path.splitext(filename_with_extention)[0]
      save_image(out[counter], os.path.join(validation_output, filename, str(self.trainer.global_step) + '.png'))

      counter += 1

  def test_step(self, train_batch, train_idx):
    # inpainting
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path

    # super resolution
    # train_batch[0] = lr
    # train_batch[1] = hr

    test_output = '/content/test_output/' #@param
    if not os.path.exists(test_output):
      os.makedirs(test_output)

    out = self(train_batch[0].unsqueeze(0),train_batch[1].unsqueeze(0))
    out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    save_image(out, os.path.join(test_output, os.path.splitext(os.path.basename(train_batch[2]))[0] + '.png'))


# Model

In [None]:
#@title checkpoint.py
#https://github.com/PyTorchLightning/pytorch-lightning/issues/2534
import os
import pytorch_lightning as pl

class CheckpointEveryNSteps(pl.Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
        self,
        save_step_frequency,
        prefix="Checkpoint",
        use_modelcheckpoint_filename=False,
        save_path = '/content/'
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.save_step_frequency = save_step_frequency
        self.prefix = prefix
        self.use_modelcheckpoint_filename = use_modelcheckpoint_filename
        self.save_path = save_path

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        if global_step % self.save_step_frequency == 0:
            if self.use_modelcheckpoint_filename:
                filename = trainer.checkpoint_callback.filename
            else:
                filename = f"{self.prefix}_{epoch}_{global_step}.ckpt"
            #ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename)
            ckpt_path = os.path.join(self.save_path, filename)
            trainer.save_checkpoint(ckpt_path)

            # saving normal .pth models
            #https://github.com/PyTorchLightning/pytorch-lightning/issues/4114
            torch.save(trainer.model.netG.state_dict(), f"{self.prefix}_{epoch}_{global_step}_G.pth")
            torch.save(trainer.model.netD.state_dict(), f"{self.prefix}_{epoch}_{global_step}_D.pth")

            # run validation once checkpoint was made
            trainer.run_evaluation()


    def on_train_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        global_step = trainer.global_step
        ckpt_path = os.path.join(self.save_path, f"{self.prefix}_{epoch}_{global_step}.ckpt")
        trainer.save_checkpoint(ckpt_path)
        print("Checkpoint " + f"{self.prefix}_{epoch}_{global_step}.ckpt" + " saved.")

#Trainer(callbacks=[CheckpointEveryNSteps()])

In [None]:
#@title utils.py
from pathlib import Path

import cv2
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


def resize_like(x, target, mode='bilinear'):
    return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)


def list2nparray(lst, dtype=None):
    """fast conversion from nested list to ndarray by pre-allocating space"""
    if isinstance(lst, np.ndarray):
        return lst
    assert isinstance(lst, (list, tuple)), 'bad type: {}'.format(type(lst))
    assert lst, 'attempt to convert empty list to np array'
    if isinstance(lst[0], np.ndarray):
        dim1 = lst[0].shape
        assert all(i.shape == dim1 for i in lst)
        if dtype is None:
            dtype = lst[0].dtype
            assert all(i.dtype == dtype for i in lst), \
                'bad dtype: {} {}'.format(dtype, set(i.dtype for i in lst))
    elif isinstance(lst[0], (int, float, complex, np.number)):
        return np.array(lst, dtype=dtype)
    else:
        dim1 = list2nparray(lst[0])
        if dtype is None:
            dtype = dim1.dtype
        dim1 = dim1.shape
    shape = [len(lst)] + list(dim1)
    rst = np.empty(shape, dtype=dtype)
    for idx, i in enumerate(lst):
        rst[idx] = i
    return rst


def get_img_list(path):
    return sorted(list(Path(path).glob('*.png'))) + \
        sorted(list(Path(path).glob('*.jpg'))) + \
        sorted(list(Path(path).glob('*.jpeg')))


def gen_miss(img, mask, output):

    imgs = get_img_list(img)
    masks = get_img_list(mask)
    print('Total images:', len(imgs), len(masks))

    out = Path(output)
    out.mkdir(parents=True, exist_ok=True)

    for i, (img, mask) in tqdm(enumerate(zip(imgs, masks))):
        path = out.joinpath('miss_%04d.png' % (i+1))
        img = cv2.imread(str(img), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(mask), cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, img.shape[:2][::-1])
        mask = mask[..., np.newaxis]
        miss = img * (mask > 127) + 255 * (mask <= 127)
        cv2.imwrite(str(path), miss)

def merge_imgs(dirs, output, row=1, gap=2, res=512):

    image_list = [get_img_list(path) for path in dirs]
    img_count = [len(image) for image in image_list]
    print('Total images:', img_count)
    assert min(img_count) > 0, 'Please check the path of empty folder.'

    output_dir = Path(output)
    output_dir.mkdir(parents=True, exist_ok=True)

    n_img = len(dirs)
    row = row
    column = (n_img - 1) // row + 1
    print('Row:', row)
    print('Column:', column)

    for i, unit in tqdm(enumerate(zip(*image_list))):
        name = output_dir.joinpath('merge_%04d.png' % i)
        merge = np.ones([
            res*row + (row+1)*gap, res*column + (column+1)*gap, 3], np.uint8) * 255
        for j, img in enumerate(unit):
            r = j // column
            c = j - r * column
            img = cv2.imread(str(img), cv2.IMREAD_COLOR)
            if img.shape[:2] != (res, res):
                img = cv2.resize(img, (res, res))
            start_h, start_w = (r + 1) * gap + r * res, (c + 1) * gap + c * res
            merge[start_h: start_h + res, start_w: start_w + res] = img
        cv2.imwrite(str(name), merge)


In [None]:
#@title metrics.py (removing lpips import)
%%writefile /content/pytorchloss/metrics.py
#https://github.com/huster-wgm/Pytorch-metrics/blob/master/metrics.py

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
  @Email:  guangmingwu2010@gmail.com \
           guozhilingty@gmail.com
  @Copyright: go-hiroaki & Chokurei
  @License: MIT
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#import lpips

eps = 1e-6

def _binarize(y_data, threshold):
    """
    args:
        y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
        threshold : [float] [0.0, 1.0]
    return 4-d binarized y_data
    """
    y_data[y_data < threshold] = 0.0
    y_data[y_data >= threshold] = 1.0
    return y_data

def _argmax(y_data, dim):
    """
    args:
        y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
        dim : int
    return 3-d [int] y_data
    """
    return torch.argmax(y_data, dim).int()


def _get_tp(y_pred, y_true):
    """
    args:
        y_true : [int] 3-d in [batch_size, img_rows, img_cols]
        y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
    return [float] true_positive
    """
    return torch.sum(y_true * y_pred).float()


def _get_fp(y_pred, y_true):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
    return [float] false_positive
    """
    return torch.sum((1 - y_true) * y_pred).float()


def _get_tn(y_pred, y_true):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
    return [float] true_negative
    """
    return torch.sum((1 - y_true) * (1 - y_pred)).float()


def _get_fn(y_pred, y_true):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
    return [float] false_negative
    """
    return torch.sum(y_true * (1 - y_pred)).float()


def _get_weights(y_true, nb_ch):
    """
    args:
        y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
        nb_ch : int
    return [float] weights
    """
    batch_size, img_rows, img_cols = y_true.shape
    pixels = batch_size * img_rows * img_cols
    weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
    return weights


class CFMatrix(object):
    def __init__(self, des=None):
        self.des = des

    def __repr__(self):
        return "ConfusionMatrix"

    def __call__(self, y_pred, y_true, threshold=0.5):

        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return confusion matrix
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            nb_tn = _get_tn(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 4).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                nb_tn = _get_tn(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class OAAcc(object):
    def __init__(self, des="Overall Accuracy"):
        self.des = des

    def __repr__(self):
        return "OAcc"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return (tp+tn)/total
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)

        nb_tp_tn = torch.sum(y_true == y_pred).float()
        mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
        performs = None
        return mperforms, performs


class Precision(object):
    def __init__(self, des="Precision"):
        self.des = des

    def __repr__(self):
        return "Prec"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return tp/(tp+fp)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            mperforms = nb_tp / (nb_tp + nb_fp + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class Recall(object):
    def __init__(self, des="Recall"):
        self.des = des

    def __repr__(self):
        return "Reca"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return tp/(tp+fn)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            mperforms = nb_tp / (nb_tp + nb_fn + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class F1Score(object):
    def __init__(self, des="F1Score"):
        self.des = des

    def __repr__(self):
        return "F1Sc"

    def __call__(self, y_pred, y_true, threshold=0.5):

        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return 2*precision*recall/(precision+recall)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            _precision = nb_tp / (nb_tp + nb_fp + esp)
            _recall = nb_tp / (nb_tp + nb_fn + esp)
            mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                _precision = nb_tp / (nb_tp + nb_fp + esp)
                _recall = nb_tp / (nb_tp + nb_fn + esp)
                performs[int(ch)] = 2 * _precision * \
                    _recall / (_precision + _recall + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class Kappa(object):
    def __init__(self, des="Kappa"):
        self.des = des

    def __repr__(self):
        return "Kapp"

    def __call__(self, y_pred, y_true, threshold=0.5):

        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return (Po-Pe)/(1-Pe)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            nb_tp = _get_tp(y_pred, y_true)
            nb_fp = _get_fp(y_pred, y_true)
            nb_tn = _get_tn(y_pred, y_true)
            nb_fn = _get_fn(y_pred, y_true)
            nb_total = nb_tp + nb_fp + nb_tn + nb_fn
            Po = (nb_tp + nb_tn) / nb_total
            Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
                  (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
            mperforms = (Po - Pe) / (1 - Pe + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                nb_tp = _get_tp(y_pred_ch, y_true_ch)
                nb_fp = _get_fp(y_pred_ch, y_true_ch)
                nb_tn = _get_tn(y_pred_ch, y_true_ch)
                nb_fn = _get_fn(y_pred_ch, y_true_ch)
                nb_total = nb_tp + nb_fp + nb_tn + nb_fn
                Po = (nb_tp + nb_tn) / nb_total
                Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
                      + (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
                performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class Jaccard(object):
    def __init__(self, des="Jaccard"):
        self.des = des

    def __repr__(self):
        return "Jacc"

    def __call__(self, y_pred, y_true, threshold=0.5):
        """
        args:
            y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return intersection / (sum-intersection)
        """
        batch_size, chs, img_rows, img_cols = y_true.shape
        device = y_true.device
        if chs == 1:
            y_pred = _binarize(y_pred, threshold)
            y_true = _binarize(y_true, threshold)
            _intersec = torch.sum(y_true * y_pred).float()
            _sum = torch.sum(y_true + y_pred).float()
            mperforms = _intersec / (_sum - _intersec + esp)
            performs = None
        else:
            y_pred = _argmax(y_pred, 1)
            y_true = _argmax(y_true, 1)
            performs = torch.zeros(chs, 1).to(device)
            weights = _get_weights(y_true, chs)
            for ch in range(chs):
                y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
                y_true_ch[y_true == ch] = 1
                y_pred_ch[y_pred == ch] = 1
                _intersec = torch.sum(y_true_ch * y_pred_ch).float()
                _sum = torch.sum(y_true_ch + y_pred_ch).float()
                performs[int(ch)] = _intersec / (_sum - _intersec + esp)
            mperforms = sum([i*j for (i, j) in zip(performs, weights)])
        return mperforms, performs


class MSE(object):
    def __init__(self, des="Mean Square Error"):
        self.des = des

    def __repr__(self):
        return "MSE"

    def __call__(self, y_pred, y_true, dim=1, threshold=None):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return mean_squared_error, smaller the better
        """
        if threshold:
            y_pred = _binarize(y_pred, threshold)
        return torch.mean((y_pred - y_true) ** 2)


class PSNR(object):
    def __init__(self, des="Peak Signal to Noise Ratio"):
        self.des = des

    def __repr__(self):
        return "PSNR"

    def __call__(self, y_pred, y_true, dim=1, threshold=None):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            threshold : [0.0, 1.0]
        return PSNR, larger the better
        """
        if threshold:
            y_pred = _binarize(y_pred, threshold)
        mse = torch.mean((y_pred - y_true) ** 2)
        return 10 * torch.log10(1 / mse)


class SSIM(object):
    '''
    modified from https://github.com/jorge-pessoa/pytorch-msssim
    '''
    def __init__(self, des="structural similarity index"):
        self.des = des

    def __repr__(self):
        return "SSIM"

    def gaussian(self, w_size, sigma):
        gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
        return gauss/gauss.sum()

    def create_window(self, w_size, channel=1):
        _1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
        return window

    def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            w_size : int, default 11
            size_average : boolean, default True
            full : boolean, default False
        return ssim, larger the better
        """
        # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
        if torch.max(y_pred) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(y_pred) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val

        padd = 0
        (_, channel, height, width) = y_pred.size()
        window = self.create_window(w_size, channel=channel).to(y_pred.device)

        mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
        mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
        sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2

        C1 = (0.01 * L) ** 2
        C2 = (0.03 * L) ** 2

        v1 = 2.0 * sigma12 + C2
        v2 = sigma1_sq + sigma2_sq + C2
        cs = torch.mean(v1 / v2)  # contrast sensitivity

        ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

        if size_average:
            ret = ssim_map.mean()
        else:
            ret = ssim_map.mean(1).mean(1).mean(1)

        if full:
            return ret, cs
        return ret


class LPIPS(object):
    '''
    borrowed from https://github.com/richzhang/PerceptualSimilarity
    '''
    def __init__(self, cuda, des="Learned Perceptual Image Patch Similarity", version="0.1"):
        self.des = des
        self.version = version
        self.model = lpips.PerceptualLoss(model='net-lin',net='alex',use_gpu=cuda)

    def __repr__(self):
        return "LPIPS"

    def __call__(self, y_pred, y_true, normalized=True):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            normalized : change [0,1] => [-1,1] (default by LPIPS)
        return LPIPS, smaller the better
        """
        if normalized:
            y_pred = y_pred * 2.0 - 1.0
            y_true = y_true * 2.0 - 1.0
        return self.model.forward(y_pred, y_true)


class AE(object):
    """
    Modified from matlab : colorangle.m, MATLAB V2019b
    angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
    angle = 180 / pi * angle;
    """
    def __init__(self, des='average Angular Error'):
        self.des = des

    def __repr__(self):
        return "AE"

    def __call__(self, y_pred, y_true):
        """
        args:
            y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
            y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
        return average AE, smaller the better
        """
        dotP = torch.sum(y_pred * y_true, dim=1)
        Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
        Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
        ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
        return ae.mean(1).mean(1)


if __name__ == "__main__":
    for ch in [3, 1]:
        batch_size, img_row, img_col = 1, 224, 224
        y_true = torch.rand(batch_size, ch, img_row, img_col)
        noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
        y_pred = y_true + noise
        for cuda in [False, True]:
            if cuda:
                y_pred = y_pred.cuda()
                y_true = y_true.cuda()

            print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
            ########### similarity metrics
            metric = MSE()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = PSNR()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = SSIM()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = LPIPS(cuda)
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            metric = AE()
            acc = metric(y_pred, y_true).item()
            print("{} ==> {}".format(repr(metric), acc))

            ########### accuracy metrics
            metric = OAAcc()
            maccu, accu = metric(y_pred, y_true)
            print('mAccu:', maccu, 'Accu', accu)

            metric = Precision()
            mprec, prec = metric(y_pred, y_true)
            print('mPrec:', mprec, 'Prec', prec)

            metric = Recall()
            mreca, reca = metric(y_pred, y_true)
            print('mReca:', mreca, 'Reca', reca)

            metric = F1Score()
            mf1sc, f1sc = metric(y_pred, y_true)
            print('mF1sc:', mf1sc, 'F1sc', f1sc)

            metric = Kappa()
            mkapp, kapp = metric(y_pred, y_true)
            print('mKapp:', mkapp, 'Kapp', kapp)

            metric = Jaccard()
            mjacc, jacc = metric(y_pred, y_true)
            print('mJacc:', mjacc, 'Jacc', jacc)


In [None]:
#@title init.py
import torch.nn.init as init

def weights_init(net, init_type = 'kaiming', init_gain = 0.02):
    #Initialize network weights.
    #Parameters:
    #    net (network)       -- network to be initialized
    #    init_type (str)     -- the name of an initialization method: normal | xavier | kaiming | orthogonal
    #    init_var (float)    -- scaling factor for normal, xavier and orthogonal.

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # Apply the initialization function <init_func>
    print('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)

In [None]:
#@title block.py
from collections import OrderedDict

import torch
import torch.nn as nn
#from models.modules.architectures.convolutions.partialconv2d import PartialConv2d #TODO
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d
#from models.networks import weights_init_normal, weights_init_xavier, weights_init_kaiming, weights_init_orthogonal


####################
# Basic blocks
####################

# Swish activation funtion
def swish_func(x, beta=1.0):
    """
    "Swish: a Self-Gated Activation Function"
    Searching for Activation Functions (https://arxiv.org/abs/1710.05941)
    
    If beta=1 applies the Sigmoid Linear Unit (SiLU) function element-wise
    If beta=0, Swish becomes the scaled linear function (identity 
      activation) f(x) = x/2
    As beta -> ∞, the sigmoid component converges to approach a 0-1 function
      (unit step), and multiplying that by x gives us f(x)=2max(0,x), which 
      is the ReLU multiplied by a constant factor of 2, so Swish becomes like 
      the ReLU function.
    
    Including beta, Swish can be loosely viewed as a smooth function that 
      nonlinearly interpolate between identity (linear) and ReLU function.
      The degree of interpolation can be controlled by the model if beta is 
      set as a trainable parameter.
    
    Alt: 1.78718727865 * (x * sigmoid(x) - 0.20662096414)
    """
    
    # In-place implementation, may consume less GPU memory: 
    """ 
    result = x.clone()
    torch.sigmoid_(beta*x)
    x *= result
    return x
    #"""
    
    # Normal out-of-place implementation:
    #"""
    return x * torch.sigmoid(beta * x)
    #"""
    
# Swish module
class Swish(nn.Module):
    
    __constants__ = ['beta', 'slope', 'inplace']
    
    def __init__(self, beta=1.0, slope=1.67653251702, inplace=False):
        """
        Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
        """
        super(Swish).__init__()
        self.inplace = inplace
        # self.beta = beta # user-defined beta parameter, non-trainable
        # self.beta = beta * torch.nn.Parameter(torch.ones(1)) # learnable beta parameter, create a tensor out of beta
        self.beta = torch.nn.Parameter(torch.tensor(beta)) # learnable beta parameter, create a tensor out of beta
        self.beta.requiresGrad = True # set requiresGrad to true to make it trainable

        self.slope = slope / 2 # user-defined "slope", non-trainable
        # self.slope = slope * torch.nn.Parameter(torch.ones(1)) # learnable slope parameter, create a tensor out of slope
        # self.slope = torch.nn.Parameter(torch.tensor(slope)) # learnable slope parameter, create a tensor out of slope
        # self.slope.requiresGrad = True # set requiresGrad to true to true to make it trainable
    
    def forward(self, input):
        """
        # Disabled, using inplace causes:
        # "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"
        if self.inplace:
            input.mul_(torch.sigmoid(self.beta*input))
            return 2 * self.slope * input
        else:
            return 2 * self.slope * swish_func(input, self.beta)
        """
        return 2 * self.slope * swish_func(input, self.beta)


def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
    # helper selecting activation
    # neg_slope: for leakyrelu and init of prelu
    # n_prelu: for p_relu num_parameters
    # beta: for swish
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type == 'leakyrelu' or act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act_type == 'Tanh' or act_type == 'tanh':  # [-1, 1] range output
        layer = nn.Tanh()
    elif act_type == 'sigmoid':  # [0, 1] range output
        layer = nn.Sigmoid()
    elif act_type == 'swish':
        layer = Swish(beta=beta, inplace=inplace)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer


class Identity(nn.Module):
    def __init__(self, *kwargs):
        super(Identity, self).__init__()

    def forward(self, x, *kwargs):
        return x


def norm(norm_type, nc):
    """Return a normalization layer
    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    norm_type = norm_type.lower()
    if norm_type == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
        # norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
        # norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    # elif norm_type == 'layer':
    #     return lambda num_features: nn.GroupNorm(1, num_features)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
    return layer


def add_spectral_norm(module, use_spectral_norm=False):
    """ Add spectral norm to any module passed if use_spectral_norm = True,
    else, returns the original module without change
    """
    if use_spectral_norm:
        return nn.utils.spectral_norm(module)
    return module


def pad(pad_type, padding):
    """
    helper selecting padding layer
    if padding is 'zero', can be done with conv layers
    """
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    elif pad_type == 'zero':
        layer = nn.ZeroPad2d(padding)
    else:
        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding


class ConcatBlock(nn.Module):
    # Concat the output of a submodule to its input
    def __init__(self, submodule):
        super(ConcatBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = torch.cat((x, self.sub(x)), dim=1)
        return output

    def __repr__(self):
        return 'Identity .. \n|' + self.sub.__repr__().replace('\n', '\n|')


class ShortcutBlock(nn.Module):
    # Elementwise sum the output of a submodule to its input
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output

    def __repr__(self):
        return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')


def sequential(*args):
    # Flatten Sequential. It unwraps nn.Sequential.
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', \
               spectral_norm=False):
    """
    Conv layer with padding, normalization, activation
    mode: CNA --> Conv -> Norm -> Act
        NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
    """
    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0
    
    if convtype=='PartialConv2D':
        c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='DeformConv2D':
        c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='Conv3D':
        c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
                dilation=dilation, bias=bias, groups=groups)
    else: #default case is standard 'Conv2D':
        c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
                dilation=dilation, bias=bias, groups=groups) #normal conv2d
            
    if spectral_norm:
        c = nn.utils.spectral_norm(c)
    
    a = act(act_type) if act_type else None
    if 'CNA' in mode:
        n = norm(norm_type, out_nc) if norm_type else None
        return sequential(p, c, n, a)
    elif mode == 'NAC':
        if norm_type is None and act_type is not None:
            a = act(act_type, inplace=False)
            # Important!
            # input----ReLU(inplace)----Conv--+----output
            #        |________________________|
            # inplace ReLU will modify the input, therefore wrong output
        n = norm(norm_type, in_nc) if norm_type else None
        return sequential(n, a, p, c)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.
    Args:
        basic_block (nn.module): nn.module class for basic block. (block)
        num_basic_block (int): number of blocks. (n_layers)
    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


class Mean(nn.Module):
  def __init__(self, dim: list, keepdim=False):
    super().__init__()
    self.dim = dim
    self.keepdim = keepdim

  def forward(self, x):
    return torch.mean(x, self.dim, self.keepdim)


####################
# initialize modules
####################

@torch.no_grad()
def default_init_weights(module_list, init_type='kaiming', scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.
    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        init_type (str): the type of initialization in: 'normal', 'kaiming' 
            or 'orthogonal'
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1. (for 'kaiming')
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function:
            mean and/or std for 'normal'.
            a and/or mode for 'kaiming'
            gain for 'orthogonal' and xavier
    """
    
    # TODO
    # logger.info('Initialization method [{:s}]'.format(init_type))
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if init_type == 'normal':
                weights_init_normal(m, bias_fill=bias_fill, **kwargs)
            if init_type == 'xavier':
                weights_init_xavier(m, scale=scale, bias_fill=bias_fill, **kwargs)    
            elif init_type == 'kaiming':
                weights_init_kaiming(m, scale=scale, bias_fill=bias_fill, **kwargs)
            elif init_type == 'orthogonal':
                weights_init_orthogonal(m, bias_fill=bias_fill)
            else:
                raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))



####################
# Upsampler
####################

class Upsample(nn.Module):
    r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.

    The input data is assumed to be of the form
    `minibatch x channels x [optional depth] x [optional height] x width`.

    Args:
        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
            output spatial sizes
        scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
            multiplier for spatial size. Has to match input size if it is a tuple.
        mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
            ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
            Default: ``'nearest'``
        align_corners (bool, optional): if ``True``, the corner pixels of the input
            and output tensors are aligned, and thus preserving the values at
            those pixels. This only has effect when :attr:`mode` is
            ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``
    """
    # To prevent warning: nn.Upsample is deprecated
    # https://discuss.pytorch.org/t/which-function-is-better-for-upsampling-upsampling-or-interpolate/21811/8
    # From: https://pytorch.org/docs/stable/_modules/torch/nn/modules/upsampling.html#Upsample
    # Alternative: https://discuss.pytorch.org/t/using-nn-function-interpolate-inside-nn-sequential/23588/2?u=ptrblck
    
    def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
        super(Upsample, self).__init__()
        if isinstance(scale_factor, tuple):
            self.scale_factor = tuple(float(factor) for factor in scale_factor)
        else:
            self.scale_factor = float(scale_factor) if scale_factor else None
        self.mode = mode
        self.size = size
        self.align_corners = align_corners
        # self.interp = nn.functional.interpolate
    
    def forward(self, x):
        return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
        # return self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
    
    def extra_repr(self):
        if self.scale_factor is not None:
            info = 'scale_factor=' + str(self.scale_factor)
        else:
            info = 'size=' + str(self.size)
        info += ', mode=' + self.mode
        return info

def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                        pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
    """
    Pixel shuffle layer
    (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
    Neural Network, CVPR17)
    """
    conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
    pixel_shuffle = nn.PixelShuffle(upscale_factor)

    n = norm(norm_type, out_nc) if norm_type else None
    a = act(act_type) if act_type else None
    return sequential(conv, pixel_shuffle, n, a)

def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
    """
    Upconv layer described in https://distill.pub/2016/deconv-checkerboard/
    Example to replace deconvolutions: 
        - from: nn.ConvTranspose2d(in_nc, out_nc, kernel_size=4, stride=2, padding=1)
        - to: upconv_block(in_nc, out_nc,kernel_size=3, stride=1, act_type=None)
    """
    # upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
    upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
    upsample = Upsample(scale_factor=upscale_factor, mode=mode) #Updated to prevent the "nn.Upsample is deprecated" Warning
    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
    return sequential(upsample, conv)

# PPON
def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1):
    padding = int((kernel_size - 1) / 2) * dilation
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation, groups=groups)




####################
# ESRGANplus
####################

class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1, is_relative_detach=False):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.noise = torch.tensor(0, dtype=torch.float).to(torch.device('cuda'))

    def forward(self, x):
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
            x = x + sampled_noise
        return x 

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


# TODO: Not used:
# https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py
class minibatch_std_concat_layer(nn.Module):
    def __init__(self, averaging='all'):
        super(minibatch_std_concat_layer, self).__init__()
        self.averaging = averaging.lower()
        if 'group' in self.averaging:
            self.n = int(self.averaging[5:])
        else:
            assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging
        self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)

    def forward(self, x):
        shape = list(x.size())
        target_shape = copy.deepcopy(shape)
        vals = self.adjusted_std(x, dim=0, keepdim=True)
        if self.averaging == 'all':
            target_shape[1] = 1
            vals = torch.mean(vals, dim=1, keepdim=True)
        elif self.averaging == 'spatial':
            if len(shape) == 4:
                vals = mean(vals, axis=[2,3], keepdim=True)             # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True)
        elif self.averaging == 'none':
            target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
        elif self.averaging == 'gpool':
            if len(shape) == 4:
                vals = mean(x, [0,2,3], keepdim=True)                   # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True)
        elif self.averaging == 'flat':
            target_shape[1] = 1
            vals = torch.FloatTensor([self.adjusted_std(x)])
        else:                                                           # self.averaging == 'group'
            target_shape[1] = self.n
            vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3])
            vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1)
        vals = vals.expand(*target_shape)
        return torch.cat([x, vals], 1)


####################
# Useful blocks
####################

class SelfAttentionBlock(nn.Module):
    """ 
        Implementation of Self attention Block according to paper 
        'Self-Attention Generative Adversarial Networks' (https://arxiv.org/abs/1805.08318)
        Flexible Self Attention (FSA) layer according to paper
        Efficient Super Resolution For Large-Scale Images Using Attentional GAN (https://arxiv.org/pdf/1812.04821.pdf)
          The FSA layer borrows the self attention layer from SAGAN, 
          and wraps it with a max-pooling layer to reduce the size 
          of the feature maps and enable large-size images to fit in memory.
        Used in Generator and Discriminator Networks.
    """

    def __init__(self, in_dim, max_pool=False, poolsize = 4, spectral_norm=False, ret_attention=False): #in_dim = in_feature_maps
        super(SelfAttentionBlock,self).__init__()

        self.in_dim = in_dim
        self.max_pool = max_pool
        self.poolsize = poolsize
        self.ret_attention = ret_attention
        
        if self.max_pool:
            self.pooled = nn.MaxPool2d(kernel_size=self.poolsize, stride=self.poolsize) #kernel_size=4, stride=4
            # Note: can test using strided convolutions instead of MaxPool2d! :
            #upsample_block_num = int(math.log(scale_factor, 2))
            #self.pooled = nn.Conv2d .... strided conv
            # upsample_o = [UpconvBlock(in_channels=in_dim, out_channels=in_dim, upscale_factor=2, mode='bilinear', act_type='leakyrelu') for _ in range(upsample_block_num)]
            ## upsample_o.append(nn.Conv2d(nf, in_nc, kernel_size=9, stride=1, padding=4))
            ## self.upsample_o = nn.Sequential(*upsample_o)

            # self.upsample_o = B.Upsample(scale_factor=self.poolsize, mode='bilinear', align_corners=False) 
            
        self.conv_f = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #query_conv 
        self.conv_g = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #key_conv 
        self.conv_h = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #value_conv 

        self.gamma = nn.Parameter(torch.zeros(1)) # Trainable interpolation parameter
        self.softmax  = nn.Softmax(dim = -1)
        
    def forward(self,input):
        """
            inputs :
                input : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        
        if self.max_pool: #Downscale with Max Pool
            x = self.pooled(input)
        else:
            x = input
            
        batch_size, C, width, height = x.size()
        
        N = width * height
        x = x.view(batch_size, -1, N)
        f = self.conv_f(x) #proj_query  # B X CX(N)
        g = self.conv_g(x) #proj_key    # B X C x (*W*H)
        h = self.conv_h(x) #proj_value  # B X C X N

        s = torch.bmm(f.permute(0, 2, 1), g) # energy, transpose check
        # get probabilities
        attention = self.softmax(s) #beta #attention # BX (N) X (N) 
        
        out = torch.bmm(h, attention.permute(0,2,1))
        out = out.view(batch_size, C, width, height) 
        
        if self.max_pool: #Upscale to original size
            # out = self.upsample_o(out)
            out = Upsample(size=(input.shape[2],input.shape[3]), mode='bicubic', align_corners=False)(out) #bicubic (PyTorch > 1.0) | bilinear others.
        
        out = self.gamma*out + input #Add original input
        
        if self.ret_attention:
            return out, attention
        else:
            return out



In [None]:
#@title spectral_norm.py
"""
spectral_norm.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/spectral_norm.py
"""
'''
Copy from pytorch github repo
Spectral Normalization from https://arxiv.org/abs/1802.05957
'''
import torch
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter


class SpectralNorm(object):
    def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
        self.name = name
        self.dim = dim
        if n_power_iterations <= 0:
            raise ValueError('Expected n_power_iterations to be positive, but '
                             'got n_power_iterations={}'.format(n_power_iterations))
        self.n_power_iterations = n_power_iterations
        self.eps = eps

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        weight_mat = weight
        if self.dim != 0:
            # permute dim to front
            weight_mat = weight_mat.permute(self.dim,
                                            *[d for d in range(weight_mat.dim()) if d != self.dim])
        height = weight_mat.size(0)
        weight_mat = weight_mat.reshape(height, -1)
        with torch.no_grad():
            for _ in range(self.n_power_iterations):
                # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                # are the first left and right singular vectors.
                # This power iteration produces approximations of `u` and `v`.
                v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
                u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)

        sigma = torch.dot(u, torch.matmul(weight_mat, v))
        weight = weight / sigma
        return weight, u

    def remove(self, module):
        weight = getattr(module, self.name)
        delattr(module, self.name)
        delattr(module, self.name + '_u')
        delattr(module, self.name + '_orig')
        module.register_parameter(self.name, torch.nn.Parameter(weight))

    def __call__(self, module, inputs):
        if module.training:
            weight, u = self.compute_weight(module)
            setattr(module, self.name, weight)
            setattr(module, self.name + '_u', u)
        else:
            r_g = getattr(module, self.name + '_orig').requires_grad
            getattr(module, self.name).detach_().requires_grad_(r_g)

    @staticmethod
    def apply(module, name, n_power_iterations, dim, eps):
        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = module._parameters[name]
        height = weight.size(dim)

        u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a
        # buffer, which will cause weight to be included in the state dict
        # and also supports nn.init due to shared storage.
        module.register_buffer(fn.name, weight.data)
        module.register_buffer(fn.name + "_u", u)

        module.register_forward_pre_hook(fn)
        return fn


def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
    r"""Applies spectral normalization to a parameter in the given module.

    .. math::
         \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
         \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

    Spectral normalization stabilizes the training of discriminators (critics)
    in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
    with spectral norm :math:`\sigma` of the weight matrix calculated using
    power iteration method. If the dimension of the weight tensor is greater
    than 2, it is reshaped to 2D in power iteration method to get spectral
    norm. This is implemented via a hook that calculates spectral norm and
    rescales weight before every :meth:`~Module.forward` call.

    See `Spectral Normalization for Generative Adversarial Networks`_ .

    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter
        n_power_iterations (int, optional): number of power iterations to
            calculate spectal norm
        eps (float, optional): epsilon for numerical stability in
            calculating norms
        dim (int, optional): dimension corresponding to number of outputs,
            the default is 0, except for modules that are instances of
            ConvTranspose1/2/3d, when it is 1

    Returns:
        The original module with the spectal norm hook

    Example::

        >>> m = spectral_norm(nn.Linear(20, 40))
        Linear (20 -> 40)
        >>> m.weight_u.size()
        torch.Size([20])

    """
    if dim is None:
        if isinstance(
                module,
            (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
            dim = 1
        else:
            dim = 0
    SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module


def remove_spectral_norm(module, name='weight'):
    r"""Removes the spectral normalization reparameterization from a module.

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter

    Example:
        >>> m = spectral_norm(nn.Linear(40, 10))
        >>> remove_spectral_norm(m)
    """
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, SpectralNorm) and hook.name == name:
            hook.remove(module)
            del module._forward_pre_hooks[k]
            return module

    raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))


In [None]:
#@title discriminator.py
"""
discriminators.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/architectures/discriminators.py
"""
import math
import torch
import torch.nn as nn
import torchvision
#from . import block as B
from torch.nn.utils import spectral_norm as SN
import pytorch_lightning as pl



####################
# Discriminator
####################


# VGG style Discriminator
class Discriminator_VGG(pl.LightningModule):
    def __init__(self, size, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN'):
        super(Discriminator_VGG, self).__init__()

        conv_blocks = []
        conv_blocks.append(conv_block(  in_nc, base_nf, kernel_size=3, stride=1, norm_type=None, \
            act_type=act_type, mode=mode))
        conv_blocks.append(conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode))

        cur_size = size // 2
        cur_nc = base_nf
        while cur_size > 4:
            out_nc = cur_nc * 2 if cur_nc < 512 else cur_nc
            conv_blocks.append(conv_block(cur_nc, out_nc, kernel_size=3, stride=1, norm_type=norm_type, \
                act_type=act_type, mode=mode))
            conv_blocks.append(conv_block(out_nc, out_nc, kernel_size=4, stride=2, norm_type=norm_type, \
                act_type=act_type, mode=mode))
            cur_nc = out_nc
            cur_size //= 2

        self.features = sequential(*conv_blocks)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(cur_nc * cur_size * cur_size, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(cur_nc * cur_size * cur_size, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# VGG style Discriminator with input size 96*96
class Discriminator_VGG_96(pl.LightningModule):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN'):
        super(Discriminator_VGG_96, self).__init__()
        # features
        # hxw, c
        # 96, 64
        conv0 = conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 48, 64
        conv2 = conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 24, 128
        conv4 = conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 12, 256
        conv6 = conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 6, 512
        conv8 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 3, 512
        self.features = sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 3 * 3, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminator_VGG_128_SN(pl.LightningModule):
    def __init__(self):
        super(Discriminator_VGG_128_SN, self).__init__()
        # features
        # hxw, c
        # 128, 64
        self.lrelu = nn.LeakyReLU(0.2, True)

        self.conv0 = spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv1 = spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
        # 64, 64
        self.conv2 = spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv3 = spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
        # 32, 128
        self.conv4 = spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv5 = spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
        # 16, 256
        self.conv6 = spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
        self.conv7 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
        # 8, 512
        self.conv8 = spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
        self.conv9 = spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
        # 4, 512

        # classifier
        self.linear0 = spectral_norm(nn.Linear(512 * 4 * 4, 100))
        self.linear1 = spectral_norm(nn.Linear(100, 1))

    def forward(self, x):
        x = self.lrelu(self.conv0(x))
        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.conv2(x))
        x = self.lrelu(self.conv3(x))
        x = self.lrelu(self.conv4(x))
        x = self.lrelu(self.conv5(x))
        x = self.lrelu(self.conv6(x))
        x = self.lrelu(self.conv7(x))
        x = self.lrelu(self.conv8(x))
        x = self.lrelu(self.conv9(x))
        x = x.view(x.size(0), -1)
        x = self.lrelu(self.linear0(x))
        x = self.linear1(x)
        return x

# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(pl.LightningModule):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN'):
        super(Discriminator_VGG_128, self).__init__()
        # features
        # hxw, c
        # 128, 64
        conv0 = conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 64, 64
        conv2 = conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 32, 128
        conv4 = conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 16, 256
        conv6 = conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 8, 512
        conv8 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 4, 512
        self.features = sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 4 * 4, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


# VGG style Discriminator with input size 192*192
class Discriminator_VGG_192(pl.LightningModule): #vic in PPON is called Discriminator_192 
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN'):
        super(Discriminator_VGG_192, self).__init__()
        # features
        # hxw, c
        # 192, 64
        conv0 = conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode) # 3-->64
        conv1 = conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 64-->64, 96*96
        # 96, 64
        conv2 = conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 64-->128
        conv3 = conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 128-->128, 48*48
        # 48, 128
        conv4 = conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 128-->256
        conv5 = conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 256-->256, 24*24
        # 24, 256
        conv6 = conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 256-->512
        conv7 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 512-->512 12*12
        # 12, 512
        conv8 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 512-->512
        conv9 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 512-->512 6*6
        # 6, 512
        conv10 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv11 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 3*3
        # 3, 512
        self.features = sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9, conv10, conv11)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 3 * 3, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1)) #vic PPON uses 128 and 128 instead of 100
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# VGG style Discriminator with input size 256*256
class Discriminator_VGG_256(pl.LightningModule):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN'):
        super(Discriminator_VGG_256, self).__init__()
        # features
        # hxw, c
        # 256, 64
        conv0 = conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        conv1 = conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 128, 64
        conv2 = conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv3 = conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 64, 128
        conv4 = conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv5 = conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 32, 256
        conv6 = conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv7 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 16, 512
        conv8 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv9 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        # 8, 512
        conv10 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode)
        conv11 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode) # 3*3
        # 4, 512
        self.features = sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            conv9, conv10, conv11)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 4 * 4, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


####################
# Perceptual Network
####################


# Assume input range is [0, 1]
class VGGFeatureExtractor(pl.LightningModule):
    def __init__(self,
                 feature_layer=34,
                 use_bn=False,
                 use_input_norm=True,
                 device=torch.device('cpu'),
                 z_norm=False): #Note: PPON uses cuda instead of CPU
        super(VGGFeatureExtractor, self).__init__()
        if use_bn:
            model = torchvision.models.vgg19_bn(pretrained=True)
        else:
            model = torchvision.models.vgg19(pretrained=True)
        self.use_input_norm = use_input_norm
        if self.use_input_norm:
            if z_norm: # if input in range [-1,1]
                mean = torch.Tensor([0.485-1, 0.456-1, 0.406-1]).view(1, 3, 1, 1).to(device) 
                std = torch.Tensor([0.229*2, 0.224*2, 0.225*2]).view(1, 3, 1, 1).to(device)
            else: # input in range [0,1]
                mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)                 
                std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)
        self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = self.features(x)
        return output

# Assume input range is [0, 1]
class ResNet101FeatureExtractor(pl.LightningModule):
    def __init__(self, use_input_norm=True, device=torch.device('cpu'), z_norm=False):
        super(ResNet101FeatureExtractor, self).__init__()
        model = torchvision.models.resnet101(pretrained=True)
        self.use_input_norm = use_input_norm
        if self.use_input_norm:
            if z_norm: # if input in range [-1,1]
                mean = torch.Tensor([0.485-1, 0.456-1, 0.406-1]).view(1, 3, 1, 1).to(device)
                std = torch.Tensor([0.229*2, 0.224*2, 0.225*2]).view(1, 3, 1, 1).to(device)
            else: # input in range [0,1]
                mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
                std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)
        self.features = nn.Sequential(*list(model.children())[:8])
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = self.features(x)
        return output


class MINCNet(pl.LightningModule):
    def __init__(self):
        super(MINCNet, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
        self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
        self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)

    def forward(self, x):
        out = self.ReLU(self.conv11(x))
        out = self.ReLU(self.conv12(out))
        out = self.maxpool1(out)
        out = self.ReLU(self.conv21(out))
        out = self.ReLU(self.conv22(out))
        out = self.maxpool2(out)
        out = self.ReLU(self.conv31(out))
        out = self.ReLU(self.conv32(out))
        out = self.ReLU(self.conv33(out))
        out = self.maxpool3(out)
        out = self.ReLU(self.conv41(out))
        out = self.ReLU(self.conv42(out))
        out = self.ReLU(self.conv43(out))
        out = self.maxpool4(out)
        out = self.ReLU(self.conv51(out))
        out = self.ReLU(self.conv52(out))
        out = self.conv53(out)
        return out


# Assume input range is [0, 1]
class MINCFeatureExtractor(pl.LightningModule):
    def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
                device=torch.device('cpu')):
        super(MINCFeatureExtractor, self).__init__()

        self.features = MINCNet()
        self.features.load_state_dict(
            torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
        self.features.eval()
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        output = self.features(x)
        return output


#TODO
# moved from models.modules.architectures.ASRResNet_arch, did not bring the self-attention layer
# VGG style Discriminator with input size 128*128, with feature_maps extraction and self-attention
class Discriminator_VGG_128_fea(pl.LightningModule):
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
         arch='ESRGAN', spectral_norm=False, self_attention = False, max_pool=False, poolsize = 4):
        super(Discriminator_VGG_128_fea, self).__init__()
        # features
        # hxw, c
        # 128, 64
        
        # Self-Attention configuration
        '''#TODO
        self.self_attention = self_attention
        self.max_pool = max_pool
        self.poolsize = poolsize
        '''
        
        # Remove BatchNorm2d if using spectral_norm
        if spectral_norm:
            norm_type = None
        
        self.conv0 = conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
            mode=mode)
        self.conv1 = conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        # 64, 64
        self.conv2 = conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        self.conv3 = conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        # 32, 128
        self.conv4 = conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        self.conv5 = conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        # 16, 256
        
        '''#TODO
        if self.self_attention:
            self.FSA = SelfAttentionBlock(in_dim = base_nf*4, max_pool=self.max_pool, poolsize = self.poolsize, spectral_norm=spectral_norm)
        '''

        self.conv6 = conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        self.conv7 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        # 8, 512
        self.conv8 = conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        self.conv9 = conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm)
        # 4, 512
        # self.features = sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
            # conv9)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 4 * 4, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    #TODO: modify to a listening dictionary like VGG_Model(), can select what maps to use
    def forward(self, x, return_maps=False):
        feature_maps = []
        # x = self.features(x)
        x = self.conv0(x)
        feature_maps.append(x)
        x = self.conv1(x)
        feature_maps.append(x)
        x = self.conv2(x)
        feature_maps.append(x)
        x = self.conv3(x)
        feature_maps.append(x)
        x = self.conv4(x)
        feature_maps.append(x)
        x = self.conv5(x)
        feature_maps.append(x)
        x = self.conv6(x)
        feature_maps.append(x)
        x = self.conv7(x)
        feature_maps.append(x)
        x = self.conv8(x)
        feature_maps.append(x)
        x = self.conv9(x)
        feature_maps.append(x)
        
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        if return_maps:
            return [x, feature_maps]
        return x


class Discriminator_VGG_fea(pl.LightningModule):
    def __init__(self, size, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
         arch='ESRGAN', spectral_norm=False, self_attention = False, max_pool=False, poolsize = 4):
        super(Discriminator_VGG_fea, self).__init__()
        # features
        # hxw, c
        # 128, 64
        
        # Self-Attention configuration
        '''#TODO
        self.self_attention = self_attention
        self.max_pool = max_pool
        self.poolsize = poolsize
        '''
        
        # Remove BatchNorm2d if using spectral_norm
        if spectral_norm:
            norm_type = None

        self.conv_blocks = []
        self.conv_blocks.append(conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm))
        self.conv_blocks.append(conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
            act_type=act_type, mode=mode, spectral_norm=spectral_norm))

        cur_size = size // 2
        cur_nc = base_nf
        while cur_size > 4:
            out_nc = cur_nc * 2 if cur_nc < 512 else cur_nc
            self.conv_blocks.append(conv_block(cur_nc, out_nc, kernel_size=3, stride=1, norm_type=norm_type, \
                act_type=act_type, mode=mode, spectral_norm=spectral_norm))
            self.conv_blocks.append(conv_block(out_nc, out_nc, kernel_size=4, stride=2, norm_type=norm_type, \
                act_type=act_type, mode=mode, spectral_norm=spectral_norm))
            cur_nc = out_nc
            cur_size //= 2
        
        '''#TODO
        if self.self_attention:
            self.FSA = SelfAttentionBlock(in_dim = base_nf*4, max_pool=self.max_pool, poolsize = self.poolsize, spectral_norm=spectral_norm)
        '''

        # self.features = sequential(*conv_blocks)

        # classifier
        if arch=='PPON':
            self.classifier = nn.Sequential(
                nn.Linear(cur_nc * cur_size * cur_size, 128), nn.LeakyReLU(0.2, True), nn.Linear(128, 1))
        else: #arch='ESRGAN':
            self.classifier = nn.Sequential(
                nn.Linear(cur_nc * cur_size * cur_size, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

    #TODO: modify to a listening dictionary like VGG_Model(), can select what maps to use
    def forward(self, x, return_maps=False):
        feature_maps = []
        # x = self.features(x)
        for conv in self.conv_blocks:
            # Fixes incorrect device error
            device = x.device
            conv = conv.to(device)
            x = conv(x)
            feature_maps.append(x)
        
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        if return_maps:
            return [x, feature_maps]
        return x


class NLayerDiscriminator(pl.LightningModule):
    r"""
    PatchGAN discriminator
    https://arxiv.org/pdf/1611.07004v3.pdf
    https://arxiv.org/pdf/1803.07422.pdf

    """

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
        use_sigmoid=False, getIntermFeat=False, patch=True, use_spectral_norm=False):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int): the number of channels in input images
            ndf (int): the number of filters in the last conv layer
            n_layers (int): the number of conv layers in the discriminator
            norm_layer (nn.Module): normalization layer (if not using Spectral Norm)
            patch (bool): Select between an patch or a linear output
            use_spectral_norm (bool): Select if Spectral Norm will be used
        """
        super(NLayerDiscriminator, self).__init__()
        '''
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        '''

        if use_spectral_norm:
            # disable Instance or Batch Norm if using Spectral Norm
            norm_layer = Identity

        #self.getIntermFeat = getIntermFeat # not used for now
        #use_sigmoid not used for now
        #TODO: test if there are benefits by incorporating the use of intermediate features from pix2pixHD

        use_bias = False
        kw = 4
        padw = 1 # int(np.ceil((kw-1.0)/2))

        sequence = [add_spectral_norm(
                        nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 
                        use_spectral_norm), 
                    nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                add_spectral_norm(
                    nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 
                    use_spectral_norm),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            add_spectral_norm(
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 
                use_spectral_norm),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        if patch:
            # output patches as results
            sequence += [add_spectral_norm(
                nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), 
                use_spectral_norm)]  # output 1 channel prediction map
        else:
            # linear vector classification output
            sequence += [Mean([1, 2]), nn.Linear(ndf * nf_mult, 1)]
        
        if use_sigmoid:
            sequence += [nn.Sigmoid()]
        
        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        """Standard forward."""
        return self.model(x)


class MultiscaleDiscriminator(pl.LightningModule):
    r"""
    Multiscale PatchGAN discriminator
    https://arxiv.org/pdf/1711.11585.pdf

    """
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
                 use_sigmoid=False, num_D=3, getIntermFeat=False):
        """Construct a pyramid of PatchGAN discriminators
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
            use_sigmoid     -- boolean to use sigmoid in patchGAN discriminators
            num_D (int)     -- number of discriminators/downscales in the pyramid
            getIntermFeat   -- boolean to get intermediate features (unused for now)
        """
        super(MultiscaleDiscriminator, self).__init__()
        self.num_D = num_D
        self.n_layers = n_layers
        self.getIntermFeat = getIntermFeat
     
        for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
            if getIntermFeat:                                
                for j in range(n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def singleD_forward(self, model, input):
        if self.getIntermFeat:
            result = [input]
            for i in range(len(model)):
                result.append(model[i](result[-1]))
            return result[1:]
        else:
            return [model(input)]

    def forward(self, input):        
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
            else:
                model = getattr(self, 'layer'+str(num_D-1-i))
            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)
        return result


class PixelDiscriminator(pl.LightningModule):
    """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""

    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
        """Construct a 1x1 PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            norm_layer      -- normalization layer
        """
        super(PixelDiscriminator, self).__init__()
        '''
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        '''
        use_bias = False

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        """Standard forward."""
        return self.net(input)


"""
models.py (21-12-20)
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/context_encoder/models.py
"""

class context_encoder(pl.LightningModule):
    def __init__(self, channels=3):
        super(context_encoder, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)
        
    def forward(self, img):
        return self.model(img)


"""
discriminators.py (12-2-20)
https://github.com/JoeyBallentine/BasicSR/blob/resnet-discriminator/codes/models/modules/architectures/discriminators.py
"""
# Assume input range is [0, 1]
class ResNet101FeatureExtractor(pl.LightningModule):
    def __init__(self, use_input_norm=True, device=torch.device('cpu'), z_norm=False):
        super(ResNet101FeatureExtractor, self).__init__()
        model = torchvision.models.resnet101(pretrained=True)
        self.use_input_norm = use_input_norm
        if self.use_input_norm:
            if z_norm: # if input in range [-1,1]
                mean = torch.Tensor([0.485-1, 0.456-1, 0.406-1]).view(1, 3, 1, 1).to(device)
                std = torch.Tensor([0.229*2, 0.224*2, 0.225*2]).view(1, 3, 1, 1).to(device)
            else: # input in range [0,1]
                mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
                std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)
        self.features = nn.Sequential(*list(model.children())[:8])
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = self.features(x)
        return output


# ResNet50 style Discriminator with input size 128*128
class Discriminator_ResNet_128(pl.LightningModule):
    """
    Structure based off of the ResNet50 configuration from this repository:
    https://github.com/bentrevett/pytorch-image-classification
    """
    def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
        super(Discriminator_ResNet_128, self).__init__()
        # features
        # hxw, c

        self.in_channels = base_nf
        
        # 128, 3
        conv0 = conv_block(in_nc, self.in_channels, kernel_size=7, norm_type=norm_type, act_type=act_type, \
            mode=mode, stride=2, bias=False)
        pool0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # 32, 64

        layer1 = self.get_resnet_layer(Bottleneck, 3, base_nf) # 32, 64
        layer2 = self.get_resnet_layer(Bottleneck, 4, base_nf*2, stride = 2) # 16, 128
        layer3 = self.get_resnet_layer(Bottleneck, 6, base_nf*4, stride = 2) # 8, 256
        layer4 = self.get_resnet_layer(Bottleneck, 3, base_nf*8, stride = 2) # 4, 512

        avgpool = nn.AdaptiveAvgPool2d((1,1))

        self.features = sequential(conv0, pool0, layer1, layer2, layer3, layer4, avgpool)

        self.classifier = nn.Linear(self.in_channels, 1)

    def get_resnet_layer(self, block, n_blocks, channels, stride = 1):
    
        layers = []
        
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False
        
        layers.append(block(self.in_channels, channels, stride, downsample))
        
        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels
            
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class Bottleneck(pl.LightningModule):
    
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()
    
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, 
                               stride = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                               stride = stride, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size = 1,
                               stride = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)
        
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:
            conv = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(self.expansion * out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None
            
        self.downsample = downsample
        
    def forward(self, x):
        
        i = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
                
        if self.downsample is not None:
            i = self.downsample(i)
            
        x += i
        x = self.relu(x)
    
        return x






Inside ``CustomTrainClass`` it is possible to configure loss functions and weights. Configure logging path inside ``CustomTrainClass.py``. 

Warning: Don't use AMP with StyleLoss.

**Warning**: Certain combinations of discriminator and generator can result in crappy validation images. Test for a short while and make sure it isn't a solid color.

In [None]:
#@title CustomTrainClass.py
from vic.loss import CharbonnierLoss, GANLoss, GradientPenaltyLoss, HFENLoss, TVLoss, GradientLoss, ElasticLoss, RelativeL1, L1CosineSim, ClipL1, MaskedL1Loss, MultiscalePixelLoss, FFTloss, OFLoss, L1_regularization, ColorLoss, AverageLoss, GPLoss, CPLoss, SPL_ComputeWithTrace, SPLoss, Contextual_Loss, StyleLoss
from vic.perceptual_loss import PerceptualLoss
from metrics import *
from torchvision.utils import save_image
from torch.autograd import Variable

from tensorboardX import SummaryWriter
logdir='/content/'
writer = SummaryWriter(logdir=logdir)

from adamp import AdamP
#from adamp import SGDP

class CustomTrainClass(pl.LightningModule):
  def __init__(self):
    super().__init__()
    ############################
    # generators with one output, no AMP means nan loss during training
    self.netG = RRDBNet(in_nc=3, out_nc=3, nf=128, nb=8, gc=32, upscale=4, norm_type=None,
                act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
                finalact=None, gaussian_noise=True, plus=False, 
                nr=3)

    # DFNet
    #self.netG = DFNet(c_img=3, c_mask=1, c_alpha=3,
    #        mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
    #        en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3, 3, 3, 3, 3, 3, 3, 3],
    #        blend_layers=[0, 1, 2, 3, 4, 5], conv_type='partial')
    
    # AdaFill
    #self.netG = InpaintNet()

    # MEDFE (batch_size: 1, no AMP)
    #self.netG = MEDFEGenerator()

    # RFR
    # conv_type = partial or deform
    # Warning: One testrun with deform resulted in Nan errors after ~60k iterations. It is also very slow. 
    # 'partial' is recommended, since this is what the official implementation does use.
    #self.netG = RFRNet(conv_type='partial')

    # LBAM
    #self.netG = LBAMModel(inputChannels=4, outputChannels=3)

    # DMFN
    #self.netG = InpaintingGenerator(in_nc=4, out_nc=3,nf=64,n_res=8,
    #      norm='in', activation='relu')

    # partial
    #self.netG = Model()

    # RN
    #self.netG = G_Net(input_channels=3, residual_blocks=8, threshold=0.8)
    # using rn init to avoid errors
    #RN_arch = rn_initialize_weights(self.netG, scale=0.1)


    # DSNet
    #self.netG = DSNet(layer_size=8, input_channels=3, upsampling_mode='nearest')


    #DSNetRRDB
    #self.netG = DSNetRRDB(layer_size=8, input_channels=3, upsampling_mode='nearest',
    #            in_nc=4, out_nc=3, nf=128, nb=8, gc=32, upscale=1, norm_type=None,
    #            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
    #            finalact=None, gaussian_noise=True, plus=False, 
    #            nr=3)


    # DSNetDeoldify
    #self.netG = DSNetDeoldify()

    ############################

    # generators with two outputs

    # deepfillv1
    #self.netG = InpaintSANet()

    # deepfillv2
    # conv_type = partial or deform
    #self.netG = GatedGenerator(in_channels=4, out_channels=3, 
    #  latent_channels=64, pad_type='zero', activation='lrelu', norm='in', conv_type = 'partial')

    # Adaptive
    # [Warning] Adaptive does not like PatchGAN, Multiscale and ResNet.
    #self.netG = PyramidNet(in_channels=3, residual_blocks=1, init_weights='True')

    ############################
    # exotic generators

    # Pluralistic
    #self.netG = PluralisticGenerator(ngf_E=opt_net['ngf_E'], z_nc_E=opt_net['z_nc_E'], img_f_E=opt_net['img_f_E'], layers_E=opt_net['layers_E'], norm_E=opt_net['norm_E'], activation_E=opt_net['activation_E'],
    #            ngf_G=opt_net['ngf_G'], z_nc_G=opt_net['z_nc_G'], img_f_G=opt_net['img_f_G'], L_G=opt_net['L_G'], output_scale_G=opt_net['output_scale_G'], norm_G=opt_net['norm_G'], activation_G=opt_net['activation_G'])

    
    # EdgeConnect
    #conv_type_edge: 'normal' # normal | partial | deform (has no spectral_norm)
    #self.netG = EdgeConnectModel(residual_blocks_edge=8,
    #        residual_blocks_inpaint=8, use_spectral_norm=True,
    #        conv_type_edge='normal', conv_type_inpaint='normal')

    # FRRN
    #self.netG = FRRNet()

    # PRVS
    #self.netG = PRVSNet()

    # CSA
    #self.netG = InpaintNet(c_img=3, norm='instance', act_en='leaky_relu', 
    #                           act_de='relu')

    # deoldify
    #self.netG = Unet34()

    weights_init(self.netG, 'kaiming')
    ############################


    # discriminators
    # size refers to input shape of tensor

    self.netD = context_encoder()

    # VGG
    #self.netD = Discriminator_VGG(size=256, in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D', arch='ESRGAN')
    #self.netD = Discriminator_VGG_fea(size=256, in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA', convtype='Conv2D',
    #     arch='ESRGAN', spectral_norm=False, self_attention = False, max_pool=False, poolsize = 4)
    #self.netD = Discriminator_VGG_128_SN()
    #self.netD = VGGFeatureExtractor(feature_layer=34,use_bn=False,use_input_norm=True,device=torch.device('cpu'),z_norm=False)

    # PatchGAN
    #self.netD = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
    #    use_sigmoid=False, getIntermFeat=False, patch=True, use_spectral_norm=False)

    # Multiscale
    #self.netD = MultiscaleDiscriminator(input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 
    #             use_sigmoid=False, num_D=3, getIntermFeat=False)

    # ResNet
    #self.netD = Discriminator_ResNet_128(in_nc=3, base_nf=64, norm_type='batch', act_type='leakyrelu', mode='CNA')
    #self.netD = ResNet101FeatureExtractor(use_input_norm=True, device=torch.device('cpu'), z_norm=False)
    
    # MINC
    #self.netD = MINCNet()

    # Pixel
    #self.netD = PixelDiscriminator(input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d)

    # EfficientNet
    #from efficientnet_pytorch import EfficientNet
    #self.netD = EfficientNet.from_pretrained('efficientnet-b0')

    # ResNeSt
    # ["resnest50", "resnest101", "resnest200", "resnest269"]
    #self.netD = resnest50(pretrained=True)

    # need fixing
    #FileNotFoundError: [Errno 2] No such file or directory: '../experiments/pretrained_models/VGG16minc_53.pth'
    #self.netD = MINCFeatureExtractor(feature_layer=34, use_bn=False, use_input_norm=True, device=torch.device('cpu'))

    # Transformer (Warning: uses own init!)
    #self.netD  = TranformerDiscriminator(img_size=256, patch_size=1, in_chans=3, num_classes=1, embed_dim=64, depth=7,
    #             num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
    #             drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm)
    

    weights_init(self.netD, 'kaiming')


    # loss functions
    self.l1 = nn.L1Loss()
    l_hfen_type = L1CosineSim()
    self.HFENLoss = HFENLoss(loss_f=l_hfen_type, kernel='log', kernel_size=15, sigma = 2.5, norm = False)
    self.ElasticLoss = ElasticLoss(a=0.2, reduction='mean')
    self.RelativeL1 = RelativeL1(eps=.01, reduction='mean')
    self.L1CosineSim = L1CosineSim(loss_lambda=5, reduction='mean')
    self.ClipL1 = ClipL1(clip_min=0.0, clip_max=10.0)
    self.FFTloss = FFTloss(loss_f = torch.nn.L1Loss, reduction='mean')
    self.OFLoss = OFLoss()
    self.GPLoss = GPLoss(trace=False, spl_denorm=False)
    self.CPLoss = CPLoss(rgb=True, yuv=True, yuvgrad=True, trace=False, spl_denorm=False, yuv_denorm=False)
    self.StyleLoss = StyleLoss()
    self.TVLoss = TVLoss(tv_type='tv', p = 1)
    self.PerceptualLoss = PerceptualLoss(model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], model_path=None)
    layers_weights = {'conv_1_1': 1.0, 'conv_3_2': 1.0}
    self.Contextual_Loss = Contextual_Loss(layers_weights, crop_quarter=False, max_1d_size=100,
        distance_type = 'cosine', b=1.0, band_width=0.5,
        use_vgg = True, net = 'vgg19', calc_type = 'regular')

    self.MSELoss = torch.nn.MSELoss()
    self.L1Loss = nn.L1Loss()

    # metrics
    self.psnr_metric = PSNR()
    self.ssim_metric = SSIM()
    self.ae_metric = AE()
    self.mse_metric = MSE()


  def forward(self, image, masks):
      return self.netG(image, masks)

  #def adversarial_loss(self, y_hat, y):
  #    return F.binary_cross_entropy(y_hat, y)


  def training_step(self, train_batch, batch_idx):
      # inpainting:
      # train_batch[0][0] = batch_size
      # train_batch[0] = masked
      # train_batch[1] = mask
      # train_batch[2] = original

      # super resolution
      # train_batch[0] = lr
      # train_batch[1] = hr

      # train generator
      ############################
      # generate fake (1 output)
      #out = self(train_batch[0],train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      ############################
      # generate fake (2 outputs)
      #out, other_img = self(train_batch[0],train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      ############################
      # exotic generators
      # CSA
      #coarse_result, out, csa, csa_d = self(train_batch[0],train_batch[1])
      
      # EdgeConnect
      # train_batch[3] = edges
      # train_batch[4] = grayscale
      #out, other_img = self.netG(train_batch[0], train_batch[3], train_batch[4], train_batch[1])
      
      # PVRS
      #out, _ ,edge_small, edge_big = self.netG(train_batch[0], train_batch[1], train_batch[3])

      # FRRN
      #out, mid_x, mid_mask = self(train_batch[0], train_batch[1])

      # masking, taking original content from HR
      #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

      # deoldify
      #out = self.netG(train_batch[0])


      ############################
      # ESRGAN
      out = self.netG(train_batch[0])
      


      ############################
      # loss calculation
      total_loss = 0
      """
      HFENLoss_forward = self.HFENLoss(out, train_batch[0])
      total_loss += HFENLoss_forward
      ElasticLoss_forward = self.ElasticLoss(out, train_batch[0])
      total_loss += ElasticLoss_forward
      RelativeL1_forward = self.RelativeL1(out, train_batch[0])
      total_loss += RelativeL1_forward
      """
      L1CosineSim_forward = 5*self.L1CosineSim(out, train_batch[1])
      total_loss += L1CosineSim_forward
      #self.log('loss/L1CosineSim', L1CosineSim_forward)
      writer.add_scalar('loss/L1CosineSim', L1CosineSim_forward, self.trainer.global_step)

      """
      ClipL1_forward = self.ClipL1(out, train_batch[0])
      total_loss += ClipL1_forward
      FFTloss_forward = self.FFTloss(out, train_batch[0])
      total_loss += FFTloss_forward
      OFLoss_forward = self.OFLoss(out)
      total_loss += OFLoss_forward
      GPLoss_forward = self.GPLoss(out, train_batch[0])
      total_loss += GPLoss_forward
      
      CPLoss_forward = 0.1*self.CPLoss(out, train_batch[0])
      total_loss += CPLoss_forward
      

      Contextual_Loss_forward = self.Contextual_Loss(out, train_batch[0])
      total_loss += Contextual_Loss_forward
      self.log('loss/contextual', Contextual_Loss_forward)
      """

      #style_forward = 240*self.StyleLoss(out, train_batch[2])
      #total_loss += style_forward
      #self.log('loss/style', style_forward)

      tv_forward = 0.0000005*self.TVLoss(out)
      total_loss += tv_forward
      #self.log('loss/tv', tv_forward)
      writer.add_scalar('loss/tv', tv_forward, self.trainer.global_step)

      perceptual_forward = 2*self.PerceptualLoss(out, train_batch[1])
      total_loss += perceptual_forward
      #self.log('loss/perceptual', perceptual_forward)
      writer.add_scalar('loss/perceptual', perceptual_forward, self.trainer.global_step)







      #########################
      # exotic loss

      # if model has two output, also calculate loss for such an image
      # example with just l1 loss
      
      #l1_stage1 = self.L1Loss(other_img, train_batch[0])
      #self.log('loss/l1_stage1', l1_stage1)
      #total_loss += l1_stage1


      # CSA Loss
      """
      recon_loss = self.L1Loss(coarse_result, train_batch[2]) + self.L1Loss(out, train_batch[2])
      cons = ConsistencyLoss()
      cons_loss = cons(csa, csa_d, train_batch[2], train_batch[1])
      self.log('loss/recon_loss', recon_loss)
      total_loss += recon_loss
      self.log('loss/cons_loss', cons_loss)
      total_loss += cons_loss
      """

      # EdgeConnect
      # train_batch[3] = edges
      # train_batch[4] = grayscale
      #l1_edge = self.L1Loss(other_img, train_batch[3])
      #self.log('loss/l1_edge', l1_edge)
      #total_loss += l1_edge

      # PVRS
      """
      edge_big_l1 = self.L1Loss(edge_big, train_batch[3])
      edge_small_l1 = self.L1Loss(edge_small, torch.nn.functional.interpolate(train_batch[3], scale_factor = 0.5))
      self.log('loss/edge_big_l1', edge_big_l1)
      total_loss += edge_big_l1
      self.log('loss/edge_small_l1', edge_small_l1)
      total_loss += edge_small_l1
      """ 

      # FRRN
      """
      mid_l1_loss = 0
      for idx in range(len(mid_x) - 1):
          mid_l1_loss += self.L1Loss(mid_x[idx] * mid_mask[idx], train_batch[2] * mid_mask[idx])
      self.log('loss/mid_l1_loss', mid_l1_loss)
      total_loss += mid_l1_loss
      """

      #self.log('loss/g_loss', total_loss)
      writer.add_scalar('loss/g_loss', total_loss, self.trainer.global_step)

      #return total_loss
      #########################








      # train discriminator
      # resizing input if needed
      #train_batch[2] = torch.nn.functional.interpolate(train_batch[2], (128,128), align_corners=False, mode='bilinear')
      #out = torch.nn.functional.interpolate(out, (128,128), align_corners=False, mode='bilinear')

      Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
      valid = Variable(Tensor(out.shape).fill_(1.0), requires_grad=False)
      fake = Variable(Tensor(out.shape).fill_(0.0), requires_grad=False)
      dis_real_loss = self.MSELoss(train_batch[1], valid)
      dis_fake_loss = self.MSELoss(out, fake)

      d_loss = (dis_real_loss + dis_fake_loss) / 2
      #self.log('loss/d_loss', d_loss)
      writer.add_scalar('loss/d_loss', d_loss, self.trainer.global_step)

      return total_loss+d_loss

  def configure_optimizers(self):
      #optimizer = torch.optim.Adam(self.netG.parameters(), lr=2e-3)
      optimizer = AdamP(self.netG.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-2)
      #optimizer = SGDP(self.netG.parameters(), lr=0.1, weight_decay=1e-5, momentum=0.9, nesterov=True)
      return optimizer

  def validation_step(self, train_batch, train_idx):
    # inpainting
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path

    # super resolution
    # train_batch[0] = lr
    # train_batch[1] = hr
    # train_batch[2] = lr_path

    #########################
    # generate fake (one output generator)
    #out = self(train_batch[0],train_batch[1])
    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    #########################
    # generate fake (two output generator)
    #out, _ = self(train_batch[0],train_batch[1])

    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])
    #########################
    # CSA
    #_, out, _, _ = self(train_batch[0],train_batch[1])
    # masking, taking original content from HR
    #out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    # EdgeConnect
    # train_batch[3] = edges
    # train_batch[4] = grayscale
    #out, _ = self.netG(train_batch[0], train_batch[3], train_batch[4], train_batch[1])

    # PVRS
    #out, _ ,_, _ = self.netG(train_batch[0], train_batch[1], train_batch[3])

    # FRRN
    #out, _, _ = self(train_batch[0], train_batch[1])

    # deoldify
    #out = self.netG(train_batch[0])

    ############################
    # ESRGAN
    out = self.netG(train_batch[0])

    # Validation metrics work, but they need an origial source image.
    # Change dataloader to provide LR and HR if you want metrics.
    self.log('metrics/PSNR', self.psnr_metric(train_batch[1], out))
    self.log('metrics/SSIM', self.ssim_metric(train_batch[1], out))
    self.log('metrics/MSE', self.mse_metric(train_batch[1], out))
    self.log('metrics/LPIPS', self.PerceptualLoss(out, train_batch[1]))

    validation_output = '/content/validation_output/' #@param

    # train_batch[2] can contain multiple files, depending on the batch_size
    for f in train_batch[2]:
      # data is processed as a batch, to save indididual files, a counter is used
      counter = 0
      if not os.path.exists(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0])):
        os.makedirs(os.path.join(validation_output, os.path.splitext(os.path.basename(f))[0]))

      filename_with_extention = os.path.basename(f)
      filename = os.path.splitext(filename_with_extention)[0]

      save_image(out[counter], os.path.join(validation_output, filename, str(self.trainer.global_step) + '.png'))

      counter += 1

  def test_step(self, train_batch, train_idx):
    # inpainting
    # train_batch[0] = masked
    # train_batch[1] = mask
    # train_batch[2] = path

    # super resolution
    # train_batch[0] = lr
    # train_batch[1] = hr
    # train_batch[2] = lr_path
    test_output = '/content/test_output/' #@param
    if not os.path.exists(test_output):
      os.makedirs(test_output)

    out = self(train_batch[0].unsqueeze(0),train_batch[1].unsqueeze(0))
    out = train_batch[0]*(train_batch[1])+out*(1-train_batch[1])

    save_image(out, os.path.join(test_output, os.path.splitext(os.path.basename(train_batch[2]))[0] + '.png'))


# Inpaint Generators

Only run of these cells and confure further inside ``CustomTrainClass``. If you run more, it could maybe cause problems. You should restart the notebook once this happens.

Sidenote: Some files use ``.type(torch.cuda.FloatTensor)`` to avoid crashing. You could also try ``.type(torch.cuda.HalfTensor)``, but this is untested behaviour, but might help with AMP. ``[no AMP]`` indicates ``loss=nan`` if you actually try to use AMP.

With one output:

In [None]:
#@title [DSNet](https://github.com/wangning-001/DSNet) (2021)
"""
DSNet.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/DSNet.py

RegionNorm.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/RegionNorm.py

ValidMigration.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/ValidMigration.py

Attention.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/Attention.py

deform_conv.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/deform_conv.py
"""
#from modules.Attention import PixelContextualAttention
#from modules.RegionNorm import RBNModule, RCNModule
#from modules.ValidMigration import ConvOffset2D
#from modules.deform_conv import th_batch_map_offsets, th_generate_grid
from __future__ import absolute_import, division
from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates
from torch.autograd import Variable
from torchvision import models
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


def th_flatten(a):
    """Flatten tensor"""
    return a.contiguous().view(a.nelement())


def th_repeat(a, repeats, axis=0):
    """Torch version of np.repeat for 1D"""
    assert len(a.size()) == 1
    return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))


def np_repeat_2d(a, repeats):
    """Tensorflow version of np.repeat for 2D"""

    assert len(a.shape) == 2
    a = np.expand_dims(a, 0)
    a = np.tile(a, [repeats, 1, 1])
    return a


def th_gather_2d(input, coords):
    inds = coords[:, 0]*input.size(1) + coords[:, 1]
    x = torch.index_select(th_flatten(input), 0, inds)
    return x.view(coords.size(0))


def th_map_coordinates(input, coords, order=1):
    """Tensorflow verion of scipy.ndimage.map_coordinates
    Note that coords is transposed and only 2D is supported
    Parameters
    ----------
    input : tf.Tensor. shape = (s, s)
    coords : tf.Tensor. shape = (n_points, 2)
    """

    assert order == 1
    input_size = input.size(0)

    coords = torch.clamp(coords, 0, input_size - 1)
    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1)
    coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1)

    vals_lt = th_gather_2d(input,  coords_lt.detach())
    vals_rb = th_gather_2d(input,  coords_rb.detach())
    vals_lb = th_gather_2d(input,  coords_lb.detach())
    vals_rt = th_gather_2d(input,  coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())

    vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0]
    vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0]
    mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1]
    return mapped_vals


def sp_batch_map_coordinates(inputs, coords):
    """Reference implementation for batch_map_coordinates"""
    # coords = coords.clip(0, inputs.shape[1] - 1)

    assert (coords.shape[2] == 2)
    height = coords[:,:,0].clip(0, inputs.shape[1] - 1)
    width = coords[:,:,1].clip(0, inputs.shape[2] - 1)
    np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2)

    mapped_vals = np.array([
        sp_map_coordinates(input, coord.T, mode='nearest', order=1)
        for input, coord in zip(inputs, coords)
    ])
    return mapped_vals


def th_batch_map_coordinates(input, coords, order=1):
    """Batch version of th_map_coordinates
    Only supports 2D feature maps
    Parameters
    ----------
    input : tf.Tensor. shape = (b, s, s)
    coords : tf.Tensor. shape = (b, n_points, 2)
    Returns
    -------
    tf.Tensor. shape = (b, s, s)
    """

    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    n_coords = coords.size(1)

    # coords = torch.clamp(coords, 0, input_size - 1)

    coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2)

    assert (coords.size(1) == n_coords)

    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
    coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2)
    idx = th_repeat(torch.arange(0, batch_size), n_coords).long()
    idx = Variable(idx, requires_grad=False)
    if input.is_cuda:
        idx = idx.cuda()

    def _get_vals_by_coords(input, coords):
        indices = torch.stack([
            idx, th_flatten(coords[..., 0]), th_flatten(coords[..., 1])
        ], 1)
        inds = indices[:, 0]*input.size(1)*input.size(2)+ indices[:, 1]*input.size(2) + indices[:, 2]
        vals = th_flatten(input).index_select(0, inds)
        vals = vals.view(batch_size, n_coords)
        return vals

    vals_lt = _get_vals_by_coords(input, coords_lt.detach())
    vals_rb = _get_vals_by_coords(input, coords_rb.detach())
    vals_lb = _get_vals_by_coords(input, coords_lb.detach())
    vals_rt = _get_vals_by_coords(input, coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())
    vals_t = coords_offset_lt[..., 0]*(vals_rt - vals_lt) + vals_lt
    vals_b = coords_offset_lt[..., 0]*(vals_rb - vals_lb) + vals_lb
    mapped_vals = coords_offset_lt[..., 1]* (vals_b - vals_t) + vals_t
    return mapped_vals


def sp_batch_map_offsets(input, offsets):
    """Reference implementation for tf_batch_map_offsets"""

    batch_size = input.shape[0]
    input_height = input.shape[1]
    input_width = input.shape[2]

    offsets = offsets.reshape(batch_size, -1, 2)
    grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2)
    grid = np.repeat([grid], batch_size, axis=0)
    coords = offsets + grid
    # coords = coords.clip(0, input_size - 1)

    mapped_vals = sp_batch_map_coordinates(input, coords)
    return mapped_vals


def th_generate_grid(batch_size, input_height, input_width, dtype, cuda):
    grid = np.meshgrid(
        range(input_height), range(input_width), indexing='ij'
    )
    grid = np.stack(grid, axis=-1)
    grid = grid.reshape(-1, 2)

    grid = np_repeat_2d(grid, batch_size)
    grid = torch.from_numpy(grid).type(dtype)
    if cuda:
        grid = grid.cuda()
    return Variable(grid, requires_grad=False)


def th_batch_map_offsets(input, offsets, grid=None, order=1):
    """Batch map offsets into input
    Parameters
    ---------
    input : torch.Tensor. shape = (b, s, s)
    offsets: torch.Tensor. shape = (b, s, s, 2)
    Returns
    -------
    torch.Tensor. shape = (b, s, s)
    """
    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    offsets = offsets.view(batch_size, -1, 2)
    if grid is None:
        grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda)

    coords = offsets + grid

    mapped_vals = th_batch_map_coordinates(input, coords)
    return mapped_vals


class SEModule(pl.LightningModule):
    def __init__(self, num_channel, squeeze_ratio=1.0):
        super(SEModule, self).__init__()
        self.sequeeze_mod = nn.AdaptiveAvgPool2d(1)
        self.num_channel = num_channel

        blocks = [nn.Linear(num_channel, int(num_channel * squeeze_ratio)),
                  nn.ReLU(),
                  nn.Linear(int(num_channel * squeeze_ratio), num_channel),
                  nn.Sigmoid()]
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        ori = x
        x = self.sequeeze_mod(x)
        x = x.view(x.size(0), 1, self.num_channel)
        x = self.blocks(x)
        x = x.view(x.size(0), self.num_channel, 1, 1)
        x = ori * x
        return x


class ContextualAttentionModule(pl.LightningModule):

    def __init__(self, patch_size=3, propagate_size=3, stride=1):
        super(ContextualAttentionModule, self).__init__()
        self.patch_size = patch_size
        self.propagate_size = propagate_size
        self.stride = stride
        self.prop_kernels = None

    def forward(self, foreground, masks):
        ###assume the masked area has value 1
        bz, nc, w, h = foreground.size()
        if masks.size(3) != foreground.size(3):
            masks = F.interpolate(masks, foreground.size()[2:])
        background = foreground.clone()
        background = background * masks
        background = F.pad(background,
                           [self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2])
        conv_kernels_all = background.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size,
                                                                                     self.stride).contiguous().view(bz,
                                                                                                                    nc,
                                                                                                                    -1,
                                                                                                                    self.patch_size,
                                                                                                                    self.patch_size)
        conv_kernels_all = conv_kernels_all.transpose(2, 1)
        output_tensor = []
        for i in range(bz):
            mask = masks[i:i + 1]
            feature_map = foreground[i:i + 1].contiguous()
            # form convolutional kernels
            conv_kernels = conv_kernels_all[i] + 0.0000001
            norm_factor = torch.sum(conv_kernels ** 2, [1, 2, 3], keepdim=True) ** 0.5
            conv_kernels = conv_kernels / norm_factor

            conv_result = F.conv2d(feature_map, conv_kernels, padding=self.patch_size // 2)
            """
            if self.propagate_size != 1:
                if self.prop_kernels is None:
                    self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
                    self.prop_kernels.requires_grad = False
                    self.prop_kernels = self.prop_kernels.cuda()
                conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))

            """

            self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
            self.prop_kernels.requires_grad = False
            self.prop_kernels = self.prop_kernels.cuda()
            conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))
            
            attention_scores = F.softmax(conv_result, dim=1)
            ##propagate the scores
            recovered_foreground = F.conv_transpose2d(attention_scores, conv_kernels, stride=1,
                                                      padding=self.patch_size // 2)
            # average the recovered value, at the same time make non-masked area 0
            recovered_foreground = (recovered_foreground * (1 - mask)) / (self.patch_size ** 2)
            # recover the image
            final_output = recovered_foreground + feature_map * mask
            output_tensor.append(final_output)
        return torch.cat(output_tensor, dim=0)


class PixelContextualAttention(pl.LightningModule):

    def __init__(self, inchannel, patch_size_list=[1], propagate_size_list=[3], stride_list=[1]):
        assert isinstance(patch_size_list,
                          list), "patch_size should be a list containing scales, or you should use Contextual Attention to initialize your module"
        assert len(patch_size_list) == len(propagate_size_list) and len(propagate_size_list) == len(
            stride_list), "the input_lists should have same lengths"
        super(PixelContextualAttention, self).__init__()
        for i in range(len(patch_size_list)):
            name = "CA_{:d}".format(i)
            setattr(self, name, ContextualAttentionModule(patch_size_list[i], propagate_size_list[i], stride_list[i]))
        self.num_of_modules = len(patch_size_list)
        self.SqueezeExc = SEModule(inchannel * 2)
        self.combiner = nn.Conv2d(inchannel * 2, inchannel, kernel_size=1)

    def forward(self, foreground, mask):
        outputs = [foreground]
        for i in range(self.num_of_modules):
            name = "CA_{:d}".format(i)
            CA_module = getattr(self, name)
            outputs.append(CA_module(foreground, mask))
        outputs = torch.cat(outputs, dim=1)
        outputs = self.SqueezeExc(outputs)
        outputs = self.combiner(outputs)
        return outputs




class ConvOffset2D(nn.Conv2d):
    """ConvOffset2D

    Convolutional layer responsible for learning the 2D offsets and output the
    deformed feature map using bilinear interpolation

    Note that this layer does not perform convolution on the deformed feature
    map. See get_deform_cnn in cnn.py for usage
    """
    def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
        """Init

        Parameters
        ----------
        filters : int
            Number of channel of the input feature map
        init_normal_stddev : float
            Normal kernel initialization
        **kwargs:
            Pass to superclass. See Con2d layer in pytorch
        """
        self.filters = filters
        self._grid_param = None
        super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs)
        self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))

    def forward(self, x):
        """Return the deformed featured map"""
        x_shape = x.size()
        offsets = super(ConvOffset2D, self).forward(x)

        # offsets: (b*c, h, w, 2)
        offsets = self._to_bc_h_w_2(offsets, x_shape)

        # x: (b*c, h, w)
        x = self._to_bc_h_w(x, x_shape)

        # X_offset: (b*c, h, w)
        x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x))

        # x_offset: (b, h, w, c)
        x_offset = self._to_b_c_h_w(x_offset, x_shape)

        return x_offset

    @staticmethod
    def _get_grid(self, x):
        batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
        dtype, cuda = x.data.type(), x.data.is_cuda
        if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
            return self._grid
        self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
        self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
        return self._grid

    @staticmethod
    def _init_weights(weights, std):
        fan_out = weights.size(0)
        fan_in = weights.size(1) * weights.size(2) * weights.size(3)
        w = np.random.normal(0.0, std, (fan_out, fan_in))
        return torch.from_numpy(w.reshape(weights.size()))

    @staticmethod
    def _to_bc_h_w_2(x, x_shape):
        """(b, 2c, h, w) -> (b*c, h, w, 2)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2)
        return x

    @staticmethod
    def _to_bc_h_w(x, x_shape):
        """(b, c, h, w) -> (b*c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
        return x

    @staticmethod
    def _to_b_c_h_w(x, x_shape):
        """(b*c, h, w) -> (b, c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))
        return x





class RBNModule(pl.LightningModule):
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                     'running_mean', 'running_var', 'num_batches_tracked']

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, track_running_stats=True):
        super(RBNModule, self).__init__()
        self.num_features = num_features
        self.track_running_stats = track_running_stats
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
            self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, mask_t):
        input_m = input * mask_t
        if self.training:
            mask_mean = torch.mean(mask_t, (0, 2, 3), True)
            x_mean = torch.mean(input_m, (0, 2, 3), True) / mask_mean
            x_var = torch.mean(((input_m - x_mean) * mask_t) ** 2, (0, 2, 3), True) / mask_mean

            x_out = self.weight * (input_m - x_mean) / torch.sqrt(x_var + self.eps) + self.bias

            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * x_mean.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * x_var.data)
        else:
            x_out = self.weight * (input_m - self.running_mean) / torch.sqrt(self.running_var + self.eps) + self.bias
        return x_out * mask_t + input * (1 - mask_t)


class RCNModule(pl.LightningModule):
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                     'running_mean', 'running_var', 'num_batches_tracked']

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, track_running_stats=True):
        super(RCNModule, self).__init__()
        self.num_features = num_features
        self.track_running_stats = track_running_stats
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
            self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, mask_t):
        input_m = input * mask_t

        if self.training:
            mask_mean_bn = torch.mean(mask_t, (0, 2, 3), True)
            mean_bn = torch.mean(input_m, (0, 2, 3), True) / mask_mean_bn
            var_bn = torch.mean(((input_m - mean_bn) * mask_t) ** 2, (0, 2, 3), True) / mask_mean_bn

            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * mean_bn.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        mask_mean_in = torch.mean(mask_t, (2, 3), True)
        mean_in = torch.mean(input_m, (2, 3), True) / mask_mean_in
        var_in = torch.mean(((input_m - mean_in) * mask_t) ** 2, (2, 3), True) / mask_mean_in

        mask_mean_ln = torch.mean(mask_t, (1, 2, 3), True)
        mean_ln = torch.mean(input_m, (1, 2, 3), True) / mask_mean_ln
        var_ln = torch.mean(((input_m - mean_ln) * mask_t) ** 2, (1, 2, 3), True) / mask_mean_ln

        mean_weight = F.softmax(self.mean_weight)
        var_weight = F.softmax(self.var_weight)

        x_mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
        x_var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn

        x_out = self.weight * (input_m - x_mean) / torch.sqrt(x_var + self.eps) + self.bias
        return x_out * mask_t + input * (1 - mask_t)


class DSModule(pl.LightningModule):
    def __init__(self, in_ch, out_ch, bn=False, rn=True, sample='none-3', activ='relu',
                 conv_bias=False, defor=True):
        super().__init__()
        if sample == 'down-5':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 5, 2, 2, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(5,2,2)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        elif sample == 'down-7':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 7, 2, 3, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(7, 2, 3)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        elif sample == 'down-3':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 3, 2, 1, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(3, 2, 1)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        else:
            self.conv = nn.Conv2d(in_ch+2, out_ch, 3, 1, 1, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(3,1,1)
            if defor:
                self.offset0 = ConvOffset2D(in_ch-out_ch+1)
                self.offset1 = ConvOffset2D(out_ch+1)
        self.in_ch = in_ch
        self.out_ch = out_ch

        if bn:
            self.bn = nn.BatchNorm2d(out_ch)
        if rn:
            # Regional Composite Normalization
            self.rn = RCNModule(out_ch)

            # Regional Batch Normalization
            # self.rn = RBNModule(out_ch)
        if activ == 'relu':
            self.activation = nn.ReLU(inplace = True)
        elif activ == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2, inplace = True)

    def forward(self, input, input_mask):
        if hasattr(self, 'offset'):
            input = torch.cat([input, input_mask[:,:1,:,:]], dim = 1)
            h = self.offset(input)
            h = input*input_mask[:,:1,:,:] + (1-input_mask[:,:1,:,:])*h
            h = self.conv(h)
            h_mask = self.updatemask(input_mask[:,:1,:,:])
            h = h*h_mask
            h = self.rn(h, h_mask)
        elif hasattr(self, 'offset0'):
            h1_in = torch.cat([input[:,self.in_ch-self.out_ch:,:,:], input_mask[:,1:,:,:]], dim = 1)
            m1_in = input_mask[:,1:,:,:]
            h0 = torch.cat([input[:,:self.in_ch-self.out_ch,:,:], input_mask[:,:1,:,:]], dim = 1)
            h1 = self.offset1(h1_in)
            h1 = m1_in*h1_in + (1-m1_in)*h1
            h = self.conv(torch.cat([h0,h1], dim = 1))
            h = self.rn(h, input_mask[:,:1,:,:])
            h_mask = F.interpolate(input_mask[:,:1,:,:], scale_factor=2, mode='nearest')
        else:
            h = self.conv(torch.cat([input, input_mask[:,:,:,:]], dim = 1))
            h_mask = self.updatemask(input_mask[:,:1,:,:])
            h = h*h_mask

        if hasattr(self, 'bn'):
            h = self.bn(h)
        if hasattr(self, 'activation'):
            h = self.activation(h)
        return h, h_mask


class DSNet(pl.LightningModule):
    def __init__(self, layer_size=8, input_channels=3, upsampling_mode='nearest'):
        super().__init__()
        self.freeze_enc_bn = False
        self.upsampling_mode = upsampling_mode
        self.layer_size = layer_size
        self.enc_1 = DSModule(input_channels, 64, rn=False, sample='down-7', defor = False)
        self.enc_2 = DSModule(64, 128, sample='down-5')
        self.enc_3 = DSModule(128, 256, sample='down-5')
        self.enc_4 = DSModule(256, 512, sample='down-3')
        for i in range(4, self.layer_size):
            name = 'enc_{:d}'.format(i + 1)
            setattr(self, name, DSModule(512, 512, sample='down-3'))

        for i in range(4, self.layer_size):
            name = 'dec_{:d}'.format(i + 1)
            setattr(self, name, DSModule(512 + 512, 512, activ='leaky'))
        self.dec_4 = DSModule(512 + 256, 256, activ='leaky')
        self.dec_3 = DSModule(256 + 128, 128, activ='leaky')
        self.dec_2 = DSModule(128 + 64, 64, activ='leaky')
        self.dec_1 = DSModule(64 + input_channels, input_channels,
                              rn=False, activ=None, defor = False)
        self.att = PixelContextualAttention(128)
    def forward(self, input, input_mask):
        input = input.type(torch.cuda.FloatTensor)
        input_mask = input_mask.type(torch.cuda.FloatTensor)

        input_mask = input_mask[:,0:1,:,:]
        h_dict = {}  # for the output of enc_N
        h_mask_dict = {}  # for the output of enc_N

        h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask

        h_key_prev = 'h_0'
        for i in range(1, self.layer_size + 1):
            l_key = 'enc_{:d}'.format(i)
            h_key = 'h_{:d}'.format(i)
            h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
                h_dict[h_key_prev], h_mask_dict[h_key_prev])
            h_key_prev = h_key

        h_key = 'h_{:d}'.format(self.layer_size)
        h, h_mask = h_dict[h_key], h_mask_dict[h_key]
        h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')

        for i in range(self.layer_size, 0, -1):
            enc_h_key = 'h_{:d}'.format(i - 1)
            dec_l_key = 'dec_{:d}'.format(i)

            h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)

            h = torch.cat([h, h_dict[enc_h_key]], dim=1)
            h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
            h, h_mask = getattr(self, dec_l_key)(h, h_mask)
            if i == 3:
                h = self.att(h, input_mask[:,:1,:,:])
        #return h, h_mask
        return h

In [None]:
#@title [AdaFill_arch.py](https://github.com/ChajinShin/AdaFill-Image_Inpainting) (2021)
"""
network.py (3-2-20)
https://github.com/ChajinShin/AdaFill-Image_Inpainting/blob/main/Model/AdaFill/src/network.py
"""
import torch
import torch.nn as nn
import pytorch_lightning as pl

class ResModule(pl.LightningModule):
    def __init__(self, num_features, normalization):
        super(ResModule, self).__init__()
        self.block = nn.Sequential(
            normalization(num_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1),
            normalization(num_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)



class InpaintNet(pl.LightningModule):
    def __init__(self):
        super(InpaintNet, self).__init__()
        cnum = 64
        num_of_resblock = 8
        normalization_config = 'batch_norm'

        if normalization_config == 'batch_norm':
            normalization = nn.BatchNorm2d
        elif normalization_config == 'instance_norm':
            normalization = nn.InstanceNorm2d
        else:
            raise ValueError("batch normalization or instance normalization is only available")

        self.enc_conv1 = nn.Conv2d(in_channels=4, out_channels=cnum, kernel_size=3, stride=1, padding=1)  # cnum, 128, 128
        self.enc_norm1 = normalization(num_features=cnum)
        self.enc_activation1 = nn.ReLU(inplace=True)
        self.enc_conv2 = nn.Conv2d(in_channels=cnum, out_channels=2*cnum, kernel_size=3, stride=2, padding=1)  # 2cnum, 64, 64
        self.enc_norm2 = normalization(num_features=2*cnum)
        self.enc_activation2 = nn.ReLU(inplace=True)
        self.enc_conv3 = nn.Conv2d(in_channels=2*cnum, out_channels=4*cnum, kernel_size=3, stride=2, padding=1)  # 4cnum, 32, 32

        res = [ResModule(4*cnum, normalization) for _ in range(num_of_resblock)]
        self.res_module = nn.Sequential(
            *res
        )

        self.dec_norm1 = normalization(4*cnum)
        self.dec_activation1 = nn.ReLU(inplace=True)
        self.dec_conv1 = nn.Conv2d(in_channels=4*cnum, out_channels=2*cnum, kernel_size=3, stride=1, padding=1)    # 2*cnum, 64, 64
        self.dec_norm2 = normalization(num_features=2*cnum)
        self.dec_activation2 = nn.ReLU(inplace=True)
        self.dec_conv2 = nn.Conv2d(in_channels=2*cnum, out_channels=cnum, kernel_size=3, stride=1, padding=1)   # cnum, 128, 128
        self.dec_norm3 = normalization(num_features=cnum)
        self.dec_activation3 = nn.ReLU(inplace=True)
        self.dec_conv3 = nn.Conv2d(in_channels=cnum, out_channels=3, kernel_size=3, stride=1, padding=1)    # 3, 128, 128
        self.tanh = nn.Tanh()

    def forward(self, image, mask):
        x = torch.cat((image, mask), 1)
        # --------- encoder ----------------
        x = self.enc_conv1(x)
        x = self.enc_norm1(x)
        x = self.enc_activation1(x)
        size_1x = [x.size(2), x.size(3)]

        x = self.enc_conv2(x)
        x = self.enc_norm2(x)
        x = self.enc_activation2(x)
        size_2x = [x.size(2), x.size(3)]

        x = self.enc_conv3(x)

        # --------- res module ----------------
        x = self.res_module(x)

        # --------- decoder ----------------
        x = self.dec_norm1(x)
        x = self.dec_activation1(x)
        x = nn.functional.interpolate(x, size=size_2x)
        x = self.dec_conv1(x)

        x = self.dec_norm2(x)
        x = self.dec_activation2(x)
        x = nn.functional.interpolate(x, size=size_1x)
        x = self.dec_conv2(x)

        x = self.dec_norm3(x)
        x = self.dec_activation3(x)
        x = self.dec_conv3(x)
        x = self.tanh(x)
        return x


In [None]:
#@title [MEDFE_arch.py](https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE) (2020)
"""
Encoder.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/9adf8898a142784976bb3e162a9fd864c224e01e/models/Encoder.py

Decoder.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/9adf8898a142784976bb3e162a9fd864c224e01e/models/Decoder.py

networks.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/9adf8898a142784976bb3e162a9fd864c224e01e/models/networks.py

MEDFE.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/master/models/MEDFE.py

PCconv.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/dd838b01d9786dc2c67de5d71869e5a60da28eb9/models/PCconv.py

Selfpatch.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/dd838b01d9786dc2c67de5d71869e5a60da28eb9/util/Selfpatch.py

util.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/dd838b01d9786dc2c67de5d71869e5a60da28eb9/util/util.py

InnerCos.py (25-12-20)
https://github.com/KumapowerLIU/Rethinking-Inpainting-MEDFE/blob/c7156eab4a9890888fa86e641cd685e21b78c31e/models/InnerCos.py
"""


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from torch.autograd import Variable
import collections
import inspect, re
import math
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.utils import save_image
import pytorch_lightning as pl

class InnerCos(pl.LightningModule):
    def __init__(self):
        super(InnerCos, self).__init__()
        self.criterion = nn.L1Loss()
        self.target = None
        self.down_model = nn.Sequential(
            nn.Conv2d(256, 3, kernel_size=1,stride=1, padding=0),
            nn.Tanh()
        )

    def set_target(self, targetde, targetst):
        self.targetst = F.interpolate(targetst, size=(32, 32), mode='bilinear')
        self.targetde = F.interpolate(targetde, size=(32, 32), mode='bilinear')

    def get_target(self):
        return self.target

    def forward(self, in_data):
        loss_co = in_data[1]
        self.ST = self.down_model(loss_co[0])
        self.DE = self.down_model(loss_co[1])
        #self.loss = self.criterion(self.ST, self.targetst)+self.criterion(self.DE, self.targetde)
        self.output = in_data[0]
        return self.output

    def backward(self, retain_graph=True):

        self.loss.backward(retain_graph=retain_graph)
        return self.loss

    def __repr__(self):

        return self.__class__.__name__

# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor[0].cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3,1,1))
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def diagnose_network(net, name='network'):
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)

def binary_mask(in_mask, threshold):
    assert in_mask.dim() == 2, "mask must be 2 dimensions"

    output = torch.ByteTensor(in_mask.size())
    output = (output > threshold).float().mul_(1)

    return output

def gussin(v):
    outk = []
    v = v
    for i in range(32):
        for k in range(32):

            out = []
            for x in range(32):
                row = []
                for y in range(32):
                    cord_x = i
                    cord_y = k
                    dis_x = np.abs(x - cord_x)
                    dis_y = np.abs(y - cord_y)
                    dis_add = -(dis_x * dis_x + dis_y * dis_y)
                    dis_add = dis_add / (2 * v * v)
                    dis_add = math.exp(dis_add) / (2 * math.pi * v * v)

                    row.append(dis_add)
                out.append(row)

            outk.append(out)

    out = np.array(outk)
    f = out.sum(-1).sum(-1)
    q = []
    for i in range(1024):
        g = out[i] / f[i]
        q.append(g)
    out = np.array(q)
    return torch.from_numpy(out)

def cal_feat_mask(inMask, conv_layers, threshold):
    assert inMask.dim() == 4, "mask must be 4 dimensions"
    assert inMask.size(0) == 1, "the first dimension must be 1 for mask"
    inMask = inMask.float()
    convs = []
    inMask = Variable(inMask, requires_grad = False)
    for id_net in range(conv_layers):
        conv = nn.Conv2d(1,1,4,2,1, bias=False)
        conv.weight.data.fill_(1/16)
        convs.append(conv)
    lnet = nn.Sequential(*convs)
    if inMask.is_cuda:

        lnet = lnet.cuda()
    output = lnet(inMask)
    output = (output > threshold).float().mul_(1)

    return output

def cal_mask_given_mask_thred(img, mask, patch_size, stride, mask_thred):
    assert img.dim() == 3, 'img has to be 3 dimenison!'
    assert mask.dim() == 2, 'mask has to be 2 dimenison!'
    dim = img.dim()
    #math.floor 是向下取整
    _, H, W = img.size(dim-3), img.size(dim-2), img.size(dim-1)
    nH = int(math.floor((H-patch_size)/stride + 1))
    nW = int(math.floor((W-patch_size)/stride + 1))
    N = nH*nW

    flag = torch.zeros(N).long()
    offsets_tmp_vec = torch.zeros(N).long()
    #返回的是一个list类型的数据

    nonmask_point_idx_all = torch.zeros(N).long()

    tmp_non_mask_idx = 0


    mask_point_idx_all = torch.zeros(N).long()

    tmp_mask_idx = 0
    #所有的像素点都浏览一遍
    for i in range(N):
        h = int(math.floor(i/nW))
        w = int(math.floor(i%nW))
        # print(h, w)
        #截取一个个1×1的小方片
        mask_tmp = mask[h*stride:h*stride + patch_size,
                        w*stride:w*stride + patch_size]


        if torch.sum(mask_tmp) < mask_thred:
            nonmask_point_idx_all[tmp_non_mask_idx] = i
            tmp_non_mask_idx += 1
        else:
            mask_point_idx_all[tmp_mask_idx] = i
            tmp_mask_idx += 1
            flag[i] = 1
            offsets_tmp_vec[i] = -1
    # print(flag)  #checked
    # print(offsets_tmp_vec) # checked

    non_mask_num = tmp_non_mask_idx
    mask_num = tmp_mask_idx

    nonmask_point_idx = nonmask_point_idx_all.narrow(0, 0, non_mask_num)
    mask_point_idx=mask_point_idx_all.narrow(0, 0, mask_num)

    # get flatten_offsets
    flatten_offsets_all = torch.LongTensor(N).zero_()
    for i in range(N):
        offset_value = torch.sum(offsets_tmp_vec[0:i+1])
        if flag[i] == 1:
            offset_value = offset_value + 1
        # print(i+offset_value)
        flatten_offsets_all[i+offset_value] = -offset_value

    flatten_offsets = flatten_offsets_all.narrow(0, 0, non_mask_num)

    # print('flatten_offsets')
    # print(flatten_offsets)   # checked


    # print('nonmask_point_idx')
    # print(nonmask_point_idx)  #checked

    return flag, nonmask_point_idx, flatten_offsets, mask_point_idx


# sp_x: LongTensor
# sp_y: LongTensor
def cal_sps_for_Advanced_Indexing(h, w):
    sp_y = torch.arange(0, w).long()
    sp_y = torch.cat([sp_y]*h)

    lst = []
    for i in range(h):
        lst.extend([i]*w)
    sp_x = torch.from_numpy(np.array(lst))
    return sp_x, sp_y

"""
def save_image(image_numpy, image_path):
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)
"""
def info(object, spacing=10, collapse=1):
    """Print methods and doc strings.
    Takes module, class, list, dictionary, or string."""
    methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
    processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
    print( "\n".join(["%s %s" %
                     (method.ljust(spacing),
                      processFunc(str(getattr(object, method).__doc__)))
                     for method in methodList]) )

def varname(p):
    for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
        m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
        if m:
            return m.group(1)

def print_numpy(x, val=True, shp=False):
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)



class Selfpatch(object):
    def buildAutoencoder(self, target_img, target_img_2, target_img_3, patch_size=1, stride=1):
        nDim = 3
        assert target_img.dim() == nDim, 'target image must be of dimension 3.'
        C = target_img.size(0)

        self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor

        patches_features = self._extract_patches(target_img, patch_size, stride)
        patches_features_f = self._extract_patches(target_img_3, patch_size, stride)

        patches_on = self._extract_patches(target_img_2, 1, stride)

        return patches_features_f, patches_features, patches_on

    def build(self, target_img,  patch_size=5, stride=1):
        nDim = 3
        assert target_img.dim() == nDim, 'target image must be of dimension 3.'
        C = target_img.size(0)

        self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.Tensor

        patches_features = self._extract_patches(target_img, patch_size, stride)

        return patches_features

    def _build(self, patch_size, stride, C, target_patches, npatches, normalize, interpolate, type):
        # for each patch, divide by its L2 norm.
        if type == 1:
            enc_patches = target_patches.clone()
            for i in range(npatches):
                enc_patches[i] = enc_patches[i]*(1/(enc_patches[i].norm(2)+1e-8))

            conv_enc = nn.Conv2d(npatches, npatches, kernel_size=1, stride=stride, bias=False, groups=npatches)
            conv_enc.weight.data = enc_patches
            return conv_enc

        # normalize is not needed, it doesn't change the result!
            if normalize:
                raise NotImplementedError

            if interpolate:
                raise NotImplementedError
        else:

            conv_dec = nn.ConvTranspose2d(npatches, C, kernel_size=patch_size, stride=stride, bias=False)
            conv_dec.weight.data = target_patches
            return conv_dec

    def _extract_patches(self, img, patch_size, stride):
        n_dim = 3
        assert img.dim() == n_dim, 'image must be of dimension 3.'
        kH, kW = patch_size, patch_size
        dH, dW = stride, stride
        input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)
        i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(2), input_windows.size(3), input_windows.size(4)
        input_windows = input_windows.permute(1,2,0,3,4).contiguous().view(i_2*i_3, i_1, i_4, i_5)
        patches_all = input_windows
        return patches_all




# SE MODEL
class SELayer(pl.LightningModule):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c, 1, 1)
        y = self.fc(y)
        return x * y.expand_as(x)


class Convnorm(pl.LightningModule):
    def __init__(self, in_ch, out_ch, sample='none-3', activ='leaky'):
        super().__init__()
        self.bn = nn.InstanceNorm2d(out_ch, affine=True)

        if sample == 'down-3':
            self.conv = nn.Conv2d(in_ch, out_ch, 3, 2, 1, bias=False)
        else:
            self.conv = nn.Conv2d(in_ch, out_ch, 3, 1)
        if activ == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, input):
        out = input
        out = self.conv(out)
        out = self.bn(out)
        if hasattr(self, 'activation'):
            out = self.activation(out[0])
        return out


class PCBActiv(pl.LightningModule):
    def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='leaky',
                 conv_bias=False, innorm=False, inner=False, outer=False):
        super().__init__()
        if sample == 'same-5':
            self.conv = PartialConv(in_ch, out_ch, 5, 1, 2, bias=conv_bias)
        elif sample == 'same-7':
            self.conv = PartialConv(in_ch, out_ch, 7, 1, 3, bias=conv_bias)
        elif sample == 'down-3':
            self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias)
        else:
            self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias)

        if bn:
            self.bn = nn.InstanceNorm2d(out_ch, affine=True)
        if activ == 'relu':
            self.activation = nn.ReLU()
        elif activ == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2)
        self.innorm = innorm
        self.inner = inner
        self.outer = outer

    def forward(self, input):
        out = input
        if self.inner:
            out[0] = self.bn(out[0])
            out[0] = self.activation(out[0])
            out = self.conv(out)
            out[0] = self.bn(out[0])
            out[0] = self.activation(out[0])

        elif self.innorm:
            out = self.conv(out)
            out[0] = self.bn(out[0])
            out[0] = self.activation(out[0])
        elif self.outer:
            out = self.conv(out)
            out[0] = self.bn(out[0])
        else:
            out = self.conv(out)
            out[0] = self.bn(out[0])
            if hasattr(self, 'activation'):
                out[0] = self.activation(out[0])
        return out


class ConvDown(pl.LightningModule):
    def __init__(self, in_c, out_c, kernel, stride, padding=0, dilation=1, groups=1, bias=False, layers=1, activ=True):
        super().__init__()
        nf_mult = 1
        nums = out_c / 64
        sequence = []

        for i in range(1, layers + 1):
            nf_mult_prev = nf_mult
            if nums == 8:
                if in_c == 512:

                    nfmult = 1
                else:
                    nf_mult = 2

            else:
                nf_mult = min(2 ** i, 8)
            if kernel != 1:

                if activ == False and layers == 1:
                    sequence += [
                        nn.Conv2d(nf_mult_prev * in_c, nf_mult * in_c,
                                  kernel_size=kernel, stride=stride, padding=padding, bias=bias),
                        nn.InstanceNorm2d(nf_mult * in_c)
                    ]
                else:
                    sequence += [
                        nn.Conv2d(nf_mult_prev * in_c, nf_mult * in_c,
                                  kernel_size=kernel, stride=stride, padding=padding, bias=bias),
                        nn.InstanceNorm2d(nf_mult * in_c),
                        nn.LeakyReLU(0.2, True)
                    ]

            else:

                sequence += [
                    nn.Conv2d(in_c, out_c,
                              kernel_size=kernel, stride=stride, padding=padding, bias=bias),
                    nn.InstanceNorm2d(out_c),
                    nn.LeakyReLU(0.2, True)
                ]

            if activ == False:
                if i + 1 == layers:
                    if layers == 2:
                        sequence += [
                            nn.Conv2d(nf_mult * in_c, nf_mult * in_c,
                                      kernel_size=kernel, stride=stride, padding=padding, bias=bias),
                            nn.InstanceNorm2d(nf_mult * in_c)
                        ]
                    else:
                        sequence += [
                            nn.Conv2d(nf_mult_prev * in_c, nf_mult * in_c,
                                      kernel_size=kernel, stride=stride, padding=padding, bias=bias),
                            nn.InstanceNorm2d(nf_mult * in_c)
                        ]
                    break

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)


class ConvUp(pl.LightningModule):
    def __init__(self, in_c, out_c, kernel, stride, padding=0, dilation=1, groups=1, bias=False):
        super().__init__()

        self.conv = nn.Conv2d(in_c, out_c, kernel,
                              stride, padding, dilation, groups, bias)
        self.bn = nn.InstanceNorm2d(out_c)
        self.relu = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, input, size):
        out = F.interpolate(input=input, size=size, mode='bilinear')
        out = self.conv(out)
        out = self.bn(out)
        out = self.relu(out)
        return out


class BASE(pl.LightningModule):
    def __init__(self, inner_nc):
        super(BASE, self).__init__()
        se = SELayer(inner_nc, 16)
        model = [se]
        gus = gussin(1.5).cuda()
        self.gus = torch.unsqueeze(gus, 1).double()
        self.model = nn.Sequential(*model)
        self.down = nn.Sequential(
            nn.Conv2d(1024, 512, 1, 1, 0, bias=False),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2)
        )

    def forward(self, x):
        Nonparm = Selfpatch()
        out_32 = self.model(x)
        b, c, h, w = out_32.size()
        gus = self.gus.float()
        gus_out = out_32[0].expand(h * w, c, h, w)
        gus_out = gus * gus_out
        gus_out = torch.sum(gus_out, -1)
        gus_out = torch.sum(gus_out, -1)
        gus_out = gus_out.contiguous().view(b, c, h, w)
        csa2_in = F.sigmoid(out_32)
        csa2_f = torch.nn.functional.pad(csa2_in, (1, 1, 1, 1))
        csa2_ff = torch.nn.functional.pad(out_32, (1, 1, 1, 1))
        csa2_fff, csa2_f, csa2_conv = Nonparm.buildAutoencoder(csa2_f[0], csa2_in[0], csa2_ff[0], 3, 1)
        csa2_conv = csa2_conv.expand_as(csa2_f)
        csa_a = csa2_conv * csa2_f
        csa_a = torch.mean(csa_a, 1)
        a_c, a_h, a_w = csa_a.size()
        csa_a = csa_a.contiguous().view(a_c, -1)
        csa_a = F.softmax(csa_a, dim=1)
        csa_a = csa_a.contiguous().view(a_c, 1, a_h, a_h)
        out = csa_a * csa2_fff
        out = torch.sum(out, -1)
        out = torch.sum(out, -1)
        out_csa = out.contiguous().view(b, c, h, w)
        out_32 = torch.cat([gus_out, out_csa], 1)
        out_32 = self.down(out_32)
        return out_32


class PartialConv(pl.LightningModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super().__init__()
        self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                                    stride, padding, dilation, groups, bias)
        self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                                   stride, padding, dilation, groups, False)

        torch.nn.init.constant_(self.mask_conv.weight, 1.0)

        # mask is not updated
        for param in self.mask_conv.parameters():
            param.requires_grad = False

    def forward(self, inputt):
        # http://masc.cs.gmu.edu/wiki/partialconv
        # C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
        # W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)

        input = inputt[0]
        mask = inputt[1].float().cuda()

        output = self.input_conv(input * mask)
        if self.input_conv.bias is not None:
            output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
                output)
        else:
            output_bias = torch.zeros_like(output)

        with torch.no_grad():
            output_mask = self.mask_conv(mask)

        no_update_holes = output_mask == 0
        mask_sum = output_mask.masked_fill_(no_update_holes.bool(), 1.0)
        output_pre = (output - output_bias) / mask_sum + output_bias
        output = output_pre.masked_fill_(no_update_holes.bool(), 0.0)
        new_mask = torch.ones_like(output)
        new_mask = new_mask.masked_fill_(no_update_holes.bool(), 0.0)
        out = []
        out.append(output)
        out.append(new_mask)
        return out


class PCconv(pl.LightningModule):
    def __init__(self):
        super(PCconv, self).__init__()
        self.down_128 = ConvDown(64, 128, 4, 2, padding=1, layers=2)
        self.down_64 = ConvDown(128, 256, 4, 2, padding=1)
        self.down_32 = ConvDown(256, 256, 1, 1)
        self.down_16 = ConvDown(512, 512, 4, 2, padding=1, activ=False)
        self.down_8 = ConvDown(512, 512, 4, 2, padding=1, layers=2, activ=False)
        self.down_4 = ConvDown(512, 512, 4, 2, padding=1, layers=3, activ=False)
        self.down = ConvDown(768, 256, 1, 1)
        self.fuse = ConvDown(512, 512, 1, 1)
        self.up = ConvUp(512, 256, 1, 1)
        self.up_128 = ConvUp(512, 64, 1, 1)
        self.up_64 = ConvUp(512, 128, 1, 1)
        self.up_32 = ConvUp(512, 256, 1, 1)
        self.base= BASE(512)
        seuqence_3 = []
        seuqence_5 = []
        seuqence_7 = []
        for i in range(5):
            seuqence_3 += [PCBActiv(256, 256, innorm=True)]
            seuqence_5 += [PCBActiv(256, 256, sample='same-5', innorm=True)]
            seuqence_7 += [PCBActiv(256, 256, sample='same-7', innorm=True)]

        self.cov_3 = nn.Sequential(*seuqence_3)
        self.cov_5 = nn.Sequential(*seuqence_5)
        self.cov_7 = nn.Sequential(*seuqence_7)
        self.activation = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, input, mask):
        mask =  cal_feat_mask(mask, 3, 1)
        # input[2]:256 32 32
        b, c, h, w = input[2].size()
        mask_1 = torch.add(torch.neg(mask.float()), 1)
        mask_1 = mask_1.expand(b, c, h, w)

        x_1 = self.activation(input[0])
        x_2 = self.activation(input[1])
        x_3 = self.activation(input[2])
        x_4 = self.activation(input[3])
        x_5 = self.activation(input[4])
        x_6 = self.activation(input[5])
        # Change the shape of each layer and intergrate low-level/high-level features
        x_1 = self.down_128(x_1)
        x_2 = self.down_64(x_2)
        x_3 = self.down_32(x_3)
        x_4 = self.up(x_4, (32, 32))
        x_5 = self.up(x_5, (32, 32))
        x_6 = self.up(x_6, (32, 32))

        # The first three layers are Texture/detail
        # The last three layers are Structure
        x_DE = torch.cat([x_1, x_2, x_3], 1)
        x_ST = torch.cat([x_4, x_5, x_6], 1)

        x_ST = self.down(x_ST)
        x_DE = self.down(x_DE)
        x_ST = [x_ST, mask_1]
        x_DE = [x_DE, mask_1]

        # Multi Scale PConv fill the Details
        x_DE_3 = self.cov_3(x_DE)
        x_DE_5 = self.cov_5(x_DE)
        x_DE_7 = self.cov_7(x_DE)
        x_DE_fuse = torch.cat([x_DE_3[0], x_DE_5[0], x_DE_7[0]], 1)
        x_DE_fi = self.down(x_DE_fuse)

        # Multi Scale PConv fill the Structure
        x_ST_3 = self.cov_3(x_ST)
        x_ST_5 = self.cov_5(x_ST)
        x_ST_7 = self.cov_7(x_ST)
        x_ST_fuse = torch.cat([x_ST_3[0], x_ST_5[0], x_ST_7[0]], 1)
        x_ST_fi = self.down(x_ST_fuse)

        x_cat = torch.cat([x_ST_fi, x_DE_fi], 1)
        x_cat_fuse = self.fuse(x_cat)

        # Feature equalizations
        x_final = self.base(x_cat_fuse)

        # Add back to the input
        x_ST = x_final
        x_DE = x_final
        x_1 = self.up_128(x_DE, (128, 128)) + input[0]
        x_2 = self.up_64(x_DE, (64, 64)) + input[1]
        x_3 = self.up_32(x_DE, (32, 32)) + input[2]
        x_4 = self.down_16(x_ST) + input[3]
        x_5 = self.down_8(x_ST) + input[4]
        x_6 = self.down_4(x_ST) + input[5]

        out = [x_1, x_2, x_3, x_4, x_5, x_6]
        loss = [x_ST_fi, x_DE_fi]
        out_final = [out, loss]
        return out_final



import torch.nn as nn


# Define the resnet block
class ResnetBlock(pl.LightningModule):
    def __init__(self, dim, dilation=1):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(dilation),
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=False),
            nn.InstanceNorm2d(dim, track_running_stats=False),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=False),
            nn.InstanceNorm2d(dim, track_running_stats=False),
        )

    def forward(self, x):
        out = x + self.conv_block(x)
        return out


# define the Encoder unit
class UnetSkipConnectionEBlock(pl.LightningModule):
    def __init__(self, outer_nc, inner_nc, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
                 use_dropout=False):
        super(UnetSkipConnectionEBlock, self).__init__()
        downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1)

        downrelu = nn.LeakyReLU(0.2, True)

        downnorm = norm_layer(inner_nc, affine=True)
        if outermost:
            down = [downconv]
            model = down
        elif innermost:
            down = [downrelu, downconv]
            model = down
        else:
            down = [downrelu, downconv, downnorm]
            if use_dropout:
                model = down + [nn.Dropout(0.5)]
            else:
                model = down
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class Encoder(pl.LightningModule):
    def __init__(self, input_nc, output_nc, ngf=64, res_num=4, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(Encoder, self).__init__()

        # construct unet structure
        Encoder_1 = UnetSkipConnectionEBlock(input_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, outermost=True)
        Encoder_2 = UnetSkipConnectionEBlock(ngf, ngf * 2, norm_layer=norm_layer, use_dropout=use_dropout)
        Encoder_3 = UnetSkipConnectionEBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, use_dropout=use_dropout)
        Encoder_4 = UnetSkipConnectionEBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout)
        Encoder_5 = UnetSkipConnectionEBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout)
        Encoder_6 = UnetSkipConnectionEBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout, innermost=True)

        blocks = []
        for _ in range(res_num):
            block = ResnetBlock(ngf * 8, 2)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)

        self.Encoder_1 = Encoder_1
        self.Encoder_2 = Encoder_2
        self.Encoder_3 = Encoder_3
        self.Encoder_4 = Encoder_4
        self.Encoder_5 = Encoder_5
        self.Encoder_6 = Encoder_6

    def forward(self, input):
        y_1 = self.Encoder_1(input)
        y_2 = self.Encoder_2(y_1)
        y_3 = self.Encoder_3(y_2)
        y_4 = self.Encoder_4(y_3)
        y_5 = self.Encoder_5(y_4)
        y_6 = self.Encoder_6(y_5)
        y_7 = self.middle(y_6)

        return y_1, y_2, y_3, y_4, y_5, y_7


import torch.nn as nn
import torch

class UnetSkipConnectionDBlock(pl.LightningModule):
    def __init__(self, inner_nc, outer_nc, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d,
                 use_dropout=False):
        super(UnetSkipConnectionDBlock, self).__init__()
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc, affine=True)
        upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                    kernel_size=4, stride=2,
                                    padding=1)
        up = [uprelu, upconv, upnorm]

        if outermost:
            up = [uprelu, upconv, nn.Tanh()]
            model = up
        elif innermost:
            up = [uprelu, upconv, upnorm]
            model = up
        else:
            up = [uprelu, upconv, upnorm]
            model = up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class Decoder(pl.LightningModule):
    def __init__(self, input_nc, output_nc, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(Decoder, self).__init__()

        # construct unet structure
        Decoder_1 = UnetSkipConnectionDBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout,
                                             innermost=True)
        Decoder_2 = UnetSkipConnectionDBlock(ngf * 16, ngf * 8, norm_layer=norm_layer, use_dropout=use_dropout)
        Decoder_3 = UnetSkipConnectionDBlock(ngf * 16, ngf * 4, norm_layer=norm_layer, use_dropout=use_dropout)
        Decoder_4 = UnetSkipConnectionDBlock(ngf * 8, ngf * 2, norm_layer=norm_layer, use_dropout=use_dropout)
        Decoder_5 = UnetSkipConnectionDBlock(ngf * 4, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
        Decoder_6 = UnetSkipConnectionDBlock(ngf * 2, output_nc, norm_layer=norm_layer, use_dropout=use_dropout, outermost=True)

        self.Decoder_1 = Decoder_1
        self.Decoder_2 = Decoder_2
        self.Decoder_3 = Decoder_3
        self.Decoder_4 = Decoder_4
        self.Decoder_5 = Decoder_5
        self.Decoder_6 = Decoder_6

    def forward(self, input_1, input_2, input_3, input_4, input_5, input_6):
        y_1 = self.Decoder_1(input_6)
        y_2 = self.Decoder_2(torch.cat([y_1, input_5], 1))
        y_3 = self.Decoder_3(torch.cat([y_2, input_4], 1))
        y_4 = self.Decoder_4(torch.cat([y_3, input_3], 1))
        y_5 = self.Decoder_5(torch.cat([y_4, input_2], 1))
        y_6 = self.Decoder_6(torch.cat([y_5, input_1], 1))
        out = y_6

        return out


class PCblock(pl.LightningModule):
    def __init__(self, stde_list):
        super(PCblock, self).__init__()
        self.pc_block = PCconv()
        innerloss = InnerCos()
        stde_list.append(innerloss)
        loss = [innerloss]
        self.loss=nn.Sequential(*loss)
    def forward(self, input, mask):
        out = self.pc_block(input, mask)
        out = self.loss(out)
        return out


class MEDFEGenerator(pl.LightningModule):
    def __init__(self, input_nc=4, output_nc=3, ngf=64,  norm='batch', use_dropout=False, stde_list=[], norm_layer = nn.BatchNorm2d):
        super().__init__()
        self.netEN = Encoder(input_nc=input_nc, output_nc=output_nc, ngf=ngf, norm_layer=norm_layer, use_dropout=use_dropout)
        self.netDE = Decoder(input_nc=input_nc, output_nc=output_nc, ngf=ngf, norm_layer=norm_layer, use_dropout=use_dropout)
        self.netMEDFE = PCblock(stde_list)

    def mask_process(self, mask):
        mask = mask[0][0]
        mask = torch.unsqueeze(mask, 0)
        mask = torch.unsqueeze(mask, 1)
        mask = mask.byte()
        return mask

    def forward(self, images, masks):
        #masks =torch.cat([masks,masks,masks],1)

        fake_p_1, fake_p_2, fake_p_3, fake_p_4, fake_p_5, fake_p_6 = self.netEN(torch.cat([images, masks], 1))
        x_out = self.netMEDFE([fake_p_1, fake_p_2, fake_p_3, fake_p_4, fake_p_5, fake_p_6], masks)
        self.fake_out = self.netDE(x_out[0], x_out[1], x_out[2], x_out[3], x_out[4], x_out[5])

        return self.fake_out


In [None]:
#@title [RFR_arch.py](https://github.com/jingyuanli001/RFR-Inpainting) (2020)
"""
RFRNet.py (18-12-20)
https://github.com/jingyuanli001/RFR-Inpainting/blob/master/modules/RFRNet.py

partialconv2d.py (18-12-20) # using their partconv2d to avoid dimension errors
https://github.com/jingyuanli001/RFR-Inpainting/blob/master/modules/partialconv2d.py

Attention.py (18-12-20)
https://github.com/jingyuanli001/RFR-Inpainting/blob/master/modules/Attention.py
"""

from torch import nn
from torchvision import models
import torch
import torch.nn as nn
import torch.nn.functional as F
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d
import pytorch_lightning as pl

class KnowledgeConsistentAttention(pl.LightningModule):
    def __init__(self, patch_size = 3, propagate_size = 3, stride = 1):
        super(KnowledgeConsistentAttention, self).__init__()
        self.patch_size = patch_size
        self.propagate_size = propagate_size
        self.stride = stride
        self.prop_kernels = None
        self.att_scores_prev = None
        self.masks_prev = None
        self.ratio = nn.Parameter(torch.ones(1))

    def forward(self, foreground, masks):
        bz, nc, h, w = foreground.size()
        if masks.size(3) != foreground.size(3):
            masks = F.interpolate(masks, foreground.size()[2:])
        background = foreground.clone()
        background = background
        conv_kernels_all = background.view(bz, nc, w * h, 1, 1)
        conv_kernels_all = conv_kernels_all.permute(0, 2, 1, 3, 4)
        output_tensor = []
        att_score = []
        for i in range(bz):
            feature_map = foreground[i:i+1]
            conv_kernels = conv_kernels_all[i] + 0.0000001
            norm_factor = torch.sum(conv_kernels**2, [1, 2, 3], keepdim = True)**0.5
            conv_kernels = conv_kernels/norm_factor

            conv_result = F.conv2d(feature_map, conv_kernels, padding = self.patch_size//2)
            if self.propagate_size != 1:
                if self.prop_kernels is None:
                    self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
                    self.prop_kernels.requires_grad = False
                    self.prop_kernels = self.prop_kernels.cuda()
                conv_result = F.avg_pool2d(conv_result, 3, 1, padding = 1)*9
            attention_scores = F.softmax(conv_result, dim = 1)
            if self.att_scores_prev is not None:
                attention_scores = (self.att_scores_prev[i:i+1]*self.masks_prev[i:i+1] + attention_scores * (torch.abs(self.ratio)+1e-7))/(self.masks_prev[i:i+1]+(torch.abs(self.ratio)+1e-7))
            att_score.append(attention_scores)
            feature_map = F.conv_transpose2d(attention_scores, conv_kernels, stride = 1, padding = self.patch_size//2)
            final_output = feature_map
            output_tensor.append(final_output)
        self.att_scores_prev = torch.cat(att_score, dim = 0).view(bz, h*w, h, w)
        self.masks_prev = masks.view(bz, 1, h, w)
        return torch.cat(output_tensor, dim = 0)

class AttentionModule(pl.LightningModule):

    def __init__(self, inchannel, patch_size_list = [1], propagate_size_list = [3], stride_list = [1]):
        assert isinstance(patch_size_list, list), "patch_size should be a list containing scales, or you should use Contextual Attention to initialize your module"
        assert len(patch_size_list) == len(propagate_size_list) and len(propagate_size_list) == len(stride_list), "the input_lists should have same lengths"
        super(AttentionModule, self).__init__()

        self.att = KnowledgeConsistentAttention(patch_size_list[0], propagate_size_list[0], stride_list[0])
        self.num_of_modules = len(patch_size_list)
        self.combiner = nn.Conv2d(inchannel * 2, inchannel, kernel_size = 1)

    def forward(self, foreground, mask):
        outputs = self.att(foreground, mask)
        outputs = torch.cat([outputs, foreground],dim = 1)
        outputs = self.combiner(outputs)
        return outputs



###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
###############################################################################


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        self.return_mask = True

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask=None):

        if mask is not None or self.last_size != (input.data.shape[2], input.data.shape[3]):
            self.last_size = (input.data.shape[2], input.data.shape[3])

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
            self.update_mask.to(input)
            self.mask_ratio.to(input)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class Bottleneck(pl.LightningModule):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace = True)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += residual
        out = self.relu(out)

        return out

class RFRModule(pl.LightningModule):
    def __init__(self, layer_size=6, in_channel = 64):
        super(RFRModule, self).__init__()
        self.freeze_enc_bn = False
        self.layer_size = layer_size
        for i in range(3):
            name = 'enc_{:d}'.format(i + 1)
            out_channel = in_channel * 2
            block = [nn.Conv2d(in_channel, out_channel, 3, 2, 1, bias = False),
                     nn.BatchNorm2d(out_channel),
                     nn.ReLU(inplace = True)]
            in_channel = out_channel
            setattr(self, name, nn.Sequential(*block))

        for i in range(3, 6):
            name = 'enc_{:d}'.format(i + 1)
            block = [nn.Conv2d(in_channel, out_channel, 3, 1, 2, dilation = 2, bias = False),
                     nn.BatchNorm2d(out_channel),
                     nn.ReLU(inplace = True)]
            setattr(self, name, nn.Sequential(*block))
        self.att = AttentionModule(512)
        for i in range(5, 3, -1):
            name = 'dec_{:d}'.format(i)
            block = [nn.Conv2d(in_channel + in_channel, in_channel, 3, 1, 2, dilation = 2, bias = False),
                     nn.BatchNorm2d(in_channel),
                     nn.LeakyReLU(0.2, inplace = True)]
            setattr(self, name, nn.Sequential(*block))


        block = [nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias = False),
                 nn.BatchNorm2d(512),
                 nn.LeakyReLU(0.2, inplace = True)]
        self.dec_3 = nn.Sequential(*block)

        block = [nn.ConvTranspose2d(768, 256, 4, 2, 1, bias = False),
                 nn.BatchNorm2d(256),
                 nn.LeakyReLU(0.2, inplace = True)]
        self.dec_2 = nn.Sequential(*block)

        block = [nn.ConvTranspose2d(384, 64, 4, 2, 1, bias = False),
                 nn.BatchNorm2d(64),
                 nn.LeakyReLU(0.2, inplace = True)]
        self.dec_1 = nn.Sequential(*block)

    def forward(self, input, mask):

        h_dict = {}  # for the output of enc_N

        h_dict['h_0']= input

        h_key_prev = 'h_0'
        for i in range(1, self.layer_size + 1):
            l_key = 'enc_{:d}'.format(i)
            h_key = 'h_{:d}'.format(i)
            h_dict[h_key] = getattr(self, l_key)(h_dict[h_key_prev])
            h_key_prev = h_key

        h = h_dict[h_key]
        for i in range(self.layer_size - 1, 0, -1):
            enc_h_key = 'h_{:d}'.format(i)
            dec_l_key = 'dec_{:d}'.format(i)
            h = torch.cat([h, h_dict[enc_h_key]], dim=1)
            h = getattr(self, dec_l_key)(h)
            if i == 3:
                h = self.att(h, mask)
        return h



class RFRNet(pl.LightningModule):
    def __init__(self, conv_type):
        super(RFRNet, self).__init__()

        self.conv_type = conv_type
        if self.conv_type == 'partial':
          self.conv1 = PartialConv2d(3, 64, 7, 2, 3, multi_channel = True, bias = False)
          self.conv2 = PartialConv2d(64, 64, 7, 1, 3, multi_channel = True, bias = False)
          self.conv21 = PartialConv2d(64, 64, 7, 1, 3, multi_channel = True, bias = False)
          self.conv22 = PartialConv2d(64, 64, 7, 1, 3, multi_channel = True, bias = False)
          self.tail1 = PartialConv2d(67, 32, 3, 1, 1, multi_channel = True, bias = False)
          # original code uses conv2d
          self.out = nn.Conv2d(64,3,3,1,1, bias = False)
        elif self.conv_type == 'deform':
          self.conv1 = DeformConv2d(3, 64, 7, 2, 3)
          self.conv2 = DeformConv2d(64, 64, 7, 1, 3)
          self.conv21 = DeformConv2d(64, 64, 7, 1, 3)
          self.conv22 = DeformConv2d(64, 64, 7, 1, 3)
          self.tail1 = DeformConv2d(67, 32, 3, 1, 1)
          # original code uses conv2d
          self.out = nn.Conv2d(64,3,3,1,1, bias = False)
        else:
          print("conv_type not found")

        self.bn1 = nn.BatchNorm2d(64)
        self.bn20 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.RFRModule = RFRModule()
        self.Tconv = nn.ConvTranspose2d(64, 64, 4, 2, 1, bias = False)
        self.bn3 = nn.BatchNorm2d(64)

        self.tail2 = Bottleneck(32,8)

    def forward(self, in_image, mask):
        #in_image = torch.cat((in_image, mask), dim=1)
        mask =torch.cat([mask,mask,mask],1)
        if self.conv_type == 'partial':
          x1, m1 = self.conv1(in_image.type(torch.cuda.FloatTensor), mask.type(torch.cuda.FloatTensor))
        elif self.conv_type == 'deform':
          x1 = self.conv1(in_image)
          m1 = self.conv1(mask)

        x1 = F.relu(self.bn1(x1), inplace = True)

        if self.conv_type == 'partial':
          x1, m1 = self.conv2(x1, m1)
        elif self.conv_type == 'deform':
          x1 = self.conv2(x1)
          m1 = self.conv2(m1)

        x1 = F.relu(self.bn20(x1), inplace = True)
        x2 = x1
        x2, m2 = x1, m1
        n, c, h, w = x2.size()
        feature_group = [x2.view(n, c, 1, h, w)]
        mask_group = [m2.view(n, c, 1, h, w)]
        self.RFRModule.att.att.att_scores_prev = None
        self.RFRModule.att.att.masks_prev = None

        for i in range(6):
            if self.conv_type == 'partial':
              x2, m2 = self.conv21(x2, m2)
              x2, m2 = self.conv22(x2, m2)
            elif self.conv_type == 'deform':
              x2 = self.conv21(x2)
              m2 = self.conv21(m2)
              x2 = self.conv22(x2)
              m2 = self.conv22(m2)

            x2 = F.leaky_relu(self.bn2(x2), inplace = True)
            x2 = self.RFRModule(x2, m2[:,0:1,:,:])
            x2 = x2 * m2
            feature_group.append(x2.view(n, c, 1, h, w))
            mask_group.append(m2.view(n, c, 1, h, w))
        x3 = torch.cat(feature_group, dim = 2)
        m3 = torch.cat(mask_group, dim = 2)
        amp_vec = m3.mean(dim = 2)
        x3 = (x3*m3).mean(dim = 2) /(amp_vec+1e-7)
        x3 = x3.view(n, c, h, w)
        m3 = m3[:,:,-1,:,:]
        x4 = self.Tconv(x3)
        x4 = F.leaky_relu(self.bn3(x4), inplace = True)
        m4 = F.interpolate(m3, scale_factor = 2)
        x5 = torch.cat([in_image, x4], dim = 1)
        m5 = torch.cat([mask, m4], dim = 1)

        if self.conv_type == 'partial':
          x5, _ = self.tail1(x5, m5)
        elif self.conv_type == 'deform':
          x5 = self.tail1(x5)

        x5 = F.leaky_relu(x5, inplace = True)
        x6 = self.tail2(x5)
        x6 = torch.cat([x5,x6], dim = 1)
        output = self.out(x6)
        return output


In [None]:
#@title [DMFN_arch.py](https://github.com/Zheng222/DMFN) (2020)
"""
block.py (18-12-20)
https://github.com/Zheng222/DMFN/blob/master/models/block.py

architecture.py (18-12-20)
https://github.com/Zheng222/DMFN/blob/master/models/architecture.py
"""
import torch.nn as nn
import torch
import pytorch_lightning as pl

def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1):
    padding = int((kernel_size - 1) / 2) * dilation
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation,
                     groups=groups)


def _norm(norm_type, nc):
    norm_type = norm_type.lower()
    if norm_type == 'bn':
        layer = nn.BatchNorm2d(nc, affine=True)
    elif norm_type == 'in':
        layer = nn.InstanceNorm2d(nc, affine=False)
    else:
        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
    return layer


def _activation(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer


class conv_block(pl.LightningModule):
    def __init__(self, in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
                 padding=0, norm='in', activation='relu', pad_type='zero'):
        super(conv_block, self).__init__()
        if pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        elif pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        if norm == 'in':
            self.norm = nn.InstanceNorm2d(out_nc, affine=False)
        elif norm == 'bn':
            self.norm = nn.BatchNorm2d(out_nc, affine=True)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported norm type: {}".format(norm)

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(negative_slope=0.2)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        self.conv = nn.Conv2d(in_nc, out_nc, kernel_size, stride, 0, dilation, groups, bias)  # padding=0

    def forward(self, x):
        x = self.conv(self.pad(x))
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x


class upconv_block(pl.LightningModule):
    def __init__(self, in_nc, out_nc, kernel_size=3, stride=1, bias=True,
                 padding=0, pad_type='zero', norm='none', activation='relu'):
        super(upconv_block, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_nc, out_nc, 4, 2, 1)
        self.act = _activation('relu')
        self.norm = _norm('in', out_nc)

        self.conv = conv_block(out_nc, out_nc, kernel_size, stride, bias=bias, padding=padding, pad_type=pad_type,
                               norm=norm, activation=activation)

    def forward(self, x):
        x = self.act(self.norm(self.deconv(x)))
        x = self.conv(x)
        return x

class ResBlock_new(pl.LightningModule):
    def __init__(self, nc):
        super(ResBlock_new, self).__init__()
        self.c1 = conv_layer(nc, nc // 4, 3, 1)
        self.d1 = conv_layer(nc // 4, nc // 4, 3, 1, 1)  # rate = 1
        self.d2 = conv_layer(nc // 4, nc // 4, 3, 1, 2)  # rate = 2
        self.d3 = conv_layer(nc // 4, nc // 4, 3, 1, 4)  # rate = 4
        self.d4 = conv_layer(nc // 4, nc // 4, 3, 1, 8)  # rate = 8
        self.act = _activation('relu')
        self.norm = _norm('in', nc)
        self.c2 = conv_layer(nc, nc, 3, 1)  # fusion

    def forward(self, x):
        output1 = self.act(self.norm(self.c1(x)))
        d1 = self.d1(output1)
        d2 = self.d2(output1)
        d3 = self.d3(output1)
        d4 = self.d4(output1)

        add1 = d1 + d2
        add2 = add1 + d3
        add3 = add2 + d4
        combine = torch.cat([d1, add1, add2, add3], 1)
        output2 = self.c2(self.act(self.norm(combine)))
        output = x + self.norm(output2)
        return output

import torch.nn as nn
#from . import block as B

class InpaintingGenerator(pl.LightningModule):
    def __init__(self, in_nc=4, out_nc=3, nf=64, n_res=8, norm='in', activation='relu'):
        super(InpaintingGenerator, self).__init__()
        self.encoder = nn.Sequential(  # input: [4, 256, 256]
            conv_block(in_nc, nf, 5, stride=1, padding=2, norm='none', activation=activation),  # [64, 256, 256]
            conv_block(nf, 2 * nf, 3, stride=2, padding=1, norm=norm, activation=activation),  # [128, 128, 128]
            conv_block(2 * nf, 2 * nf, 3, stride=1, padding=1, norm=norm, activation=activation),  # [128, 128, 128]
            conv_block(2 * nf, 4 * nf, 3, stride=2, padding=1, norm=norm, activation=activation)  # [256, 64, 64]
        )

        blocks = []
        for _ in range(n_res):
            block = ResBlock_new(4 * nf)
            blocks.append(block)

        self.blocks = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            conv_block(4 * nf, 4 * nf, 3, stride=1, padding=1, norm=norm, activation=activation),  # [256, 64, 64]
            upconv_block(4 * nf, 2 * nf, kernel_size=3, stride=1, padding=1, norm=norm, activation='relu'),
            # [128, 128, 128]
            upconv_block(2 * nf, nf, kernel_size=3, stride=1, padding=1, norm=norm, activation='relu'),
            # [64, 256, 256]
            conv_block(nf, out_nc, 3, stride=1, padding=1, norm='none', activation='tanh')  # [3, 256, 256]
        )

    def forward(self, x, mask):
        x = torch.cat([x, mask], dim=1)
        x = self.encoder(x)
        x = self.blocks(x)
        x = self.decoder(x)
        return x


In [None]:
#@title [RN_arch.py](https://github.com/geekyutao/RN/) (2020)
"""
networks.py (13-12-20)
https://github.com/geekyutao/RN/blob/a3cf1fccc08f22fcf4b336503a8853748720fd67/networks.py

rn.py (13-12-20)
https://github.com/geekyutao/RN/blob/a3cf1fccc08f22fcf4b336503a8853748720fd67/rn.py

module_util.py (15-12-20)
https://github.com/geekyutao/RN/blob/a3cf1fccc08f22fcf4b336503a8853748720fd67/module_util.py
"""

from torchvision.transforms import *
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import pytorch_lightning as pl
logger = logging.getLogger('base')

def rn_initialize_weights(net_l, scale=1):
    if not isinstance(net_l, list):
        net_l = [net_l]
    for net in net_l:
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                m.weight.data *= scale  # for residual block
                if m.bias is not None:
                    init.normal_(m.bias, 0.0001)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                m.weight.data *= scale
                if m.bias is not None:
                    init.normal_(m.bias, 0.0001)
            elif isinstance(m, nn.BatchNorm2d):
                try:
                    init.constant_(m.weight, 1)
                    init.normal_(m.bias, 0.0001)
                except:
                    print('This layer has no BN parameters:', m)
    logger.info('RN Initialization method [kaiming]')



class RN_binarylabel(pl.LightningModule):
    def __init__(self, feature_channels):
        super(RN_binarylabel, self).__init__()
        self.bn_norm = nn.BatchNorm2d(feature_channels, affine=False, track_running_stats=False)

    def forward(self, x, label):
        '''
        input:  x: (B,C,M,N), features
                label: (B,1,M,N), 1 for foreground regions, 0 for background regions
        output: _x: (B,C,M,N)
        '''
        label = label.detach()

        rn_foreground_region = self.rn(x * label, label)

        rn_background_region = self.rn(x * (1 - label), 1 - label)

        return rn_foreground_region + rn_background_region

    def rn(self, region, mask):
        '''
        input:  region: (B,C,M,N), 0 for surroundings
                mask: (B,1,M,N), 1 for target region, 0 for surroundings
        output: rn_region: (B,C,M,N)
        '''
        shape = region.size()

        sum = torch.sum(region, dim=[0,2,3])  # (B, C) -> (C)
        Sr = torch.sum(mask, dim=[0,2,3])    # (B, 1) -> (1)
        Sr[Sr==0] = 1
        mu = (sum / Sr)     # (B, C) -> (C)

        return self.bn_norm(region + (1 - mask) * mu[None,:,None,None]) * \
        (torch.sqrt(Sr / (shape[0] * shape[2] * shape[3])))[None,:,None,None]

class RN_B(pl.LightningModule):
    def __init__(self, feature_channels):
        super(RN_B, self).__init__()
        '''
        input: tensor(features) x: (B,C,M,N)
               condition Mask: (B,1,H,W): 0 for background, 1 for foreground
        return: tensor RN_B(x): (N,C,M,N)
        ---------------------------------------
        args:
            feature_channels: C
        '''
        # RN
        self.rn = RN_binarylabel(feature_channels)    # need no external parameters

        # gamma and beta
        self.foreground_gamma = nn.Parameter(torch.zeros(feature_channels), requires_grad=True)
        self.foreground_beta = nn.Parameter(torch.zeros(feature_channels), requires_grad=True)
        self.background_gamma = nn.Parameter(torch.zeros(feature_channels), requires_grad=True)
        self.background_beta = nn.Parameter(torch.zeros(feature_channels), requires_grad=True)

    def forward(self, x, mask):
        # mask = F.adaptive_max_pool2d(mask, output_size=x.size()[2:])
        mask = F.interpolate(mask, size=x.size()[2:], mode='nearest')   # after down-sampling, there can be all-zero mask

        rn_x = self.rn(x, mask)

        rn_x_foreground = (rn_x * mask) * (1 + self.foreground_gamma[None,:,None,None]) + self.foreground_beta[None,:,None,None]
        rn_x_background = (rn_x * (1 - mask)) * (1 + self.background_gamma[None,:,None,None]) + self.background_beta[None,:,None,None]

        return rn_x_foreground + rn_x_background

class SelfAware_Affine(pl.LightningModule):
    def __init__(self, kernel_size=7):
        super(SelfAware_Affine, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

        self.gamma_conv = nn.Conv2d(1, 1, kernel_size, padding=padding)
        self.beta_conv = nn.Conv2d(1, 1, kernel_size, padding=padding)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)

        x = self.conv1(x)
        importance_map = self.sigmoid(x)

        gamma = self.gamma_conv(importance_map)
        beta = self.beta_conv(importance_map)

        return importance_map, gamma, beta

class RN_L(pl.LightningModule):
    def __init__(self, feature_channels, threshold=0.8):
        super(RN_L, self).__init__()
        '''
        input: tensor(features) x: (B,C,M,N)
        return: tensor RN_L(x): (B,C,M,N)
        ---------------------------------------
        args:
            feature_channels: C
        '''
        # SelfAware_Affine
        self.sa = SelfAware_Affine()
        self.threshold = threshold

        # RN
        self.rn = RN_binarylabel(feature_channels)    # need no external parameters


    def forward(self, x):

        sa_map, gamma, beta = self.sa(x)     # (B,1,M,N)

        # m = sa_map.detach()
        if x.is_cuda:
            mask = torch.zeros_like(sa_map).cuda()
        else:
            mask = torch.zeros_like(sa_map)
        mask[sa_map.detach() >= self.threshold] = 1

        rn_x = self.rn(x, mask.expand(x.size()))

        rn_x = rn_x * (1 + gamma) + beta

        return rn_x


class G_Net(pl.LightningModule):
    def __init__(self, input_channels=3, residual_blocks=8, threshold=0.8):
        super(G_Net, self).__init__()

        # Encoder
        self.encoder_prePad = nn.ReflectionPad2d(3)
        self.encoder_conv1 = nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=7, padding=0)
        self.encoder_in1 = RN_B(feature_channels=64)
        self.encoder_relu1 = nn.ReLU(True)
        self.encoder_conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.encoder_in2 = RN_B(feature_channels=128)
        self.encoder_relu2 = nn.ReLU(True)
        self.encoder_conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.encoder_in3 = RN_B(feature_channels=256)
        self.encoder_relu3 = nn.ReLU(True)


        # Middle
        blocks = []
        for _ in range(residual_blocks):
            # block = ResnetBlock(256, 2, use_spectral_norm=False)
            block = saRN_ResnetBlock(256, dilation=2, threshold=threshold, use_spectral_norm=False)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)


        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128*4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            RN_L(128),
            nn.ReLU(True),

            nn.Conv2d(128, 64*4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            RN_L(64),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=input_channels, kernel_size=7, padding=0)

        )


    def encoder(self, x, mask):
        # float
        x = x.type(torch.cuda.FloatTensor)
        mask = mask.type(torch.cuda.FloatTensor)

        # half precision (not working?)
        #x = x.type(torch.cuda.HalfTensor)
        #mask = mask.type(torch.cuda.HalfTensor)

        x = self.encoder_prePad(x)

        x = self.encoder_conv1(x)
        x = self.encoder_in1(x, mask)
        x = self.encoder_relu2(x)

        x = self.encoder_conv2(x)
        x = self.encoder_in2(x, mask)
        x = self.encoder_relu2(x)

        x = self.encoder_conv3(x)
        x = self.encoder_in3(x, mask)
        x = self.encoder_relu3(x)
        return x

    def forward(self, x, mask):
        #gt = x
        #x = (x * (1 - mask).float()) + mask
        # input mask: 1 for hole, 0 for valid
        x = self.encoder(x, mask)

        x = self.middle(x)

        x = self.decoder(x)

        x = (torch.tanh(x) + 1) / 2
        # x = x*mask+gt*(1-mask)
        return x


class ResnetBlock(pl.LightningModule):
    def __init__(self, dim, dilation=1, use_spectral_norm=True):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(dilation),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(256, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
        )

    def forward(self, x):
        out = x + self.conv_block(x)

        # Remove ReLU at the end of the residual block
        # http://torch.ch/blog/2016/02/04/resnets.html

        return out

class saRN_ResnetBlock(pl.LightningModule):
    def __init__(self, dim, dilation, threshold, use_spectral_norm=True):
        super(saRN_ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(dilation),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
            # nn.InstanceNorm2d(256, track_running_stats=False),
            RN_L(feature_channels=256, threshold=threshold),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
            # nn.InstanceNorm2d(dim, track_running_stats=False),
            RN_L(feature_channels=dim, threshold=threshold),
        )

    def forward(self, x):
        out = x + self.conv_block(x)
        # skimage.io.imsave('block.png', out[0].detach().permute(1,2,0).cpu().numpy()[:,:,0])

        # Remove ReLU at the end of the residual block
        # http://torch.ch/blog/2016/02/04/resnets.html

        return out

def spectral_norm(module, mode=True):
    if mode:
        return nn.utils.spectral_norm(module)

    return module


In [None]:
#@title [DFNet_arch.py](https://github.com/hughplay/DFNet) (2019) [no SWA]
# https://github.com/hughplay/DFNet
# https://github.com/Yukariin/DFNet
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

#from .convolutions import partialconv2d
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d

def resize_like(x, target, mode='bilinear'):
    return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)


def get_norm(name, out_channels):
    if name == 'batch':
        norm = nn.BatchNorm2d(out_channels)
    elif name == 'instance':
        norm = nn.InstanceNorm2d(out_channels)
    else:
        norm = None
    return norm


def get_activation(name):
    if name == 'relu':
        activation = nn.ReLU()
    elif name == 'elu':
        activation == nn.ELU()
    elif name == 'leaky_relu':
        activation = nn.LeakyReLU(negative_slope=0.2)
    elif name == 'tanh':
        activation = nn.Tanh()
    elif name == 'sigmoid':
        activation = nn.Sigmoid()
    else:
        activation = None
    return activation


class Conv2dSame(pl.LightningModule):

    def __init__(self, in_channels, out_channels, conv_type, kernel_size, stride):
        super().__init__()

        padding = self.conv_same_pad(kernel_size, stride)

        if conv_type == 'normal':
          # original
          if type(padding) is not tuple:
              self.conv = nn.Conv2d(
                  in_channels, out_channels, kernel_size, stride, padding)
          else:
              self.conv = nn.Sequential(
                  nn.ConstantPad2d(padding*2, 0),
                  nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0)
              )

        elif conv_type == 'partial':
          if type(padding) is not tuple:
              self.conv = PartialConv2d(
                  in_channels, out_channels, kernel_size, stride, padding)
          else:
              self.conv = nn.Sequential(
                  nn.ConstantPad2d(padding*2, 0),
                  PartialConv2d(in_channels, out_channels, kernel_size, stride, 0)
              )


        elif conv_type == 'deform':
          if type(padding) is not tuple:
              self.conv = PartialConv2d(
                  in_channels, out_channels, kernel_size, stride, padding)
          else:
              self.conv = nn.Sequential(
                  nn.ConstantPad2d(padding*2, 0),
                  DeformConv2d(in_channels, out_channels, kernel_size, stride, 0)
              )


    def conv_same_pad(self, ksize, stride):
        if (ksize - stride) % 2 == 0:
            return (ksize - stride) // 2
        else:
            left = (ksize - stride) // 2
            right = left + 1
            return left, right

    def forward(self, x):
        return self.conv(x)


class ConvTranspose2dSame(pl.LightningModule):

    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()

        padding, output_padding = self.deconv_same_pad(kernel_size, stride)
        self.trans_conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size, stride,
            padding, output_padding)

    def deconv_same_pad(self, ksize, stride):
        pad = (ksize - stride + 1) // 2
        outpad = 2 * pad + stride - ksize
        return pad, outpad

    def forward(self, x):
        return self.trans_conv(x)


class UpBlock(pl.LightningModule):

    def __init__(self, mode='nearest', scale=2, channel=None, kernel_size=4):
        super().__init__()

        self.mode = mode
        if mode == 'deconv':
            self.up = ConvTranspose2dSame(
                channel, channel, kernel_size, stride=scale)
        else:
            def upsample(x):
                return F.interpolate(x, scale_factor=scale, mode=mode)
            self.up = upsample

    def forward(self, x):
        return self.up(x)


class EncodeBlock(pl.LightningModule):

    def __init__(
            self, in_channels, out_channels, conv_type, kernel_size, stride,
            normalization=None, activation=None):
        super().__init__()

        self.c_in = in_channels
        self.c_out = out_channels

        layers = []
        layers.append(
            Conv2dSame(self.c_in, self.c_out, conv_type, kernel_size, stride))
        if normalization:
            layers.append(get_norm(normalization, self.c_out))
        if activation:
            layers.append(get_activation(activation))
        self.encode = nn.Sequential(*layers)

    def forward(self, x):
        return self.encode(x)


class DecodeBlock(pl.LightningModule):

    def __init__(
            self, c_from_up, c_from_down, conv_type, c_out, mode='nearest',
            kernel_size=4, scale=2, normalization='batch', activation='relu'):
        super().__init__()

        self.c_from_up = c_from_up
        self.c_from_down = c_from_down
        self.c_in = c_from_up + c_from_down
        self.c_out = c_out

        self.up = UpBlock(mode, scale, c_from_up, kernel_size=scale)

        layers = []
        layers.append(
            Conv2dSame(self.c_in, self.c_out, conv_type, kernel_size, stride=1))
        if normalization:
            layers.append(get_norm(normalization, self.c_out))
        if activation:
            layers.append(get_activation(activation))
        self.decode = nn.Sequential(*layers)

    def forward(self, x, concat=None):
        out = self.up(x)
        if self.c_from_down > 0:
            out = torch.cat([out, concat], dim=1)
        out = self.decode(out)
        return out


class BlendBlock(pl.LightningModule):

    def __init__(
            self, c_in, c_out, conv_type, ksize_mid=3, norm='batch', act='leaky_relu'):
        super().__init__()
        c_mid = max(c_in // 2, 32)
        self.blend = nn.Sequential(
            Conv2dSame(c_in, c_mid, conv_type, 1, 1),
            get_norm(norm, c_mid),
            get_activation(act),
            Conv2dSame(c_mid, c_out, conv_type, ksize_mid, 1),
            get_norm(norm, c_out),
            get_activation(act),
            Conv2dSame(c_out, c_out, conv_type, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.blend(x)


class FusionBlock(pl.LightningModule):
    def __init__(self, c_feat, conv_type, c_alpha=1):
        super().__init__()
        c_img = 3
        self.map2img = nn.Sequential(
            Conv2dSame(c_feat, c_img, conv_type, 1, 1),
            nn.Sigmoid())
        self.blend = BlendBlock(c_img*2, c_alpha, conv_type)

    def forward(self, img_miss, feat_de):
        img_miss = resize_like(img_miss, feat_de)
        raw = self.map2img(feat_de)
        alpha = self.blend(torch.cat([img_miss, raw], dim=1))
        result = alpha * raw + (1 - alpha) * img_miss
        return result, alpha, raw

from torchvision.utils import save_image

class DFNet(pl.LightningModule):
    def __init__(
            self, c_img=3, c_mask=1, c_alpha=3,
            mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
            en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8,
            blend_layers=[0, 1, 2, 3, 4, 5], conv_type = 'normal'):
        super().__init__()

        c_init = c_img + c_mask

        self.n_en = len(en_ksize)
        self.n_de = len(de_ksize)
        assert self.n_en == self.n_de, (
            'The number layer of Encoder and Decoder must be equal.')
        assert self.n_en >= 1, (
            'The number layer of Encoder and Decoder must be greater than 1.')

        assert 0 in blend_layers, 'Layer 0 must be blended.'

        self.en = []
        c_in = c_init
        self.en.append(
            EncodeBlock(c_in, 64, conv_type, en_ksize[0], 2, None, None))
        for k_en in en_ksize[1:]:
            c_in = self.en[-1].c_out
            c_out = min(c_in*2, 512)
            self.en.append(EncodeBlock(
                c_in, c_out, conv_type, k_en, stride=2,
                normalization=norm, activation=act_en))

        # register parameters
        for i, en in enumerate(self.en):
            self.__setattr__('en_{}'.format(i), en)

        self.de = []
        self.fuse = []
        for i, k_de in enumerate(de_ksize):

            c_from_up = self.en[-1].c_out if i == 0 else self.de[-1].c_out
            c_out = c_from_down = self.en[-i-1].c_in
            layer_idx = self.n_de - i - 1

            self.de.append(DecodeBlock(
                c_from_up, c_from_down, conv_type, c_out, mode, k_de, scale=2,
                normalization=norm, activation=act_de))
            if layer_idx in blend_layers:
                self.fuse.append(FusionBlock(c_out, conv_type, c_alpha))
            else:
                self.fuse.append(None)

        # register parameters
        for i, de in enumerate(self.de[::-1]):
            self.__setattr__('de_{}'.format(i), de)
        for i, fuse in enumerate(self.fuse[::-1]):
            if fuse:
                self.__setattr__('fuse_{}'.format(i), fuse)

    def forward(self, img_miss, mask):

        out = torch.cat([img_miss, mask], dim=1)
        out_en = [out]

        for encode in self.en:
            out = encode(out)
            out_en.append(out)

        results = []
        for i, (decode, fuse) in enumerate(zip(self.de, self.fuse)):
            out = decode(out, out_en[-i-2])
            if fuse:
                result, alpha, raw = fuse(img_miss, out)
                results.append(result)
        return results[::-1][0]


In [None]:
#@title [LBAM_arch.py](https://github.com/Vious/LBAM_Pytorch) (2019)
"""
LBAMModel.py (18-12-20)
https://github.com/Vious/LBAM_Pytorch/blob/master/models/LBAMModel.py

forwardAttentionLayer.py (18-12-20)
https://github.com/Vious/LBAM_Pytorch/blob/98c2ae70486f4ba3ab86d4345e586e7841cfe343/models/forwardAttentionLayer.py

reverseAttentionLayer.py (18-12-20)
https://github.com/Vious/LBAM_Pytorch/blob/98c2ae70486f4ba3ab86d4345e586e7841cfe343/models/reverseAttentionLayer.py

ActivationFunction.py (18-12-20)
https://github.com/Vious/LBAM_Pytorch/blob/98c2ae70486f4ba3ab86d4345e586e7841cfe343/models/ActivationFunction.py
"""

import math
import torch
from torch.nn.parameter import Parameter
from torch import nn
from torchvision import models
import pytorch_lightning as pl

# asymmetric gaussian shaped activation function g_A
class GaussActivation(pl.LightningModule):
    def __init__(self, a, mu, sigma1, sigma2):
        super(GaussActivation, self).__init__()

        self.a = Parameter(torch.tensor(a, dtype=torch.float32))
        self.mu = Parameter(torch.tensor(mu, dtype=torch.float32))
        self.sigma1 = Parameter(torch.tensor(sigma1, dtype=torch.float32))
        self.sigma2 = Parameter(torch.tensor(sigma2, dtype=torch.float32))


    def forward(self, inputFeatures):

        self.a.data = torch.clamp(self.a.data, 1.01, 6.0)
        self.mu.data = torch.clamp(self.mu.data, 0.1, 3.0)
        self.sigma1.data = torch.clamp(self.sigma1.data, 0.5, 2.0)
        self.sigma2.data = torch.clamp(self.sigma2.data, 0.5, 2.0)

        lowerThanMu = inputFeatures < self.mu
        largerThanMu = inputFeatures >= self.mu

        leftValuesActiv = self.a * torch.exp(- self.sigma1 * ( (inputFeatures - self.mu) ** 2 ) )
        leftValuesActiv.masked_fill_(largerThanMu, 0.0)

        rightValueActiv = 1 + (self.a - 1) * torch.exp(- self.sigma2 * ( (inputFeatures - self.mu) ** 2 ) )
        rightValueActiv.masked_fill_(lowerThanMu, 0.0)

        output = leftValuesActiv + rightValueActiv

        return output

# mask updating functions, we recommand using alpha that is larger than 0 and lower than 1.0
class MaskUpdate(pl.LightningModule):
    def __init__(self, alpha):
        super(MaskUpdate, self).__init__()

        self.updateFunc = nn.ReLU(True)
        #self.alpha = Parameter(torch.tensor(alpha, dtype=torch.float32))
        self.alpha = alpha
    def forward(self, inputMaskMap):
        """ self.alpha.data = torch.clamp(self.alpha.data, 0.6, 0.8)
        print(self.alpha) """

        return torch.pow(self.updateFunc(inputMaskMap), self.alpha)


import math
import torch
from torch import nn
#from models.ActivationFunction import GaussActivation, MaskUpdate
#from models.weightInitial import weights_init


# learnable reverse attention conv
class ReverseMaskConv(pl.LightningModule):
    def __init__(self, inputChannels, outputChannels, kernelSize=4, stride=2,
        padding=1, dilation=1, groups=1, convBias=False):
        super(ReverseMaskConv, self).__init__()

        self.reverseMaskConv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, \
            dilation, groups, bias=convBias)

        #self.reverseMaskConv.apply(weights_init())

        self.activationFuncG_A = GaussActivation(1.1, 1.0, 0.5, 0.5)
        self.updateMask = MaskUpdate(0.8)

    def forward(self, inputMasks):
        maskFeatures = self.reverseMaskConv(inputMasks)

        maskActiv = self.activationFuncG_A(maskFeatures)

        maskUpdate = self.updateMask(maskFeatures)

        return maskActiv, maskUpdate

# learnable reverse attention layer, including features activation and batchnorm
class ReverseAttention(pl.LightningModule):
    def __init__(self, inputChannels, outputChannels, bn=False, activ='leaky', \
        kernelSize=4, stride=2, padding=1, outPadding=0,dilation=1, groups=1,convBias=False, bnChannels=512):
        super(ReverseAttention, self).__init__()

        self.conv = nn.ConvTranspose2d(inputChannels, outputChannels, kernel_size=kernelSize, \
            stride=stride, padding=padding, output_padding=outPadding, dilation=dilation, groups=groups,bias=convBias)

        #self.conv.apply(weights_init())

        if bn:
            self.bn = nn.BatchNorm2d(bnChannels)

        if activ == 'leaky':
            self.activ = nn.LeakyReLU(0.2, False)
        elif activ == 'relu':
            self.activ = nn.ReLU()
        elif activ == 'sigmoid':
            self.activ = nn.Sigmoid()
        elif activ == 'tanh':
            self.activ = nn.Tanh()
        elif activ == 'prelu':
            self.activ = nn.PReLU()
        else:
            pass

    def forward(self, ecFeaturesSkip, dcFeatures, maskFeaturesForAttention):
        nextDcFeatures = self.conv(dcFeatures)

        # note that encoder features are ahead, it's important tor make forward attention map ahead
        # of reverse attention map when concatenate, we do it in the LBAM model forward function
        concatFeatures = torch.cat((ecFeaturesSkip, nextDcFeatures), 1)

        outputFeatures = concatFeatures * maskFeaturesForAttention

        if hasattr(self, 'bn'):
            outputFeatures = self.bn(outputFeatures)
        if hasattr(self, 'activ'):
            outputFeatures = self.activ(outputFeatures)

        return outputFeatures


import math
import torch
from torch import nn
#from models.ActivationFunction import GaussActivation, MaskUpdate
#from models.weightInitial import weights_init

# learnable forward attention conv layer
class ForwardAttentionLayer(pl.LightningModule):
    def __init__(self, inputChannels, outputChannels, kernelSize, stride,
        padding, dilation=1, groups=1, bias=False):
        super(ForwardAttentionLayer, self).__init__()

        self.conv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, dilation, \
            groups, bias)

        if inputChannels == 4:
            self.maskConv = nn.Conv2d(3, outputChannels, kernelSize, stride, padding, dilation, \
                groups, bias)
        else:
            self.maskConv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, \
                dilation, groups, bias)

        #self.conv.apply(weights_init())
        #self.maskConv.apply(weights_init())

        self.activationFuncG_A = GaussActivation(1.1, 2.0, 1.0, 1.0)
        self.updateMask = MaskUpdate(0.8)

    def forward(self, inputFeatures, inputMasks):
        convFeatures = self.conv(inputFeatures)
        maskFeatures = self.maskConv(inputMasks)
        #convFeatures_skip = convFeatures.clone()

        maskActiv = self.activationFuncG_A(maskFeatures)
        convOut = convFeatures * maskActiv

        maskUpdate = self.updateMask(maskFeatures)

        return convOut, maskUpdate, convFeatures, maskActiv

# forward attention gather feature activation and batchnorm
class ForwardAttention(pl.LightningModule):
    def __init__(self, inputChannels, outputChannels, bn=False, sample='down-4', \
        activ='leaky', convBias=False):
        super(ForwardAttention, self).__init__()

        if sample == 'down-4':
            self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 4, 2, 1, bias=convBias)
        elif sample == 'down-5':
            self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 5, 2, 2, bias=convBias)
        elif sample == 'down-7':
            self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 7, 2, 3, bias=convBias)
        elif sample == 'down-3':
            self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 3, 2, 1, bias=convBias)
        else:
            self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 3, 1, 1, bias=convBias)

        if bn:
            self.bn = nn.BatchNorm2d(outputChannels)

        if activ == 'leaky':
            self.activ = nn.LeakyReLU(0.2, False)
        elif activ == 'relu':
            self.activ = nn.ReLU()
        elif activ == 'sigmoid':
            self.activ = nn.Sigmoid()
        elif activ == 'tanh':
            self.activ = nn.Tanh()
        elif activ == 'prelu':
            self.activ = nn.PReLU()
        else:
            pass

    def forward(self, inputFeatures, inputMasks):
        features, maskUpdated, convPreF, maskActiv = self.conv(inputFeatures, inputMasks)

        if hasattr(self, 'bn'):
            features = self.bn(features)
        if hasattr(self, 'activ'):
            features = self.activ(features)

        return features, maskUpdated, convPreF, maskActiv


import torch
import torch.nn as nn
from torchvision import models
#from models.forwardAttentionLayer import ForwardAttention
#from models.reverseAttentionLayer import ReverseAttention, ReverseMaskConv
#from models.weightInitial import weights_init

#VGG16 feature extract
class VGG16FeatureExtractor(pl.LightningModule):
    def __init__(self):
        super(VGG16FeatureExtractor, self).__init__()
        vgg16 = models.vgg16(pretrained=False)
        vgg16.load_state_dict(torch.load('./vgg16-397923af.pth'))
        self.enc_1 = nn.Sequential(*vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])

        # fix the encoder
        for i in range(3):
            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
                param.requires_grad = False

    def forward(self, image):
        results = [image]
        for i in range(3):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

class LBAMModel(pl.LightningModule):
    def __init__(self, inputChannels=4, outputChannels=3):
        super(LBAMModel, self).__init__()

        # default kernel is of size 4X4, stride 2, padding 1,
        # and the use of biases are set false in default ReverseAttention class.
        self.ec1 = ForwardAttention(inputChannels, 64, bn=False)
        self.ec2 = ForwardAttention(64, 128)
        self.ec3 = ForwardAttention(128, 256)
        self.ec4 = ForwardAttention(256, 512)

        for i in range(5, 8):
            name = 'ec{:d}'.format(i)
            setattr(self, name, ForwardAttention(512, 512))

        # reverse mask conv
        self.reverseConv1 = ReverseMaskConv(3, 64)
        self.reverseConv2 = ReverseMaskConv(64, 128)
        self.reverseConv3 = ReverseMaskConv(128, 256)
        self.reverseConv4 = ReverseMaskConv(256, 512)
        self.reverseConv5 = ReverseMaskConv(512, 512)
        self.reverseConv6 = ReverseMaskConv(512, 512)

        self.dc1 = ReverseAttention(512, 512, bnChannels=1024)
        self.dc2 = ReverseAttention(512 * 2, 512, bnChannels=1024)
        self.dc3 = ReverseAttention(512 * 2, 512, bnChannels=1024)
        self.dc4 = ReverseAttention(512 * 2, 256, bnChannels=512)
        self.dc5 = ReverseAttention(256 * 2, 128, bnChannels=256)
        self.dc6 = ReverseAttention(128 * 2, 64, bnChannels=128)
        self.dc7 = nn.ConvTranspose2d(64 * 2, outputChannels, kernel_size=4, stride=2, padding=1, bias=False)

        self.tanh = nn.Tanh()

    def forward(self, inputImgs, masks):
        inputImgs = torch.cat((inputImgs, masks), 1).type(torch.cuda.FloatTensor)
        masks = torch.cat([masks,masks,masks],1).type(torch.cuda.FloatTensor)

        ef1, mu1, skipConnect1, forwardMap1 = self.ec1(inputImgs, masks)
        ef2, mu2, skipConnect2, forwardMap2 = self.ec2(ef1, mu1)
        ef3, mu3, skipConnect3, forwardMap3 = self.ec3(ef2, mu2)
        ef4, mu4, skipConnect4, forwardMap4 = self.ec4(ef3, mu3)
        ef5, mu5, skipConnect5, forwardMap5 = self.ec5(ef4, mu4)
        ef6, mu6, skipConnect6, forwardMap6 = self.ec6(ef5, mu5)
        ef7, _, _, _ = self.ec7(ef6, mu6)


        reverseMap1, revMu1 = self.reverseConv1(1 - masks)
        reverseMap2, revMu2 = self.reverseConv2(revMu1)
        reverseMap3, revMu3 = self.reverseConv3(revMu2)
        reverseMap4, revMu4 = self.reverseConv4(revMu3)
        reverseMap5, revMu5 = self.reverseConv5(revMu4)
        reverseMap6, _ = self.reverseConv6(revMu5)

        concatMap6 = torch.cat((forwardMap6, reverseMap6), 1)
        dcFeatures1 = self.dc1(skipConnect6, ef7, concatMap6)

        concatMap5 = torch.cat((forwardMap5, reverseMap5), 1)
        dcFeatures2 = self.dc2(skipConnect5, dcFeatures1, concatMap5)

        concatMap4 = torch.cat((forwardMap4, reverseMap4), 1)
        dcFeatures3 = self.dc3(skipConnect4, dcFeatures2, concatMap4)

        concatMap3 = torch.cat((forwardMap3, reverseMap3), 1)
        dcFeatures4 = self.dc4(skipConnect3, dcFeatures3, concatMap3)

        concatMap2 = torch.cat((forwardMap2, reverseMap2), 1)
        dcFeatures5 = self.dc5(skipConnect2, dcFeatures4, concatMap2)

        concatMap1 = torch.cat((forwardMap1, reverseMap1), 1)
        dcFeatures6 = self.dc6(skipConnect1, dcFeatures5, concatMap1)

        dcFeatures7 = self.dc7(dcFeatures6)

        output = (self.tanh(dcFeatures7) + 1) / 2

        return output


In [None]:
#@title [Adaptive_arch.py](https://github.com/GuardSkill/AdaptiveGAN) (2019)
"""
blocks.py (13-12-20)
https://github.com/GuardSkill/AdaptiveGAN/blob/429311f6d22948904429ff1c19b0d953bc26ba81/src/blocks.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import pytorch_lightning as pl

class hswish(pl.LightningModule):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out


class hsigmoid(pl.LightningModule):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out


class SeModule(pl.LightningModule):
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
            # nn.BatchNorm2d(in_size // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
            # nn.BatchNorm2d(in_size),
            hsigmoid()
        )

    def forward(self, x):
        return x * self.se(x)


# class BottleneckBlock(pl.LightningModule):
#     expansion = 2
#
#     def __init__(self, in_channels, out_channels, stride=1, dilation=1, use_spectral_norm=False, downsample=None):
#         super(BottleneckBlock, self).__init__()
#         self.downsample = downsample
#         self.stride = stride
#         self.conv_block = nn.Sequential(
#             nn.ZeroPad2d(dilation),
#             spectral_norm(
#                 nn.Conv2d(in_channels=in_channels, out_channels=out_channels * 2, kernel_size=3, stride=stride,
#                           padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
#             # nn.LeakyReLU(0.2, inplace=False),
#             nn.Tanh(),
#             nn.ZeroPad2d(1),
#             spectral_norm(
#                 nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=3, stride=stride,
#                           padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
#         )
#
#     def forward(self, x):
#         residual = x
#         if self.downsample is not None:
#             residual = self.downsample(x)
#         out = self.conv_block(x) + residual
#         # out = nn.LeakyReLU(0.2, inplace=False)(out)
#         out = nn.Tanh()(out)
#         return out

class Block(pl.LightningModule):
    '''expand + depthwise + pointwise'''

    def __init__(self, kernel_size, in_size, expand_size, out_size, stride, dilation=1):
        super(Block, self).__init__()
        self.stride = stride
        self.se = SeModule(out_size)

        self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.nolinear1 = nn.Tanh()
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride,
                               padding=(kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2, dilation=dilation,
                               groups=expand_size, bias=False)
        self.nolinear2 = nn.Tanh()
        self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_size != out_size:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
            )

    def forward(self, x):
        out = self.nolinear1(self.conv1(x))
        out = self.nolinear2(self.conv2(out))
        out = self.conv3(out)
        if self.se != None:
            out = self.se(out)
        out = out + self.shortcut(x) if self.stride == 1 else out
        return out


class LinkNet(pl.LightningModule):
    def __init__(self, in_channels=3, residual_blocks=1, init_weights=True):
        super(LinkNet, self).__init__()
        self.conv1 =    nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Tanh()
        )
        # self.conv1 =Block(3, 3, 8, 16, 2)
        self.block1 = nn.Sequential(
            *[Block(3, 16, 16, 16, 1) for i in range(residual_blocks)]
            # residual_blocks1   kernel  input  expand output strike dilation
        )
        #  out 16
        self.conv2 = Block(3, 16, 64, 24, 2)
        self.block2 = nn.Sequential(
            *[Block(3, 24, 72, 24, 1) for _ in range(residual_blocks * 2)]  # residual_blocks1
        )
        #  out 24

        self.conv3 = Block(5, 24, 72, 40, 2)
        self.block3 = nn.Sequential(
            *[Block(5, 40, 120, 40, 1) for _ in range(residual_blocks * 3)]  # residual_blocks2
        )
        #  out 40

        self.conv4 = Block(3, 40, 240, 80, 2, 1)
        self.block4 = nn.Sequential(
            *[Block(3, 80, 200, 80, 1, 4) for _ in range(residual_blocks * 4)]  # residual_blocks3
        )
        #  out 80

        self.conv5 = Block(5, 80, 480, 160, 2)
        self.block5 = nn.Sequential(
            *[Block(5, 160, 672, 160, 1, 4) for _ in range(residual_blocks * 4)]
        )
        #  out 160

        # self.up1 = nn.Sequential(
        #     nn.Conv2d(16, 16, 1, 1, 0, bias=True),
        #     nn.Tanh(),
        #     # nn.Upsample(scale_factor=2 << 0, mode='bilinear')
        # )

        self.up2 = nn.Sequential(
            nn.Conv2d(24, 16, 3, 1, 1, bias=False),
            nn.Tanh(),
            nn.Upsample(scale_factor=2 << 0, mode='bilinear')
        )

        self.up3 = nn.Sequential(
            nn.Conv2d(40, 16, 3, 1, 1, bias=False),
            nn.Tanh(),
            nn.Upsample(scale_factor=2 << 1, mode='bilinear')
        )
        self.up4 = nn.Sequential(
            nn.Conv2d(80, 16, 3, 1, 1, bias=False),
            nn.Tanh(),
            nn.Upsample(scale_factor=2 << 2, mode='bilinear')
        )

        self.up5 = nn.Sequential(
            nn.Conv2d(160, 16, 3, 1, 1, bias=False),
            nn.Tanh(),
            nn.Upsample(scale_factor=2 << 3, mode='bilinear')
        )

        self.fusion = nn.Sequential(
            *[Block(5, 80, 160, 80, 1) for _ in range(residual_blocks * 4)],  # 3x3 original:residual_blocks*2
            Block(5, 80, 48, 32, 1)
            # nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        )

        self.block_fusion = nn.Sequential(
            *[Block(5, 32, 48, 32, 1) for _ in range(residual_blocks * 4)]  # 3x3 original:residual_blocks*2
            # nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.final = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1, bias=False)
        # self.final =Block(5, 32, 48, 3,1)      #kernel_size, in_size, expand_size, out_size, stride
        #  out 160

    #     self.init_params()
    #
    # def init_params(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             init.kaiming_normal_(m.weight, mode='fan_out')
    #             if m.bias is not None:
    #                 init.constant_(m.bias, 0)
    #         elif isinstance(m, nn.BatchNorm2d):
    #             init.constant_(m.weight, 1)
    #             init.constant_(m.bias, 0)
    #         elif isinstance(m, nn.Linear):
    #             init.normal_(m.weight, std=0.001)
    #             if m.bias is not None:
    #                 init.constant_(m.bias, 0)

    def forward(self, x):
        x1 = self.block1(self.conv1(x))
        x2 = self.block2(self.conv2(x1))
        x3 = self.block3(self.conv3(x2))
        x4 = self.block4(self.conv4(x3))
        x5 = self.block5(self.conv5(x4))
        # x=x1+self.up2(x2)+self.up3(x3)+self.up4(x4)+self.up5(x5)
        x = torch.cat([x1, self.up2(x2), self.up3(x3), self.up4(x4), self.up5(x5)], 1)
        x = self.block_fusion(self.fusion(x))
        x = self.final(x)
        out = (torch.tanh(x) + 1) / 2
        return out


class PyramidNet(pl.LightningModule):
    def __init__(self, in_channels=3, residual_blocks=1, init_weights=True):
        super(PyramidNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.Tanh()
        )
        self.block1 = nn.Sequential(
            *[Block(3, 16, 16, 16, 1) for i in range(residual_blocks)]
            # residual_blocks1   kernel  input  expand output strike dilation
        )
        #  out 16
        self.conv2 = Block(3, 16, 64, 24, 2)
        self.block2 = nn.Sequential(
            *[Block(3, 24, 72, 24, 1) for _ in range(residual_blocks * 2)]  # residual_blocks1
        )
        #  out 24

        self.conv3 = Block(5, 24, 72, 40, 2)
        self.block3 = nn.Sequential(
            *[Block(5, 40, 120, 40, 1) for _ in range(residual_blocks * 3)]  # residual_blocks2
        )
        #  out 40

        self.conv4 = Block(3, 40, 240, 80, 2, 2)
        self.block4 = nn.Sequential(
            *[Block(3, 80, 200, 80, 1, 4) for _ in range(residual_blocks * 4)]  # residual_blocks3
        )
        #  out 80

        self.conv5 = Block(5, 80, 480, 160, 2)
        self.block5 = nn.Sequential(
            *[Block(5, 160, 672, 160, 1, 4) for _ in range(residual_blocks * 4)]
        )
        #  out 160

        # self.up1 = nn.Sequential(
        #     nn.Conv2d(16, 16, 1, 1, 0, bias=True),
        #     nn.Tanh(),
        #     # nn.Upsample(scale_factor=2 << 0, mode='bilinear')
        # )

        self.channel_reduce4 = nn.Sequential(
            nn.Conv2d(160, 80, 3, 1, 1, bias=False),
            nn.Tanh(),
        )

        self.channel_reduce3 = nn.Sequential(
            nn.Conv2d(80, 40, 3, 1, 1, bias=False),
            nn.Tanh(),
        )

        self.channel_reduce2 = nn.Sequential(
            nn.Conv2d(40, 24, 3, 1, 1, bias=False),
            nn.Tanh(),
        )

        self.channel_reduce1 = nn.Sequential(
            nn.Conv2d(24, 16, 3, 1, 1, bias=False),
            nn.Tanh(),
        )

        self.smooth4 = nn.Conv2d(80, 80, kernel_size=3, stride=1, padding=1)
        self.smooth3 = nn.Conv2d(40, 40, kernel_size=3, stride=1, padding=1)
        self.smooth2 = nn.Conv2d(24, 24, kernel_size=3, stride=1, padding=1)
        self.smooth1 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)

        self.fusion = nn.Sequential(
            *[Block(5, 80, 160, 80, 1) for _ in range(residual_blocks * 4)],  # 3x3 original:residual_blocks*2
            Block(5, 80, 48, 32, 1)
            # nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        )

        self.block_fusion = nn.Sequential(
            *[Block(5, 32, 48, 32, 1) for _ in range(residual_blocks * 4)]  # 3x3 original:residual_blocks*2
            # nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        )
        self.final = nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1, bias=False)
        # self.final =Block(5, 32, 48, 3,1)      #kernel_size, in_size, expand_size, out_size, stride
        #  out 160

    #     self.init_params()
    #
    # def init_params(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             init.kaiming_normal_(m.weight, mode='fan_out')
    #             if m.bias is not None:
    #                 init.constant_(m.bias, 0)
    #         elif isinstance(m, nn.BatchNorm2d):
    #             init.constant_(m.weight, 1)
    #             init.constant_(m.bias, 0)
    #         elif isinstance(m, nn.Linear):
    #             init.normal_(m.weight, std=0.001)
    #             if m.bias is not None:
    #                 init.constant_(m.bias, 0)

    def forward(self, images, masks):
        x = torch.cat((images, masks), dim=1)
        x1 = self.block1(self.conv1(x))
        x2 = self.block2(self.conv2(x1))
        x3 = self.block3(self.conv3(x2))
        x4 = self.block4(self.conv4(x3))
        x5 = self.block5(self.conv5(x4))
        c4=self.smooth4(self._upsample_add( self.channel_reduce4(x5),x4))
        c3=self.smooth3(self._upsample_add( self.channel_reduce3(c4),x3))
        c2=self.smooth2(self._upsample_add( self.channel_reduce2(c3),x2))
        c1=self.smooth1(self._upsample_add( self.channel_reduce1(c2),x1))
        x = self.final(c1)
        out = (torch.tanh(x) + 1) / 2
        return out

    def _upsample_add(self, x, y):
        '''Upsample and add two feature maps.
        Args:
          x: (Variable) top feature map to be upsampled.
          y: (Variable) lateral feature map.
        Returns:
          (Variable) added feature map.
        Note in PyTorch, when input size is odd, the upsampled feature map
        with `F.upsample(..., scale_factor=2, mode='nearest')`
        maybe not equal to the lateral feature map size.
        e.g.
        original input size: [N,_,15,15] ->
        conv2d feature map size: [N,_,8,8] ->
        upsampled feature map size: [N,_,16,16]
        So we choose bilinear upsample which supports arbitrary output sizes.
        '''
        _, _, H, W = y.size()
        return F.upsample(x, size=(H, W), mode='bilinear') + y


In [None]:
#@title [partial_arch.py](https://github.com/jacobaustin123/pytorch-inpainting-partial-conv) (2018)
"""
model.py (24-12-20)
https://github.com/jacobaustin123/pytorch-inpainting-partial-conv/blob/master/model.py
"""

from torch import nn, cuda
from torchvision import models
import torch
import torch.nn as nn
import torch.nn.functional as F
#from .convolutions import partialconv2d
import pytorch_lightning as pl

class PartialLayer(pl.LightningModule):
    def __init__(self, in_size, out_size, kernel_size, stride, non_linearity="relu", bn=True, multi_channel=False):
        super(PartialLayer, self).__init__()

        self.conv = PartialConv2d(in_size, out_size, kernel_size, stride, return_mask=True, padding=(kernel_size - 1) // 2, multi_channel=multi_channel, bias=not bn)

        self.bn = nn.BatchNorm2d(out_size) if bn else None

        if non_linearity == "relu":
            self.non_linearity = nn.ReLU()
        elif non_linearity == "leaky":
           self.non_linearity = nn.LeakyReLU(negative_slope=0.2)
        elif non_linearity == 'sigmoid':
            self.non_linearity = nn.Sigmoid()
        elif non_linearity == 'tanh':
            self.non_linearity = nn.Tanh()
        elif non_linearity is None:
            self.non_linearity = None
        else:
            raise ValueError("unexpected value for non_linearity")

    def forward(self, x, mask_in=None, return_mask=True):
        x, mask = self.conv(x, mask_in=mask_in)

        if self.bn:
            x = self.bn(x)

        if self.non_linearity:
            x = self.non_linearity(x)

        if return_mask:
            return x, mask
        else:
            return x



class Model(pl.LightningModule):
    def __init__(self, freeze_bn=False):
        super(Model, self).__init__()

        self.freeze_bn = freeze_bn # freeze bn layers for fine tuning

        self.conv1 = PartialLayer(3, 64, 7, 2) # encoder for UNET,  use relu for encoder
        self.conv2 = PartialLayer(64, 128, 5, 2)
        self.conv3 = PartialLayer(128, 256, 5, 2)
        self.conv4 = PartialLayer(256, 512, 3, 2)
        self.conv5 = PartialLayer(512, 512, 3, 2)
        self.conv6 = PartialLayer(512, 512, 3, 2)
        self.conv7 = PartialLayer(512, 512, 3, 2)
        self.conv8 = PartialLayer(512, 512, 3, 2)

        self.conv9 = PartialLayer(2 * 512, 512, 3, 1, non_linearity="leaky", multi_channel=True) # decoder for UNET
        self.conv10 = PartialLayer(2 * 512, 512, 3, 1, non_linearity="leaky", multi_channel=True)
        self.conv11 = PartialLayer(2 * 512, 512, 3, 1, non_linearity="leaky", multi_channel=True)
        self.conv12 = PartialLayer(2 * 512, 512, 3, 1, non_linearity="leaky", multi_channel=True)
        self.conv13 = PartialLayer(512 + 256, 256, 3, 1, non_linearity="leaky", multi_channel=True)
        self.conv14 = PartialLayer(256 + 128, 128, 3, 1, non_linearity="leaky", multi_channel=True)
        self.conv15 = PartialLayer(128 + 64, 64, 3, 1, non_linearity="leaky", multi_channel=True)
        self.conv16 = PartialLayer(64 + 3, 3, 3, 1, non_linearity="tanh", bn=False, multi_channel=True)
    def concat(self, input, prev):
        return torch.cat([F.interpolate(input, scale_factor=2), prev], dim=1)

    def repeat(self, mask, size1, size2):
        return torch.cat([mask[:,0].unsqueeze(1).repeat(1, size1, 1, 1), mask[:,1].unsqueeze(1).repeat(1, size2, 1, 1)], dim=1)

    def forward(self, x, mask):
        x1, mask1 = self.conv1(x.type(torch.cuda.FloatTensor), mask_in=mask.type(torch.cuda.FloatTensor))
        x2, mask2 = self.conv2(x1, mask_in=mask1)
        x3, mask3 = self.conv3(x2, mask_in=mask2)
        x4, mask4 = self.conv4(x3, mask_in=mask3)
        x5, mask5 = self.conv5(x4, mask_in=mask4)
        x6, mask6 = self.conv6(x5, mask_in=mask5)
        x7, mask7 = self.conv7(x6, mask_in=mask6)
        x8, mask8 = self.conv8(x7, mask_in=mask7)

        x9, mask9 = self.conv9(self.concat(x8, x7), mask_in=self.repeat(self.concat(mask8, mask7), 512, 512))
        x10, mask10 = self.conv10(self.concat(x9, x6), mask_in=self.repeat(self.concat(mask9, mask6), 512, 512))
        x11, mask11 = self.conv11(self.concat(x10, x5), mask_in=self.repeat(self.concat(mask10, mask5), 512, 512))
        x12, mask12 = self.conv12(self.concat(x11, x4), mask_in=self.repeat(self.concat(mask11, mask4), 512, 512))
        x13, mask13 = self.conv13(self.concat(x12, x3), mask_in=self.repeat(self.concat(mask12, mask3), 512, 256))
        x14, mask14 = self.conv14(self.concat(x13, x2), mask_in=self.repeat(self.concat(mask13, mask2), 256, 128))
        x15, mask15 = self.conv15(self.concat(x14, x1), mask_in=self.repeat(self.concat(mask14, mask1), 128, 64))
        out, mask16 = self.conv16(self.concat(x15, x), mask_in=self.repeat(self.concat(mask15, mask), 64, 3))

        return out


Here are some generators with two outputs. If you decide to use one of them, then make sure you calculate the first stage loss inside ``CustomTrainClass``. Custom combinations are possible, but currently there is mostly just l1 for the first stage / other image.

With two outputs:

In [None]:
#@title [deepfillv2_arch.py](https://github.com/zhaoyuzhi/deepfillv2) (2019)
"""
network.py (15-12-20)
https://github.com/zhaoyuzhi/deepfillv2/blob/master/deepfillv2/network.py

network_module.py (15-12-20)
https://github.com/zhaoyuzhi/deepfillv2/blob/master/deepfillv2/network_module.py
"""

#from network_module import *
#from .convolutions import partialconv2d
from torch import nn
from torch.nn import Parameter
from torch.nn import functional as F
import logging
import torch
import torch.nn as nn
import torch.nn.init as init
logger = logging.getLogger('base')
import pytorch_lightning as pl

#-----------------------------------------------
#                Normal ConvBlock
#-----------------------------------------------
class Conv2dLayer(pl.LightningModule):
    def __init__(self, in_channels, out_channels, conv_type, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False):
        super(Conv2dLayer, self).__init__()
        # Initialize the padding scheme
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # Initialize the normalization type
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(out_channels)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(out_channels)
        elif norm == 'ln':
            self.norm = LayerNorm(out_channels)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # Initialize the activation funtion
        if activation == 'relu':
            self.activation = nn.ReLU(inplace = True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace = True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace = True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # Initialize the convolution layers
        if sn:
            print("sn")
            self.conv2d = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation))
        else:
            if conv_type == 'normal':
              self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
            elif conv_type == 'partial':
              self.conv2d = PartialConv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
            else:
              print("conv_type not implemented")

    def forward(self, x):
        x = self.pad(x)
        x = self.conv2d(x)
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

class TransposeConv2dLayer(pl.LightningModule):
    def __init__(self, in_channels, out_channels, conv_type, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = False, scale_factor = 2):
        super(TransposeConv2dLayer, self).__init__()
        # Initialize the conv scheme
        self.scale_factor = scale_factor
        self.conv2d = Conv2dLayer(in_channels, out_channels, conv_type, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)

    def forward(self, x):
        x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')
        x = self.conv2d(x)
        return x

#-----------------------------------------------
#                Gated ConvBlock
#-----------------------------------------------
class GatedConv2d(pl.LightningModule):
    def __init__(self, in_channels, out_channels, conv_type, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'reflect', activation = 'lrelu', norm = 'none', sn = False):
        super(GatedConv2d, self).__init__()
        # Initialize the padding scheme
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # Initialize the normalization type
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(out_channels)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(out_channels)
        elif norm == 'ln':
            self.norm = LayerNorm(out_channels)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # Initialize the activation funtion
        if activation == 'relu':
            self.activation = nn.ReLU(inplace = True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace = True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace = True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # Initialize the convolution layers
        if sn:
            self.conv2d = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation))
            self.mask_conv2d = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation))
        else:
            if conv_type == 'normal':
              self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
              self.mask_conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
            elif conv_type == 'partial':
              self.conv2d = PartialConv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
              self.mask_conv2d = PartialConv2d(in_channels, out_channels, kernel_size, stride, padding = 0, dilation = dilation)
            else:
              print("conv_type not implemented")
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.pad(x)
        conv = self.conv2d(x)
        mask = self.mask_conv2d(x)
        gated_mask = self.sigmoid(mask)
        x = conv * gated_mask
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

class TransposeGatedConv2d(pl.LightningModule):
    def __init__(self, in_channels, out_channels, conv_type, kernel_size, stride = 1, padding = 0, dilation = 1, pad_type = 'zero', activation = 'lrelu', norm = 'none', sn = True, scale_factor = 2):
        super(TransposeGatedConv2d, self).__init__()
        # Initialize the conv scheme
        self.scale_factor = scale_factor
        self.gated_conv2d = GatedConv2d(in_channels, out_channels, conv_type, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn)

    def forward(self, x):
        x = F.interpolate(x, scale_factor = self.scale_factor, mode = 'nearest')
        x = self.gated_conv2d(x)
        return x

# ----------------------------------------
#               Layer Norm
# ----------------------------------------
class LayerNorm(pl.LightningModule):
    def __init__(self, num_features, eps = 1e-8, affine = True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = Parameter(torch.Tensor(num_features).uniform_())
            self.beta = Parameter(torch.zeros(num_features))

    def forward(self, x):
        # layer norm
        shape = [-1] + [1] * (x.dim() - 1)                                  # for 4d input: [-1, 1, 1, 1]
        if x.size(0) == 1:
            # These two lines run much faster in pytorch 0.4 than the two lines listed below.
            mean = x.view(-1).mean().view(*shape)
            std = x.view(-1).std().view(*shape)
        else:
            mean = x.view(x.size(0), -1).mean(1).view(*shape)
            std = x.view(x.size(0), -1).std(1).view(*shape)
        x = (x - mean) / (std + self.eps)
        # if it is learnable
        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)                          # for 4d input: [1, -1, 1, 1]
            x = x * self.gamma.view(*shape) + self.beta.view(*shape)
        return x

#-----------------------------------------------
#                  SpectralNorm
#-----------------------------------------------
def l2normalize(v, eps = 1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(pl.LightningModule):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


def deepfillv2_weights_init(net, init_type = 'kaiming', init_gain = 0.02):
    #Initialize network weights.
    #Parameters:
    #    net (network)       -- network to be initialized
    #    init_type (str)     -- the name of an initialization method: normal | xavier | kaiming | orthogonal
    #    init_var (float)    -- scaling factor for normal, xavier and orthogonal.

    def init_func(m):
        classname = m.__class__.__name__

        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain = init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a = 0, mode = 'fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain = init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            init.normal_(m.weight, 0, 0.01)
            init.constant_(m.bias, 0)

    # Apply the initialization function <init_func>
    logger.info('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)

#-----------------------------------------------
#                   Generator
#-----------------------------------------------
# Input: masked image + mask
# Output: filled image

#https://github.com/zhaoyuzhi/deepfillv2/blob/62dad2c601400e14d79f4d1e090c2effcb9bf3eb/deepfillv2/train.py
class GatedGenerator(pl.LightningModule):
    def __init__(self, in_channels = 4, out_channels = 3, latent_channels = 64, pad_type = 'zero', activation = 'lrelu', norm = 'in', conv_type = 'normal'):
        super(GatedGenerator, self).__init__()

        self.coarse = nn.Sequential(
            # encoder
            GatedConv2d(in_channels, latent_channels, conv_type, 7, 1, 3, pad_type = pad_type, activation = activation, norm = 'none'),
            GatedConv2d(latent_channels, latent_channels * 2, conv_type, 4, 2, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 2, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 4, 2, 1, pad_type = pad_type, activation = activation, norm = norm),
            # Bottleneck
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 2, dilation = 2, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 4, dilation = 4, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 8, dilation = 8, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 16, dilation = 16, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            # decoder
            TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 2, latent_channels * 2, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels, out_channels, conv_type, 7, 1, 3, pad_type = pad_type, activation = 'tanh', norm = 'none')
        )
        self.refinement = nn.Sequential(
            # encoder
            GatedConv2d(in_channels, latent_channels, conv_type, 7, 1, 3, pad_type = pad_type, activation = activation, norm = 'none'),
            GatedConv2d(latent_channels, latent_channels * 2, conv_type, 4, 2, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 2, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 4, 2, 1, pad_type = pad_type, activation = activation, norm = norm),
            # Bottleneck
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 2, dilation = 2, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 4, dilation = 4, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 8, dilation = 8, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 16, dilation = 16, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 4, latent_channels * 4, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            # decoder
            TransposeGatedConv2d(latent_channels * 4, latent_channels * 2, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels * 2, latent_channels * 2, conv_type, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            TransposeGatedConv2d(latent_channels * 2, latent_channels, 3, 1, 1, pad_type = pad_type, activation = activation, norm = norm),
            GatedConv2d(latent_channels, out_channels, conv_type, 7, 1, 3, pad_type = pad_type, activation = 'tanh', norm = 'none')
        )


    def forward(self, img, mask):
        # img: entire img
        # mask: 1 for mask region; 0 for unmask region
        # 1 - mask: unmask
        # img * (1 - mask): ground truth unmask region
        # Coarse
        #print(img.shape, mask.shape)
        first_masked_img = img * (1 - mask) + mask
        first_in = torch.cat((first_masked_img, mask), 1)       # in: [B, 4, H, W]
        first_out = self.coarse(first_in)                       # out: [B, 3, H, W]
        # Refinement
        second_masked_img = img * (1 - mask) + first_out * mask
        second_in = torch.cat((second_masked_img, mask), 1)     # in: [B, 4, H, W]
        second_out = self.refinement(second_in)                 # out: [B, 3, H, W]
        #return first_out, second_out
        #return second_out
        return second_out, first_out


In [None]:
#@title [deepfillv1_arch.py](https://github.com/avalonstrel/GatedConvolution_pytorch) (2018)
"""
networks.py (12-12-20)
https://github.com/avalonstrel/GatedConvolution_pytorch/blob/master/models/networks.py

sa_gan.py (13-12-20)
https://github.com/avalonstrel/GatedConvolution_pytorch/blob/master/models/sa_gan.py

spectral.py (13-12-20)
https://github.com/avalonstrel/GatedConvolution_pytorch/blob/master/models/spectral.py
"""

import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import pytorch_lightning as pl

def get_pad(in_,  ksize, stride, atrous=1):
    out_ = np.ceil(float(in_)/stride)
    return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)


class GatedConv2dWithActivation(pl.LightningModule):
    """
    Gated Convlution layer with activation (default activation:LeakyReLU)
    Params: same as conv2d
    Input: The feature from last layer "I"
    Output:\phi(f(I))*\sigmoid(g(I))
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
        super(GatedConv2dWithActivation, self).__init__()
        self.batch_norm = batch_norm
        self.activation = activation
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
        self.sigmoid = torch.nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
    def gated(self, mask):
        #return torch.clamp(mask, -1, 1)
        return self.sigmoid(mask)
    def forward(self, input):
        x = self.conv2d(input)
        mask = self.mask_conv2d(input)
        if self.activation is not None:
            x = self.activation(x) * self.gated(mask)
        else:
            x = x * self.gated(mask)
        if self.batch_norm:
            return self.batch_norm2d(x)
        else:
            return x

class GatedDeConv2dWithActivation(pl.LightningModule):
    """
    Gated DeConvlution layer with activation (default activation:LeakyReLU)
    resize + conv
    Params: same as conv2d
    Input: The feature from last layer "I"
    Output:\phi(f(I))*\sigmoid(g(I))
    """
    def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=True,activation=torch.nn.LeakyReLU(0.2, inplace=True)):
        super(GatedDeConv2dWithActivation, self).__init__()
        self.conv2d = GatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, batch_norm, activation)
        self.scale_factor = scale_factor

    def forward(self, input):
        #print(input.size())
        x = F.interpolate(input, scale_factor=2)
        return self.conv2d(x)

class SNGatedConv2dWithActivation(pl.LightningModule):
    """
    Gated Convolution with spetral normalization
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
        super(SNGatedConv2dWithActivation, self).__init__()
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.activation = activation
        self.batch_norm = batch_norm
        self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
        self.sigmoid = torch.nn.Sigmoid()
        self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
        self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)

    def gated(self, mask):
        return self.sigmoid(mask)
        #return torch.clamp(mask, -1, 1)

    def forward(self, input):
        x = self.conv2d(input)
        mask = self.mask_conv2d(input)
        if self.activation is not None:
            x = self.activation(x) * self.gated(mask)
        else:
            x = x * self.gated(mask)
        if self.batch_norm:
            return self.batch_norm2d(x)
        else:
            return x
class SNGatedDeConv2dWithActivation(pl.LightningModule):
    """
    Gated DeConvlution layer with activation (default activation:LeakyReLU)
    resize + conv
    Params: same as conv2d
    Input: The feature from last layer "I"
    Output:\phi(f(I))*\sigmoid(g(I))
    """
    def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
        super(SNGatedDeConv2dWithActivation, self).__init__()
        self.conv2d = SNGatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, batch_norm, activation)
        self.scale_factor = scale_factor

    def forward(self, input):
        #print(input.size())
        x = F.interpolate(input, scale_factor=2)
        return self.conv2d(x)

class SNConvWithActivation(pl.LightningModule):
    """
    SN convolution for spetral normalization conv
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
        super(SNConvWithActivation, self).__init__()
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
        self.activation = activation
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
    def forward(self, input):
        x = self.conv2d(input)
        if self.activation is not None:
            return self.activation(x)
        else:
            return x



import torch
from torch.optim.optimizer import Optimizer, required

from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.nn import Parameter

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(pl.LightningModule):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)





import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
#from .spectral import SpectralNorm
#from .networks import GatedConv2dWithActivation, GatedDeConv2dWithActivation, SNConvWithActivation, get_pad
class Self_Attn(pl.LightningModule):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation,with_attn=False):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.with_attn = with_attn
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N)
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        if self.with_attn:
            return out ,attention
        else:
            return out

class SAGenerator(pl.LightningModule):
    """Generator."""

    def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
        super(Generator, self).__init__()
        self.imsize = image_size
        layer1 = []
        layer2 = []
        layer3 = []
        last = []

        repeat_num = int(np.log2(self.imsize)) - 3
        mult = 2 ** repeat_num # 8
        layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
        layer1.append(nn.BatchNorm2d(conv_dim * mult))
        layer1.append(nn.ReLU())

        curr_dim = conv_dim * mult

        layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
        layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
        layer2.append(nn.ReLU())

        curr_dim = int(curr_dim / 2)

        layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
        layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
        layer3.append(nn.ReLU())

        if self.imsize == 64:
            layer4 = []
            curr_dim = int(curr_dim / 2)
            layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
            layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
            layer4.append(nn.ReLU())
            self.l4 = nn.Sequential(*layer4)
            curr_dim = int(curr_dim / 2)

        self.l1 = nn.Sequential(*layer1)
        self.l2 = nn.Sequential(*layer2)
        self.l3 = nn.Sequential(*layer3)

        last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
        last.append(nn.Tanh())
        self.last = nn.Sequential(*last)

        self.attn1 = Self_Attn( 128, 'relu')
        self.attn2 = Self_Attn( 64,  'relu')

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        out=self.l1(z)
        out=self.l2(out)
        out=self.l3(out)
        out,p1 = self.attn1(out)
        out=self.l4(out)
        out,p2 = self.attn2(out)
        out=self.last(out)

        return out, p1, p2

class InpaintSANet(pl.LightningModule):
    """
    Inpaint generator, input should be 5*256*256, where 3*256*256 is the masked image, 1*256*256 for mask, 1*256*256 is the guidence
    """
    def __init__(self, n_in_channel=5):
        super(InpaintSANet, self).__init__()
        cnum = 32
        self.coarse_net = nn.Sequential(
            #input is 5*256*256, but it is full convolution network, so it can be larger than 256
            GatedConv2dWithActivation(n_in_channel, cnum, 5, 1, padding=get_pad(256, 5, 1)),
            # downsample 128
            GatedConv2dWithActivation(cnum, 2*cnum, 4, 2, padding=get_pad(256, 4, 2)),
            GatedConv2dWithActivation(2*cnum, 2*cnum, 3, 1, padding=get_pad(128, 3, 1)),
            #downsample to 64
            GatedConv2dWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 4, 2)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            # atrous convlution
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=4, padding=get_pad(64, 3, 1, 4)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=16, padding=get_pad(64, 3, 1, 16)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            #Self_Attn(4*cnum, 'relu'),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            # upsample
            GatedDeConv2dWithActivation(2, 4*cnum, 2*cnum, 3, 1, padding=get_pad(128, 3, 1)),
            #Self_Attn(2*cnum, 'relu'),
            GatedConv2dWithActivation(2*cnum, 2*cnum, 3, 1, padding=get_pad(128, 3, 1)),
            GatedDeConv2dWithActivation(2, 2*cnum, cnum, 3, 1, padding=get_pad(256, 3, 1)),

            GatedConv2dWithActivation(cnum, cnum//2, 3, 1, padding=get_pad(256, 3, 1)),
            #Self_Attn(cnum//2, 'relu'),
            GatedConv2dWithActivation(cnum//2, 3, 3, 1, padding=get_pad(128, 3, 1), activation=None)
        )

        self.refine_conv_net = nn.Sequential(
            # input is 5*256*256
            GatedConv2dWithActivation(n_in_channel, cnum, 5, 1, padding=get_pad(256, 5, 1)),
            # downsample
            GatedConv2dWithActivation(cnum, cnum, 4, 2, padding=get_pad(256, 4, 2)),
            GatedConv2dWithActivation(cnum, 2*cnum, 3, 1, padding=get_pad(128, 3, 1)),
            # downsample
            GatedConv2dWithActivation(2*cnum, 2*cnum, 4, 2, padding=get_pad(128, 4, 2)),
            GatedConv2dWithActivation(2*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=2, padding=get_pad(64, 3, 1, 2)),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=4, padding=get_pad(64, 3, 1, 4)),
            #Self_Attn(4*cnum, 'relu'),
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=8, padding=get_pad(64, 3, 1, 8)),

            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, dilation=16, padding=get_pad(64, 3, 1, 16))
        )
        self.refine_attn = Self_Attn(4*cnum, 'relu', with_attn=False)
        self.refine_upsample_net = nn.Sequential(
            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),

            GatedConv2dWithActivation(4*cnum, 4*cnum, 3, 1, padding=get_pad(64, 3, 1)),
            GatedDeConv2dWithActivation(2, 4*cnum, 2*cnum, 3, 1, padding=get_pad(128, 3, 1)),
            GatedConv2dWithActivation(2*cnum, 2*cnum, 3, 1, padding=get_pad(128, 3, 1)),
            GatedDeConv2dWithActivation(2, 2*cnum, cnum, 3, 1, padding=get_pad(256, 3, 1)),

            GatedConv2dWithActivation(cnum, cnum//2, 3, 1, padding=get_pad(256, 3, 1)),
            #Self_Attn(cnum, 'relu'),
            GatedConv2dWithActivation(cnum//2, 3, 3, 1, padding=get_pad(256, 3, 1), activation=None),
        )


    def forward(self, imgs, masks, img_exs=None):
        # Coarse
        #masked_imgs =  imgs * (1 - masks) + masks

        if img_exs == None:
            input_imgs = torch.cat([imgs, masks, torch.full_like(masks, 1.)], dim=1)
        else:
            input_imgs = torch.cat([imgs, img_exs, masks, torch.full_like(masks, 1.)], dim=1)
        #print(input_imgs.size(), imgs.size(), masks.size())
        x = self.coarse_net(input_imgs)
        x = torch.clamp(x, -1., 1.)
        coarse_x = x
        # Refine
        masked_imgs = imgs * (1 - masks) + coarse_x * masks
        if img_exs is None:
            input_imgs = torch.cat([masked_imgs, masks, torch.full_like(masks, 1.)], dim=1)
        else:
            input_imgs = torch.cat([masked_imgs, img_exs, masks, torch.full_like(masks, 1.)], dim=1)
        x = self.refine_conv_net(input_imgs)
        x= self.refine_attn(x)
        #print(x.size(), attention.size())
        x = self.refine_upsample_net(x)
        x = torch.clamp(x, -1., 1.)

        return x, coarse_x


Special:

Needs more attention because of amount of outputs, custom inputs or custom loss calculation. Assumes you can configure everything correctly inside ``CustomTrainClass``. Example usage is [here](https://github.com/styler00dollar/Colab-BasicSR/blob/master/codes/models/inpaint_model.py).

In [None]:
#@title [Pluralistic_arch.py](https://github.com/lyndonzheng/Pluralistic-Inpainting) (should work, but ``CustomTrainClass`` is not ready) (2019)
"""
network.py (13-12-20)
https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/master/model/network.py

external_function.py (13-12-20)
https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/master/model/external_function.py

base_function.py (13-12-20)
https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/master/model/base_function.py

external_function.py (13-12-20)
https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/master/model/external_function.py

task.py (16-12-20)
https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/1ca1855615fed8b686ca218c6494f455860f9996/util/task.py
"""

from PIL import Image
from random import randint
from torch import nn
from torch.nn import Parameter
from torch.nn import init
from torch.optim import lr_scheduler
import copy
import cv2
import functools
import logging
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
logger = logging.getLogger('base')
import pytorch_lightning as pl

###################################################################
# multi scale for image generation
###################################################################


def scale_img(img, size):
    scaled_img = F.interpolate(img, size=size, mode='bilinear', align_corners=True)
    return scaled_img


def scale_pyramid(img, num_scales):
    scaled_imgs = [img]

    s = img.size()

    h = s[2]
    w = s[3]

    for i in range(1, num_scales):
        ratio = 2**i
        nh = h // ratio
        nw = w // ratio
        scaled_img = scale_img(img, size=[nh, nw])
        scaled_imgs.append(scaled_img)

    scaled_imgs.reverse()
    return scaled_imgs


####################################################################################################
# spectral normalization layer to decouple the magnitude of a weight tensor
####################################################################################################

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(pl.LightningModule):
    """
    spectral normalization
    code and idea originally from Takeru Miyato's work 'Spectral Normalization for GAN'
    https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
    """
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


####################################################################################################
# neural style transform loss from neural_style_tutorial of pytorch
####################################################################################################


def GramMatrix(input):
    s = input.size()
    features = input.view(s[0], s[1], s[2]*s[3])
    features_t = torch.transpose(features, 1, 2)
    G = torch.bmm(features, features_t).div(s[1]*s[2]*s[3])
    return G


def img_crop(input, size=224):
    input_cropped = F.upsample(input, size=(size, size), mode='bilinear', align_corners=True)
    return input_cropped


class Normalization(pl.LightningModule):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std = std.view(-1, 1, 1)

    def forward(self, input):
        return (input-self.mean) / self.std


######################################################################################
# base function for network structure
######################################################################################


def pluralistic_init_weights(net, init_type='normal', gain=0.02):
    """Get different initial method for the network weights"""
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv')!=-1 or classname.find('Linear')!=-1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, 0.02)
            init.constant_(m.bias.data, 0.0)

    #print('initialize network with %s' % init_type)
    logger.info('Initialization method [{:s}]'.format(init_type))
    net.apply(init_func)


def get_norm_layer(norm_type='batch'):
    """Get the normalization layer for the networks"""
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, momentum=0.1, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=True)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


def get_nonlinearity_layer(activation_type='PReLU'):
    """Get the activation layer for the networks"""
    if activation_type == 'ReLU':
        nonlinearity_layer = nn.ReLU()
    elif activation_type == 'SELU':
        nonlinearity_layer = nn.SELU()
    elif activation_type == 'LeakyReLU':
        nonlinearity_layer = nn.LeakyReLU(0.1)
    elif activation_type == 'PReLU':
        nonlinearity_layer = nn.PReLU()
    else:
        raise NotImplementedError('activation layer [%s] is not found' % activation_type)
    return nonlinearity_layer

def print_network(net):
    """print the network"""
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('total number of parameters: %.3f M' % (num_params/1e6))


def init_net(net, init_type='normal', activation='relu', gpu_ids=[]):
    """print the network structure and initial the network"""
    print_network(net)

    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.cuda()
        net = torch.nn.DataParallel(net, gpu_ids)
    init_weights(net, init_type)
    return net


def _freeze(*args):
    """freeze the network for forward process"""
    for module in args:
        if module:
            for p in module.parameters():
                p.requires_grad = False


def _unfreeze(*args):
    """ unfreeze the network for parameter update"""
    for module in args:
        if module:
            for p in module.parameters():
                p.requires_grad = True


def spectral_norm(module, use_spect=True):
    """use spectral normal layer to stable the training process"""
    if use_spect:
        return SpectralNorm(module)
    else:
        return module


def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, **kwargs):
    """use coord convolution layer to add position information"""
    if use_coord:
        return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs)
    else:
        return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)


######################################################################################
# Network basic function
######################################################################################
class AddCoords(pl.LightningModule):
    """
    Add Coords to a tensor
    """
    def __init__(self, with_r=False):
        super(AddCoords, self).__init__()
        self.with_r = with_r

    def forward(self, x):
        """
        :param x: shape (batch, channel, x_dim, y_dim)
        :return: shape (batch, channel+2, x_dim, y_dim)
        """
        B, _, x_dim, y_dim = x.size()

        # coord calculate
        xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x)
        yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x)
        # normalization
        xx_channel = xx_channel.float() / (x_dim-1)
        yy_cahnnel = yy_cahnnel.float() / (y_dim-1)
        xx_channel = xx_channel * 2 - 1
        yy_cahnnel = yy_cahnnel * 2 - 1

        ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1)

        if self.with_r:
            rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2)
            ret = torch.cat([ret, rr], dim=1)

        return ret


class CoordConv(pl.LightningModule):
    """
    CoordConv operation
    """
    def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs):
        super(CoordConv, self).__init__()
        self.addcoords = AddCoords(with_r=with_r)
        input_nc = input_nc + 2
        if with_r:
            input_nc = input_nc + 1
        self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)

    def forward(self, x):
        ret = self.addcoords(x)
        ret = self.conv(ret)

        return ret


class ResBlock(pl.LightningModule):
    """
    Define an Residual block for different types
    """
    def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
                 sample_type='none', use_spect=False, use_coord=False):
        super(ResBlock, self).__init__()

        hidden_nc = output_nc if hidden_nc is None else hidden_nc
        self.sample = True
        if sample_type == 'none':
            self.sample = False
        elif sample_type == 'up':
            output_nc = output_nc * 4
            self.pool = nn.PixelShuffle(upscale_factor=2)
        elif sample_type == 'down':
            self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        else:
            raise NotImplementedError('sample type [%s] is not found' % sample_type)

        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0}

        self.conv1 = coord_conv(input_nc, hidden_nc, use_spect, use_coord, **kwargs)
        self.conv2 = coord_conv(hidden_nc, output_nc, use_spect, use_coord, **kwargs)
        self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_short)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2,)
        else:
            self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2,)

        self.shortcut = nn.Sequential(self.bypass,)

    def forward(self, x):
        if self.sample:
            out = self.pool(self.model(x)) + self.pool(self.shortcut(x))
        else:
            out = self.model(x) + self.shortcut(x)

        return out


class ResBlockEncoderOptimized(pl.LightningModule):
    """
    Define an Encoder block for the first layer of the discriminator and representation network
    """
    def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(), use_spect=False, use_coord=False):
        super(ResBlockEncoderOptimized, self).__init__()

        kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0}

        self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs)
        self.conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs)
        self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_short)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(self.conv1, nonlinearity, self.conv2, nn.AvgPool2d(kernel_size=2, stride=2))
        else:
            self.model = nn.Sequential(self.conv1, norm_layer(output_nc), nonlinearity, self.conv2, nn.AvgPool2d(kernel_size=2, stride=2))

        self.shortcut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), self.bypass)

    def forward(self, x):
        out = self.model(x) + self.shortcut(x)

        return out


class ResBlockDecoder(pl.LightningModule):
    """
    Define a decoder block
    """
    def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
                 use_spect=False, use_coord=False):
        super(ResBlockDecoder, self).__init__()

        hidden_nc = output_nc if hidden_nc is None else hidden_nc

        self.conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect)
        self.conv2 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect)
        self.bypass = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2,)
        else:
            self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2,)

        self.shortcut = nn.Sequential(self.bypass)

    def forward(self, x):
        out = self.model(x) + self.shortcut(x)

        return out


class Output(pl.LightningModule):
    """
    Define the output layer
    """
    def __init__(self, input_nc, output_nc, kernel_size = 3, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
                 use_spect=False, use_coord=False):
        super(Output, self).__init__()

        kwargs = {'kernel_size': kernel_size, 'padding':0, 'bias': True}

        self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs)

        if type(norm_layer) == type(None):
            self.model = nn.Sequential(nonlinearity, nn.ReflectionPad2d(int(kernel_size/2)), self.conv1, nn.Tanh())
        else:
            self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, nn.ReflectionPad2d(int(kernel_size / 2)), self.conv1, nn.Tanh())

    def forward(self, x):
        out = self.model(x)

        return out


class Auto_Attn(pl.LightningModule):
    """ Short+Long attention Layer"""

    def __init__(self, input_nc, norm_layer=nn.BatchNorm2d):
        super(Auto_Attn, self).__init__()
        self.input_nc = input_nc

        self.query_conv = nn.Conv2d(input_nc, input_nc // 4, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.alpha = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

        self.model = ResBlock(int(input_nc*2), input_nc, input_nc, norm_layer=norm_layer, use_spect=True)

    def forward(self, x, pre=None, mask=None):
        """
        inputs :
            x : input feature maps( B X C X W X H)
        returns :
            out : self attention value + input feature
            attention: B X N X N (N is Width*Height)
        """
        B, C, W, H = x.size()
        proj_query = self.query_conv(x).view(B, -1, W * H)  # B X (N)X C
        proj_key = proj_query  # B X C x (N)

        energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = x.view(B, -1, W * H)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, W, H)

        out = self.gamma * out + x

        if type(pre) != type(None):
            # using long distance attention layer to copy information from valid regions
            context_flow = torch.bmm(pre.view(B, -1, W*H), attention.permute(0, 2, 1)).view(B, -1, W, H)
            context_flow = self.alpha * (1-mask) * context_flow + (mask) * pre
            out = self.model(torch.cat([out, context_flow], dim=1))

        return out, attention



import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import copy


####################################################################################################
# spectral normalization layer to decouple the magnitude of a weight tensor
####################################################################################################

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class SpectralNorm(pl.LightningModule):
    """
    spectral normalization
    code and idea originally from Takeru Miyato's work 'Spectral Normalization for GAN'
    https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
    """
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)


####################################################################################################
# adversarial loss for different gan mode
####################################################################################################


class GANLoss(pl.LightningModule):
    """Define different GAN objectives.
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'hinge':
            self.loss = nn.ReLU()
        elif gan_mode == 'wgangp':
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def __call__(self, prediction, target_is_real, is_disc=False):
        """Calculate loss given Discriminator's output and grount truth labels.
        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            the calculated loss.
        """
        if self.gan_mode in ['lsgan', 'vanilla']:
            labels = (self.real_label if target_is_real else self.fake_label).expand_as(prediction).type_as(prediction)
            loss = self.loss(prediction, labels)
        elif self.gan_mode in ['hinge', 'wgangp']:
            if is_disc:
                if target_is_real:
                    prediction = -prediction
                if self.gan_mode == 'hinge':
                    loss = self.loss(1 + prediction).mean()
                elif self.gan_mode == 'wgangp':
                    loss = prediction.mean()
            else:
                loss = -prediction.mean()
        return loss


def cal_gradient_penalty(netD, real_data, fake_data, type='mixed', constant=1.0, lambda_gp=10.0):
    """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
    Arguments:
        netD (network)              -- discriminator network
        real_data (tensor array)    -- real images
        fake_data (tensor array)    -- generated images from the generator
        type (str)                  -- if we mix real and fake data or not [real | fake | mixed].
        constant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2
        lambda_gp (float)           -- weight for this loss
    Returns the gradient penalty loss
    """
    if lambda_gp > 0.0:
        if type == 'real':   # either use real images, fake images, or a linear interpolation of two.
            interpolatesv = real_data
        elif type == 'fake':
            interpolatesv = fake_data
        elif type == 'mixed':
            alpha = torch.rand(real_data.shape[0], 1)
            alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
            alpha = alpha.type_as(real_data)
            interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
        else:
            raise NotImplementedError('{} not implemented'.format(type))
        interpolatesv.requires_grad_(True)
        disc_interpolates = netD(interpolatesv)
        gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
                                        grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
                                        create_graph=True, retain_graph=True, only_inputs=True)
        gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
        gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp        # added eps
        return gradient_penalty, gradients
    else:
        return 0.0, None


####################################################################################################
# neural style transform loss from neural_style_tutorial of pytorch
####################################################################################################


def ContentLoss(input, target):
    target = target.detach()
    loss = F.l1_loss(input, target)
    return loss


def GramMatrix(input):
    s = input.size()
    features = input.view(s[0], s[1], s[2]*s[3])
    features_t = torch.transpose(features, 1, 2)
    G = torch.bmm(features, features_t).div(s[1]*s[2]*s[3])
    return G


def StyleLoss(input, target):
    target = GramMatrix(target).detach()
    input = GramMatrix(input)
    loss = F.l1_loss(input, target)
    return loss


def img_crop(input, size=224):
    input_cropped = F.upsample(input, size=(size, size), mode='bilinear', align_corners=True)
    return input_cropped


class Normalization(pl.LightningModule):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std = std.view(-1, 1, 1)

    def forward(self, input):
        return (input-self.mean) / self.std


class get_features(pl.LightningModule):
    def __init__(self, cnn):
        super(get_features, self).__init__()

        vgg = copy.deepcopy(cnn)

        self.conv1 = nn.Sequential(vgg[0], vgg[1], vgg[2], vgg[3], vgg[4])
        self.conv2 = nn.Sequential(vgg[5], vgg[6], vgg[7], vgg[8], vgg[9])
        self.conv3 = nn.Sequential(vgg[10], vgg[11], vgg[12], vgg[13], vgg[14], vgg[15], vgg[16])
        self.conv4 = nn.Sequential(vgg[17], vgg[18], vgg[19], vgg[20], vgg[21], vgg[22], vgg[23])
        self.conv5 = nn.Sequential(vgg[24], vgg[25], vgg[26], vgg[27], vgg[28], vgg[29], vgg[30])

    def forward(self, input, layers):
        input = img_crop(input)
        output = []
        for i in range(1, layers):
            layer = getattr(self, 'conv'+str(i))
            input = layer(input)
            output.append(input)
        return output


##############################################################################################################
# Network function
##############################################################################################################
def define_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
             use_coord=False, init_type='orthogonal', gpu_ids=[]):

    net = ResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord)

    return init_net(net, init_type, activation, gpu_ids)


def define_g(output_nc=3, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
             use_spect=True, use_coord=False, use_attn=True, init_type='orthogonal', gpu_ids=[]):

    net = ResGenerator(output_nc, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)


def define_d(input_nc=3, ndf=64, img_f=512, layers=6, norm='none', activation='LeakyReLU', use_spect=True, use_coord=False,
             use_attn=True,  model_type='ResDis', init_type='orthogonal', gpu_ids=[]):

    if model_type == 'ResDis':
        net = ResDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn)
    elif model_type == 'PatchDis':
        net = PatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn)

    return init_net(net, init_type, activation, gpu_ids)


#############################################################################################################
# Network structure
#############################################################################################################
class ResEncoder(pl.LightningModule):
    """
    ResNet Encoder Network
    :param input_nc: number of channels in input
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    """
    def __init__(self, input_nc=3, ngf=64, z_nc=128, img_f=1024, L=6, layers=6, norm='none', activation='ReLU',
                 use_spect=True, use_coord=False):
        super(ResEncoder, self).__init__()

        self.layers = layers
        self.z_nc = z_nc
        self.L = L

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # encoder part
        self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)

        mult = 1
        for i in range(layers-1):
            mult_prev = mult
            mult = min(2 ** (i + 1), img_f // ngf)
            block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
            setattr(self, 'encoder' + str(i), block)

        # inference part
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'infer_prior' + str(i), block)

        self.posterior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
        self.prior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)

    def forward(self, img_m, img_c=None):
        """
        :param img_m: image with mask regions I_m
        :param img_c: complement of I_m, the mask regions
        :return distribution: distribution of mask regions, for training we have two paths, testing one path
        :return feature: the conditional feature f_m, and the previous f_pre for auto context attention
        """

        if type(img_c) != type(None):
            img = torch.cat([img_m, img_c], dim=0)
        else:
            img = img_m

        # encoder part
        out = self.block0(img)
        feature = [out]
        for i in range(self.layers-1):
            model = getattr(self, 'encoder' + str(i))
            out = model(out)
            feature.append(out)

        # infer part
        # during the training, we have two paths, during the testing, we only have one paths
        if type(img_c) != type(None):
            distribution = self.two_paths(out)
            return distribution, feature
        else:
            distribution = self.one_path(out)
            return distribution, feature

    def one_path(self, f_in):
        """one path for baseline training or testing"""
        f_m = f_in
        distribution = []

        # infer state
        for i in range(self.L):
            infer_prior = getattr(self, 'infer_prior' + str(i))
            f_m = infer_prior(f_m)

        # get distribution
        o = self.prior(f_m)
        q_mu, q_std = torch.split(o, self.z_nc, dim=1)
        distribution.append([q_mu, F.softplus(q_std)])

        return distribution

    def two_paths(self, f_in):
        """two paths for the training"""
        f_m, f_c = f_in.chunk(2)
        distributions = []

        # get distribution
        o = self.posterior(f_c)
        p_mu, p_std = torch.split(o, self.z_nc, dim=1)
        distribution = self.one_path(f_m)
        distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])

        return distributions


class ResGenerator(pl.LightningModule):
    """
    ResNet Generator Network
    :param output_nc: number of channels in output
    :param ngf: base filter channel
    :param z_nc: latent channels
    :param img_f: the largest feature channels
    :param L: Number of refinements of density
    :param layers: down and up sample layers
    :param norm: normalization function 'instance, batch, group'
    :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
    :param output_scale: Different output scales
    """
    def __init__(self, output_nc=3, ngf=64, z_nc=128, img_f=1024, L=1, layers=6, norm='batch', activation='ReLU',
                 output_scale=1, use_spect=True, use_coord=False, use_attn=True):
        super(ResGenerator, self).__init__()

        self.layers = layers
        self.L = L
        self.output_scale = output_scale
        self.use_attn = use_attn

        norm_layer = get_norm_layer(norm_type=norm)
        nonlinearity = get_nonlinearity_layer(activation_type=activation)
        # latent z to feature
        mult = min(2 ** (layers-1), img_f // ngf)
        self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)

        # transform
        for i in range(self.L):
            block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
            setattr(self, 'generator' + str(i), block)

        # decoder part
        for i in range(layers):
            mult_prev = mult
            mult = min(2 ** (layers - i - 1), img_f // ngf)
            if i > layers - output_scale:
                # upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            else:
                # upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
                upconv = ResBlockDecoder(ngf * mult_prev , ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
            setattr(self, 'decoder' + str(i), upconv)
            # output part
            if i > layers - output_scale - 1:
                outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord)
                setattr(self, 'out' + str(i), outconv)
            # short+long term attention part
            if i == 1 and use_attn:
                attn = Auto_Attn(ngf*mult, None)
                setattr(self, 'attn' + str(i), attn)

    def forward(self, z, f_m=None, f_e=None, mask=None):
        """
        ResNet Generator Network
        :param z: latent vector
        :param f_m: feature of valid regions for conditional VAG-GAN
        :param f_e: previous encoder feature for short+long term attention layer
        :return results: different scale generation outputs
        """

        f = self.generator(z)
        for i in range(self.L):
             generator = getattr(self, 'generator' + str(i))
             f = generator(f)

        # the features come from mask regions and valid regions, we directly add them together
        out = f_m + f
        results= []
        attn = 0
        for i in range(self.layers):
            model = getattr(self, 'decoder' + str(i))
            out = model(out)
            if i == 1 and self.use_attn:
                # auto attention
                model = getattr(self, 'attn' + str(i))
                out, attn = model(out, f_e, mask)
            if i > self.layers - self.output_scale - 1:
                model = getattr(self, 'out' + str(i))
                output = model(out)
                results.append(output)
                out = torch.cat([out, output], dim=1)

        return results, attn

# https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/1ca1855615fed8b686ca218c6494f455860f9996/model/pluralistic_model.py
# https://github.com/lyndonzheng/Pluralistic-Inpainting/blob/1ca1855615fed8b686ca218c6494f455860f9996/util/task.py
class PluralisticGenerator(pl.LightningModule):
    def __init__(self, ngf_E=32, z_nc_E=128, img_f_E=128, layers_E=5, norm_E='none', activation_E='LeakyReLU',
                 ngf_G=32, z_nc_G=128, img_f_G=128, L_G=0, output_scale_G=1, norm_G='instance', activation_G='LeakyReLU', train_paths='two'):
        super().__init__()
        self.net_E = ResEncoder(ngf=ngf_E, z_nc=z_nc_E, img_f=img_f_E, layers=layers_E, norm=norm_E, activation=activation_E)
        self.net_G = ResGenerator(ngf=ngf_G, z_nc=z_nc_G, img_f=img_f_G, L=L_G, layers=5, output_scale=output_scale_G,
                                      norm=norm_G, activation=activation_G)
        self.train_paths = train_paths
    def get_distribution(self, distributions, mask):
        """Calculate encoder distribution for img_m, img_c"""
        # get distribution
        sum_valid = (torch.mean(mask.view(mask.size(0), -1), dim=1) - 1e-5).view(-1, 1, 1, 1)
        m_sigma = 1 / (1 + ((sum_valid - 0.8) * 8).exp_())
        p_distribution, q_distribution, kl_rec, kl_g = 0, 0, 0, 0
        self.distribution = []
        for distribution in distributions:
            p_mu, p_sigma, q_mu, q_sigma = distribution
            # the assumption distribution for different mask regions
            m_distribution = torch.distributions.Normal(torch.zeros_like(p_mu), m_sigma * torch.ones_like(p_sigma))
            # m_distribution = torch.distributions.Normal(torch.zeros_like(p_mu), torch.ones_like(p_sigma))
            # the post distribution from mask regions
            p_distribution = torch.distributions.Normal(p_mu, p_sigma)
            p_distribution_fix = torch.distributions.Normal(p_mu.detach(), p_sigma.detach())
            # the prior distribution from valid region
            q_distribution = torch.distributions.Normal(q_mu, q_sigma)

            # kl divergence
            kl_rec += torch.distributions.kl_divergence(m_distribution, p_distribution)
            if self.train_paths == "one":
                kl_g += torch.distributions.kl_divergence(m_distribution, q_distribution)
            elif self.train_paths == "two":
                kl_g += torch.distributions.kl_divergence(p_distribution_fix, q_distribution)
            self.distribution.append([torch.zeros_like(p_mu), m_sigma * torch.ones_like(p_sigma), p_mu, p_sigma, q_mu, q_sigma])

        return p_distribution, q_distribution, kl_rec, kl_g

    def get_G_inputs(self, p_distribution, q_distribution, f, mask):
        """Process the encoder feature and distributions for generation network"""
        f_m = torch.cat([f[-1].chunk(2)[0], f[-1].chunk(2)[0]], dim=0)
        f_e = torch.cat([f[2].chunk(2)[0], f[2].chunk(2)[0]], dim=0)
        scale_mask = scale_img(mask, size=[f_e.size(2), f_e.size(3)])
        mask = torch.cat([scale_mask.chunk(3, dim=1)[0], scale_mask.chunk(3, dim=1)[0]], dim=0)
        z_p = p_distribution.rsample()
        z_q = q_distribution.rsample()
        z = torch.cat([z_p, z_q], dim=0)
        return z, f_m, f_e, mask

    def forward(self, images, img_inverted, masks):
      distributions, f = self.net_E(images, img_inverted)
      p_distribution, q_distribution, kl_rec, kl_g = self.get_distribution(distributions, masks)
      z, f_m, f_e, mask = self.get_G_inputs(p_distribution, q_distribution, f, masks)
      results, attn = self.net_G(z, f_m, f_e, mask)

      self.img_rec = []
      self.img_g = []
      for result in results:
          img_rec, img_g = result.chunk(2)
          self.img_rec.append(img_rec)
          self.img_g.append(img_g)

      return self.img_g[-1].detach(), kl_rec, kl_g


In [None]:
#@title [EdgeConnect_arch.py](https://github.com/knazeri/edge-connect) (2019)
"""
edge_connect.py (12-12-20)
https://github.com/knazeri/edge-connect/blob/master/src/edge_connect.py
"""

import torch
import torch.nn as nn
import os
import torch.optim as optim

#from models.modules.architectures.convolutions.partialconv2d import PartialConv2d
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d
import pytorch_lightning as pl
from torchvision.utils import save_image

class InpaintGenerator(pl.LightningModule):
    def __init__(self, residual_blocks=8, init_weights=True, conv_type='deform'):
        super(InpaintGenerator, self).__init__()

        if conv_type == 'normal':
          self.encoder = nn.Sequential(
              nn.ReflectionPad2d(3),
              nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
              nn.InstanceNorm2d(64, track_running_stats=False),
              nn.ReLU(True),

              nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
              nn.InstanceNorm2d(128, track_running_stats=False),
              nn.ReLU(True),

              nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
              nn.InstanceNorm2d(256, track_running_stats=False),
              nn.ReLU(True)
          )
        elif conv_type == 'partial':
          self.encoder = nn.Sequential(
              nn.ReflectionPad2d(3),
              PartialConv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
              nn.InstanceNorm2d(64, track_running_stats=False),
              nn.ReLU(True),

              PartialConv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
              nn.InstanceNorm2d(128, track_running_stats=False),
              nn.ReLU(True),

              PartialConv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
              nn.InstanceNorm2d(256, track_running_stats=False),
              nn.ReLU(True)
          )
        elif conv_type == 'deform':
          self.encoder = nn.Sequential(
              nn.ReflectionPad2d(3),
              DeformConv2d(in_nc=4, out_nc=64, kernel_size=7, padding=0),
              nn.InstanceNorm2d(64, track_running_stats=False),
              nn.ReLU(True),

              DeformConv2d(in_nc=64, out_nc=128, kernel_size=4, stride=2, padding=1),
              nn.InstanceNorm2d(128, track_running_stats=False),
              nn.ReLU(True),

              DeformConv2d(in_nc=128, out_nc=256, kernel_size=4, stride=2, padding=1),
              nn.InstanceNorm2d(256, track_running_stats=False),
              nn.ReLU(True)
          )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock(256, 2)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
        )


    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = (torch.tanh(x) + 1) / 2

        return x


class EdgeGenerator(pl.LightningModule):
    def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True, conv_type='normal'):
        super(EdgeGenerator, self).__init__()

        if conv_type == 'normal':
          self.encoder = nn.Sequential(
              nn.ReflectionPad2d(3),
              spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
              nn.InstanceNorm2d(64, track_running_stats=False),
              nn.ReLU(True),

              spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
              nn.InstanceNorm2d(128, track_running_stats=False),
              nn.ReLU(True),

              spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
              nn.InstanceNorm2d(256, track_running_stats=False),
              nn.ReLU(True)
          )
        elif conv_type == 'partial':
          self.encoder = nn.Sequential(
              nn.ReflectionPad2d(3),
              spectral_norm(PartialConv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
              nn.InstanceNorm2d(64, track_running_stats=False),
              nn.ReLU(True),

              spectral_norm(PartialConv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
              nn.InstanceNorm2d(128, track_running_stats=False),
              nn.ReLU(True),

              spectral_norm(PartialConv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
              nn.InstanceNorm2d(256, track_running_stats=False),
              nn.ReLU(True)
          )
        elif conv_type == 'deform':
            # without spectral_norm
            self.encoder = nn.Sequential(
                nn.ReflectionPad2d(3),
                DeformConv2d(in_nc=3, out_nc=64, kernel_size=7, padding=0),
                nn.InstanceNorm2d(64, track_running_stats=False),
                nn.ReLU(True),

                DeformConv2d(in_nc=64, out_nc=128, kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(128, track_running_stats=False),
                nn.ReLU(True),

                DeformConv2d(in_nc=128, out_nc=256, kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(256, track_running_stats=False),
                nn.ReLU(True)
            )

        blocks = []
        for _ in range(residual_blocks):
            block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
            blocks.append(block)

        self.middle = nn.Sequential(*blocks)

        self.decoder = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(128, track_running_stats=False),
            nn.ReLU(True),

            spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
            nn.InstanceNorm2d(64, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
        )


    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        x = torch.sigmoid(x)
        return x


class ResnetBlock(pl.LightningModule):
    def __init__(self, dim, dilation=1, use_spectral_norm=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(dilation),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
            nn.ReLU(True),

            nn.ReflectionPad2d(1),
            spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
            nn.InstanceNorm2d(dim, track_running_stats=False),
        )

    def forward(self, x):
        out = x + self.conv_block(x)

        # Remove ReLU at the end of the residual block
        # http://torch.ch/blog/2016/02/04/resnets.html

        return out


def spectral_norm(module, mode=True):
    if mode:
        return nn.utils.spectral_norm(module)

    return module

class EdgeConnectModel(pl.LightningModule):
    def __init__(self, residual_blocks_edge=8, residual_blocks_inpaint=8, use_spectral_norm=True, conv_type_edge='normal', conv_type_inpaint='normal'):
        super().__init__()
        self.EdgeGenerator = EdgeGenerator(residual_blocks=residual_blocks_edge, use_spectral_norm=use_spectral_norm, conv_type=conv_type_edge)
        self.InpaintGenerator = InpaintGenerator(residual_blocks=residual_blocks_inpaint, conv_type=conv_type_inpaint)

    def forward(self, images, edges, grayscale, masks):
        images = images.type(torch.cuda.FloatTensor)
        edges = edges.type(torch.cuda.FloatTensor)
        grayscale = grayscale.type(torch.cuda.FloatTensor)
        masks = masks.type(torch.cuda.FloatTensor)
        

        # edge
        edges_masked = (edges * masks)
        grayscale_masked = grayscale * masks

        inputs = torch.cat((grayscale_masked, edges_masked, masks), dim=1)
        outputs_edge = self.EdgeGenerator(inputs)                                      # in: [grayscale(1) + edge(1) + mask(1)]

        # inpaint
        images_masked = (images * masks).float() + (1-masks)
        inputs = torch.cat((images_masked, outputs_edge), dim=1)
        outputs = self.InpaintGenerator(inputs)                                    # in: [rgb(3) + edge(1)]
        return outputs, outputs_edge


In [None]:
#@title [FRRN_arch.py](https://github.com/ZongyuGuo/Inpainting_FRRN/) (2019)
"""
networks.py (18-12-20)
https://github.com/ZongyuGuo/Inpainting_FRRN/blob/master/src/networks.py
"""

import torch
import torch.nn as nn
#from .convolutions import partialconv2d
import pytorch_lightning as pl

class FRRNet(pl.LightningModule):
    def __init__(self, block_num=16):
        super(FRRNet, self).__init__()
        self.block_num = block_num
        self.dilation_num = block_num // 2
        blocks = []
        for _ in range(self.block_num):
            blocks.append(FRRBlock())
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x, mask):
        x = x.type(torch.cuda.FloatTensor)
        mask = mask.type(torch.cuda.FloatTensor)

        mid_x = []
        mid_m = []

        mask_new = mask
        for index in range(self.dilation_num):
            x, _ = self.blocks[index * 2](x, mask_new, mask)
            x, mask_new = self.blocks[index * 2 + 1](x, mask_new, mask)
            mid_x.append(x)
            mid_m.append(mask_new)

        return x, mid_x, mid_m


class FRRBlock(pl.LightningModule):
    def __init__(self):
        super(FRRBlock, self).__init__()
        self.full_conv1 = PConvLayer(3,  32, kernel_size=5, stride=1, padding=2, use_norm=False)
        self.full_conv2 = PConvLayer(32, 32, kernel_size=5, stride=1, padding=2, use_norm=False)
        self.full_conv3 = PConvLayer(32, 3,  kernel_size=5, stride=1, padding=2, use_norm=False)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.branch_conv1 = PConvLayer(3,   64,  kernel_size=3, stride=2, padding=1, use_norm=False)
        self.branch_conv2 = PConvLayer(64,  96,  kernel_size=3, stride=2, padding=1)
        self.branch_conv3 = PConvLayer(96,  128, kernel_size=3, stride=2, padding=1)
        self.branch_conv4 = PConvLayer(128, 96,  kernel_size=3, stride=1, padding=1, act='LeakyReLU')
        self.branch_conv5 = PConvLayer(96,  64,  kernel_size=3, stride=1, padding=1, act='LeakyReLU')
        self.branch_conv6 = PConvLayer(64,  3,   kernel_size=3, stride=1, padding=1, act='Tanh')

    def forward(self, input, mask, mask_ori):
        x = input
        out_f, mask_f = self.full_conv1(x, mask)
        out_f, mask_f = self.full_conv2(out_f, mask_f)
        out_f, mask_f = self.full_conv3(out_f, mask_f)

        out_b, mask_b = self.branch_conv1(x, mask)
        out_b, mask_b = self.branch_conv2(out_b, mask_b)
        out_b, mask_b = self.branch_conv3(out_b, mask_b)

        out_b = self.upsample(out_b)
        mask_b = self.upsample(mask_b)
        out_b, mask_b = self.branch_conv4(out_b, mask_b)
        out_b = self.upsample(out_b)
        mask_b = self.upsample(mask_b)
        out_b, mask_b = self.branch_conv5(out_b, mask_b)
        out_b = self.upsample(out_b)
        mask_b = self.upsample(mask_b)
        out_b, mask_b = self.branch_conv6(out_b, mask_b)

        mask_new = mask_f * mask_b
        out = (out_f * mask_new + out_b * mask_new) / 2 * (1 - mask_ori) + input
        #out = (out_f * mask_new + out_b * mask_new) / 2 * (1 - mask_ori) + input * mask_ori
        return out, mask_new



class PConvLayer(pl.LightningModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, act='ReLU', use_norm=True):
        super(PConvLayer, self).__init__()
        self.conv = PartialConv2d(in_channels=in_channels, out_channels=out_channels,
                        kernel_size=kernel_size, stride=stride, padding=padding, return_mask=True)
        self.norm = nn.InstanceNorm2d(out_channels, track_running_stats=False)
        self.use_norm = use_norm
        if act == 'ReLU':
            self.act = nn.ReLU(True)
        elif act == 'LeakyReLU':
            self.act = nn.LeakyReLU(0.2, True)
        elif act == 'Tanh':
            self.act = nn.Tanh()

    def forward(self, x, mask):

        x, mask_update = self.conv(x, mask)
        if self.use_norm:
            x = self.norm(x)
        x = self.act(x)
        return x, mask_update


In [None]:
#@title [PRVS_arch.py](https://github.com/jingyuanli001/PRVS-Image-Inpainting) (2019)
"""
model.py (18-12-20)
https://github.com/jingyuanli001/PRVS-Image-Inpainting/blob/master/model.py

PRVSNet.py (18-12-20)
https://github.com/jingyuanli001/PRVS-Image-Inpainting/blob/master/modules/PRVSNet.py

partialconv2d.py (18-12-20) # using their partconv2d to avoid dimension errors
https://github.com/jingyuanli001/PRVS-Image-Inpainting/blob/master/modules/partialconv2d.py

PConvLayer.py (18-12-20)
https://github.com/jingyuanli001/PRVS-Image-Inpainting/blob/master/modules/PConvLayer.py

VSRLayer.py (18-12-20)
https://github.com/jingyuanli001/PRVS-Image-Inpainting/blob/master/modules/VSRLayer.py

Attention.py (18-12-20)
https://github.com/jingyuanli001/PRVS-Image-Inpainting/blob/master/modules/Attention.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision import models
#from .convolutions import partialconv2d
import pytorch_lightning as pl

###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
###############################################################################

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        self.return_mask = True

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask=None):

        if mask is not None or self.last_size != (input.data.shape[2], input.data.shape[3]):
            self.last_size = (input.data.shape[2], input.data.shape[3])

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
            self.update_mask.to(input)
            self.mask_ratio.to(input)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output



PartialConv = PartialConv2d

class AttentionModule(pl.LightningModule):

    def __init__(self, patch_size = 3, propagate_size = 3, stride = 1):
        super(AttentionModule, self).__init__()
        self.patch_size = patch_size
        self.propagate_size = propagate_size
        self.stride = stride
        self.prop_kernels = None

    def forward(self, foreground, masks):
        ###assume the masked area has value 1
        bz, nc, w, h = foreground.size()
        if masks.size(3) != foreground.size(3):
            masks = F.interpolate(masks, foreground.size()[2:])
        background = foreground.clone()
        background = background * masks
        background = F.pad(background, [self.patch_size//2, self.patch_size//2, self.patch_size//2, self.patch_size//2])
        conv_kernels_all = background.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size, self.stride).contiguous().view(bz, nc, -1, self.patch_size, self.patch_size)
        conv_kernels_all = conv_kernels_all.transpose(2, 1)
        output_tensor = []
        for i in range(bz):
            mask = masks[i:i+1]
            feature_map = foreground[i:i+1]
            #form convolutional kernels
            conv_kernels = conv_kernels_all[i] + 0.0000001
            norm_factor = torch.sum(conv_kernels**2, [1, 2, 3], keepdim = True)**0.5
            conv_kernels = conv_kernels/norm_factor

            conv_result = F.conv2d(feature_map, conv_kernels, padding = self.patch_size//2)
            if self.propagate_size != 1:
                if self.prop_kernels is None:
                    self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
                    self.prop_kernels.requires_grad = False
                    self.prop_kernels = self.prop_kernels.cuda()
                conv_result = F.conv2d(conv_result, self.prop_kernels, stride = 1, padding = 1, groups = conv_result.size(1))
            attention_scores = F.softmax(conv_result, dim = 1)
            ##propagate the scores
            recovered_foreground = F.conv_transpose2d(attention_scores, conv_kernels, stride = 1, padding = self.patch_size//2)
            #average the recovered value, at the same time make non-masked area 0
            recovered_foreground = (recovered_foreground * (1 - mask))/(self.patch_size ** 2)
            #recover the image
            final_output = recovered_foreground + feature_map * mask
            output_tensor.append(final_output)
        return torch.cat(output_tensor, dim = 0)


class Bottleneck(pl.LightningModule):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += residual
        out = self.relu(out)

        return out

class EdgeGenerator(pl.LightningModule):
    def __init__(self, in_channels_feature, kernel_s = 3, add_last_edge = True):
        super(EdgeGenerator, self).__init__()

        self.p_conv = PartialConv2d(in_channels_feature + 1, 64, kernel_size = kernel_s, stride = 1, padding = kernel_s // 2, multi_channel = True, bias = False)

        self.edge_resolver = Bottleneck(64, 16)
        self.out_layer = nn.Conv2d(64, 1, 1, bias = False)

    def forward(self, in_x, mask):
        x, mask_updated = self.p_conv(in_x, mask)
        x = self.edge_resolver(x)
        edge_out = self.out_layer(x)
        return edge_out, mask_updated

class VSRLayer(pl.LightningModule):
    def __init__(self, in_channel, out_channel, stride = 2, kernel_size = 3, batch_norm = True, activation = "ReLU", deconv = False):
        super(VSRLayer, self).__init__()
        self.edge_generator = EdgeGenerator(in_channel, kernel_s = kernel_size)
        self.feat_rec = PartialConv(in_channel+1, out_channel, stride = stride, kernel_size = kernel_size, padding = kernel_size//2, multi_channel = True)
        if deconv:
            self.deconv = nn.ConvTranspose2d(out_channel, out_channel, 4, 2, 1)
        else:
            self.deconv = lambda x:x

        if batch_norm:
            self.batchnorm = nn.BatchNorm2d(out_channel)
        else:
            self.batchnorm = lambda x:x

        self.stride = stride

        if activation == "ReLU":
            self.activation = nn.ReLU(True)
        elif activation == "Leaky":
            self.activation = nn.LeakyReLU(0.2, True)
        else:
            self.activation = lambda x:x

    def forward(self, feat_in, mask_in, edge_in):
        edge_in = F.interpolate(edge_in, size = feat_in.size()[2:])
        edge_updated, mask_updated = self.edge_generator(torch.cat([feat_in, edge_in], dim = 1), torch.cat([mask_in, mask_in[:,:1,:,:]], dim = 1))
        edge_reconstructed = edge_in * mask_in[:,:1,:,:] + edge_updated * (mask_updated[:,:1,:,:] - mask_in[:,:1,:,:])
        feat_out, feat_mask = self.feat_rec(torch.cat([feat_in, edge_reconstructed], dim = 1), torch.cat([mask_in, mask_updated[:,:1,:,:]], dim = 1))
        feat_out = self.deconv(feat_out)
        feat_out = self.batchnorm(feat_out)
        feat_out = self.activation(feat_out)
        mask_updated = F.interpolate(mask_updated, size = feat_out.size()[2:])
        feat_mask = F.interpolate(feat_mask, size = feat_out.size()[2:])
        return feat_out, feat_mask*mask_updated[:,0:1,:,:], edge_reconstructed


class PConvLayer(pl.LightningModule):
    def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu',
                 conv_bias=False, deconv = False):
        super().__init__()
        if sample == 'down-5':
            self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias, multi_channel = True)
        elif sample == 'down-7':
            self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias, multi_channel = True)
        elif sample == 'down-3':
            self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias, multi_channel = True)
        else:
            self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias, multi_channel = True)
        if deconv:
            self.deconv = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1, bias = conv_bias)
        else:
            self.deconv = None
        if bn:
            self.bn = nn.BatchNorm2d(out_ch)
        if activ == 'relu':
            self.activation = nn.ReLU()
        elif activ == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, input, input_mask):
        h, h_mask = self.conv(input, input_mask)
        if self.deconv is not None:
            h = self.deconv(h)
        if hasattr(self, 'bn'):
            h = self.bn(h)
        if hasattr(self, 'activation'):
            h = self.activation(h)
        h_mask = F.interpolate(h_mask, size = h.size()[2:])
        return h, h_mask

class VGG16FeatureExtractor(pl.LightningModule):
    def __init__(self):
        super().__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.enc_1 = nn.Sequential(*vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])

        # fix the encoder
        for i in range(3):
            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
                param.requires_grad = False

    def forward(self, image):
        results = [image]
        for i in range(3):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

class Bottleneck(pl.LightningModule):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out += residual
        out = self.relu(out)

        return out

class PRVSNet(pl.LightningModule):
    def __init__(self, layer_size=8, input_channels=3, att = False):
        super().__init__()
        self.layer_size = layer_size
        self.enc_1 = VSRLayer(3, 64, kernel_size = 7)
        self.enc_2 = VSRLayer(64, 128, kernel_size = 5)
        self.enc_3 = PConvLayer(128, 256, sample='down-5')
        self.enc_4 = PConvLayer(256, 512, sample='down-3')
        for i in range(4, self.layer_size):
            name = 'enc_{:d}'.format(i + 1)
            setattr(self, name, PConvLayer(512, 512, sample='down-3'))
        self.deconv = nn.ConvTranspose2d(512, 512, 4, 2, 1)
        for i in range(4, self.layer_size):
            name = 'dec_{:d}'.format(i + 1)
            setattr(self, name, PConvLayer(512 + 512, 512, activ='leaky', deconv = True))
        self.dec_4 = PConvLayer(512 + 256, 256, activ='leaky', deconv = True)
        if att:
            self.att = Attention.AttentionModule()
        else:
            self.att = lambda x:x
        self.dec_3 = PConvLayer(256 + 128, 128, activ='leaky', deconv = True)
        self.dec_2 = VSRLayer(128 + 64, 64, stride = 1, activation='leaky', deconv = True)
        self.dec_1 = VSRLayer(64 + input_channels, 64, stride = 1, activation = None, batch_norm = False)
        self.resolver = Bottleneck(64,16)
        self.output = nn.Conv2d(128, 3, 1)

    def forward(self, input, input_mask, input_edge):
        input = input.type(torch.cuda.FloatTensor)
        input_mask = input_mask.type(torch.cuda.FloatTensor)
        input_edge = input_edge.type(torch.cuda.FloatTensor)

        input = input * input_mask[:,0:1,:,:]
        input_edge = input_edge * input_mask[:,0:1,:,:]
        input_mask = torch.cat([input_mask]*3, dim = 1)
        input_mask = input_mask[:,:3,:,:]


        h_dict = {}  # for the output of enc_N
        h_mask_dict = {}  # for the output of enc_N
        h_edge_list = []
        h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask
        edge = input_edge

        h_key_prev = 'h_0'
        for i in range(1, self.layer_size + 1):
            l_key = 'enc_{:d}'.format(i)
            h_key = 'h_{:d}'.format(i)
            if i not in [1, 2]:
                h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(h_dict[h_key_prev], h_mask_dict[h_key_prev])
            else:
                h_dict[h_key], h_mask_dict[h_key], edge = getattr(self, l_key)(h_dict[h_key_prev], h_mask_dict[h_key_prev], edge)
                h_edge_list.append(edge)
            h_key_prev = h_key

        h_key = 'h_{:d}'.format(self.layer_size)
        h, h_mask = h_dict[h_key], h_mask_dict[h_key]
        h = self.deconv(h)
        h_mask = F.interpolate(h_mask, scale_factor = 2)

        for i in range(self.layer_size, 0, -1):
            enc_h_key = 'h_{:d}'.format(i - 1)
            dec_l_key = 'dec_{:d}'.format(i)
            h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim = 1)
            h = torch.cat([h, h_dict[enc_h_key]], dim = 1)
            if i not in [2, 1]:
                h, h_mask = getattr(self, dec_l_key)(h, h_mask)
            else:
                edge = h_edge_list[i-1]
                h, h_mask, edge = getattr(self, dec_l_key)(h, h_mask, edge)
                h_edge_list.append(edge)
            if i == 4:
                h = self.att(h)
        h_out = self.resolver(h)
        h_out = torch.cat([h_out, h], dim = 1)
        h_out = self.output(h_out)
        return h_out, h_mask, h_edge_list[-2], h_edge_list[-1]


In [None]:
#@title [CSA_arch.py](https://github.com/Yukariin/CSA_pytorch) (2019)
"""
model.py (13-12-20)
https://github.com/Yukariin/CSA_pytorch/blob/master/model.py
"""
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl

def get_norm(name, out_channels):
    if name == 'batch':
        norm = nn.BatchNorm2d(out_channels)
    elif name == 'instance':
        norm = nn.InstanceNorm2d(out_channels)
    else:
        norm = None
    return norm


def get_act(name):
    if name == 'relu':
        activation = nn.ReLU(inplace=True)
    elif name == 'elu':
        activation == nn.ELU(inplace=True)
    elif name == 'leaky_relu':
        activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)
    elif name == 'tanh':
        activation = nn.Tanh()
    elif name == 'sigmoid':
        activation = nn.Sigmoid()
    else:
        activation = None
    return activation


class CoarseEncodeBlock(pl.LightningModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 normalization=None, activation=None):
        super().__init__()

        layers = []
        if activation:
            layers.append(get_act(activation))
        layers.append(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1))
        if normalization:
            layers.append(get_norm(normalization, out_channels))
        self.encode = nn.Sequential(*layers)

    def forward(self, x):
        return self.encode(x)


class CoarseDecodeBlock(pl.LightningModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 normalization=None, activation=None):
        super().__init__()

        layers = []
        if activation:
            layers.append(get_act(activation))
        layers.append(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=1))
        if normalization:
            layers.append(get_norm(normalization, out_channels))
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)


class CoarseNet(pl.LightningModule):
    def __init__(self, c_img=3,
                 norm='instance', act_en='leaky_relu', act_de='relu'):
        super().__init__()

        cnum = 64

        self.en_1 = nn.Conv2d(c_img, cnum, 4, 2, padding=1)
        self.en_2 = CoarseEncodeBlock(cnum, cnum*2, 4, 2, normalization=norm, activation=act_en)
        self.en_3 = CoarseEncodeBlock(cnum*2, cnum*4, 4, 2, normalization=norm, activation=act_en)
        self.en_4 = CoarseEncodeBlock(cnum*4, cnum*8, 4, 2, normalization=norm, activation=act_en)
        self.en_5 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_en)
        self.en_6 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_en)
        self.en_7 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_en)
        self.en_8 = CoarseEncodeBlock(cnum*8, cnum*8, 4, 2, activation=act_en)

        self.de_8 = CoarseDecodeBlock(cnum*8, cnum*8, 4, 2, normalization=norm, activation=act_de)
        self.de_7 = CoarseDecodeBlock(cnum*8*2, cnum*8, 4, 2, normalization=norm, activation=act_de)
        self.de_6 = CoarseDecodeBlock(cnum*8*2, cnum*8, 4, 2, normalization=norm, activation=act_de)
        self.de_5 = CoarseDecodeBlock(cnum*8*2, cnum*8, 4, 2, normalization=norm, activation=act_de)
        self.de_4 = CoarseDecodeBlock(cnum*8*2, cnum*4, 4, 2, normalization=norm, activation=act_de)
        self.de_3 = CoarseDecodeBlock(cnum*4*2, cnum*2, 4, 2, normalization=norm, activation=act_de)
        self.de_2 = CoarseDecodeBlock(cnum*2*2, cnum, 4, 2, normalization=norm, activation=act_de)
        self.de_1 = nn.Sequential(
            get_act(act_de),
            nn.ConvTranspose2d(cnum*2, c_img, 4, 2, padding=1),
            get_act('tanh'))

    def forward(self, x):
        out_1 = self.en_1(x)
        out_2 = self.en_2(out_1)
        out_3 = self.en_3(out_2)
        out_4 = self.en_4(out_3)
        out_5 = self.en_5(out_4)
        out_6 = self.en_6(out_5)
        out_7 = self.en_7(out_6)
        out_8 = self.en_8(out_7)

        dout_8 = self.de_8(out_8)
        dout_8_out_7 = torch.cat([dout_8, out_7], 1)
        dout_7 = self.de_7(dout_8_out_7)
        dout_7_out_6 = torch.cat([dout_7, out_6], 1)
        dout_6 = self.de_6(dout_7_out_6)
        dout_6_out_5 = torch.cat([dout_6, out_5], 1)
        dout_5 = self.de_5(dout_6_out_5)
        dout_5_out_4 = torch.cat([dout_5, out_4], 1)
        dout_4 = self.de_4(dout_5_out_4)
        dout_4_out_3 = torch.cat([dout_4, out_3], 1)
        dout_3 = self.de_3(dout_4_out_3)
        dout_3_out_2 = torch.cat([dout_3, out_2], 1)
        dout_2 = self.de_2(dout_3_out_2)
        dout_2_out_1 = torch.cat([dout_2, out_1], 1)
        dout_1 = self.de_1(dout_2_out_1)

        return dout_1


class RefineEncodeBlock(pl.LightningModule):
    def __init__(self, in_channels, out_channels,
                 normalization=None, activation=None):
        super().__init__()

        layers = []
        if activation:
            layers.append(get_act(activation))
        layers.append(
            nn.Conv2d(in_channels, in_channels, 4, 2, dilation=2, padding=3))
        if normalization:
            layers.append(get_norm(normalization, out_channels))

        if activation:
            layers.append(get_act(activation))
        layers.append(
            nn.Conv2d(in_channels, out_channels, 3, 1, padding=1))
        if normalization:
            layers.append(get_norm(normalization, out_channels))
        self.encode = nn.Sequential(*layers)

    def forward(self, x):
        return self.encode(x)


class RefineDecodeBlock(pl.LightningModule):
    def __init__(self, in_channels, out_channels,
                 normalization=None, activation=None):
        super().__init__()

        layers = []
        if activation:
            layers.append(get_act(activation))
        layers.append(
            nn.ConvTranspose2d(in_channels, out_channels, 3, 1, padding=1))
        if normalization:
            layers.append(get_norm(normalization, out_channels))

        if activation:
            layers.append(get_act(activation))
        layers.append(
            nn.ConvTranspose2d(out_channels, out_channels, 4, 2, padding=1))
        if normalization:
            layers.append(get_norm(normalization, out_channels))
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)


class RefineNet(pl.LightningModule):
    def __init__(self, c_img=3,
                 norm='instance', act_en='leaky_relu', act_de='relu'):
        super().__init__()

        c_in = c_img + c_img
        cnum = 64

        self.en_1 = nn.Conv2d(c_in, cnum, 3, 1, padding=1)
        self.en_2 = RefineEncodeBlock(cnum, cnum*2, normalization=norm, activation=act_en)
        self.en_3 = RefineEncodeBlock(cnum*2, cnum*4, normalization=norm, activation=act_en)
        self.en_4 = RefineEncodeBlock(cnum*4, cnum*8, normalization=norm, activation=act_en)
        self.en_5 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en)
        self.en_6 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en)
        self.en_7 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en)
        self.en_8 = RefineEncodeBlock(cnum*8, cnum*8, normalization=norm, activation=act_en)
        self.en_9 = nn.Sequential(
            get_act(act_en),
            nn.Conv2d(cnum*8, cnum*8, 4, 2, padding=1))

        self.de_9 = nn.Sequential(
            get_act(act_de),
            nn.ConvTranspose2d(cnum*8, cnum*8, 4, 2, padding=1),
            get_norm(norm, cnum*8))
        self.de_8 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de)
        self.de_7 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de)
        self.de_6 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de)
        self.de_5 = RefineDecodeBlock(cnum*8*2, cnum*8, normalization=norm, activation=act_de)
        self.de_4 = RefineDecodeBlock(cnum*8*2, cnum*4, normalization=norm, activation=act_de)
        self.de_3 = RefineDecodeBlock(cnum*4*2, cnum*2, normalization=norm, activation=act_de)
        self.de_2 = RefineDecodeBlock(cnum*2*2, cnum, normalization=norm, activation=act_de)

        self.de_1 = nn.Sequential(
            get_act(act_de),
            nn.ConvTranspose2d(cnum*2, c_img, 3, 1, padding=1))


    def forward(self, x1, x2):
        x = torch.cat([x1, x2], 1)
        out_1 = self.en_1(x)
        out_2 = self.en_2(out_1)
        out_3 = self.en_3(out_2)
        out_4 = self.en_4(out_3)
        out_5 = self.en_5(out_4)
        out_6 = self.en_6(out_5)
        out_7 = self.en_7(out_6)
        out_8 = self.en_8(out_7)
        out_9 = self.en_9(out_8)

        dout_9 = self.de_9(out_9)
        dout_9_out_8 = torch.cat([dout_9, out_8], 1)
        dout_8 = self.de_8(dout_9_out_8)
        dout_8_out_7 = torch.cat([dout_8, out_7], 1)
        dout_7 = self.de_7(dout_8_out_7)
        dout_7_out_6 = torch.cat([dout_7, out_6], 1)
        dout_6 = self.de_6(dout_7_out_6)
        dout_6_out_5 = torch.cat([dout_6, out_5], 1)
        dout_5 = self.de_5(dout_6_out_5)
        dout_5_out_4 = torch.cat([dout_5, out_4], 1)
        dout_4 = self.de_4(dout_5_out_4)
        dout_4_out_3 = torch.cat([dout_4, out_3], 1)
        dout_3 = self.de_3(dout_4_out_3)
        dout_3_out_2 = torch.cat([dout_3, out_2], 1)
        dout_2 = self.de_2(dout_3_out_2)
        dout_2_out_1 = torch.cat([dout_2, out_1], 1)
        dout_1 = self.de_1(dout_2_out_1)

        return dout_1, out_4, dout_5


class CSA(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def forward(self, x, mask):
        return x


class InpaintNet(pl.LightningModule):
    def __init__(self, c_img=3, norm='instance', act_en='leaky_relu', act_de='relu'):
        super().__init__()

        self.coarse = CoarseNet(c_img=c_img, norm=norm, act_en=act_en, act_de=act_de)
        self.refine = RefineNet(c_img=c_img, norm=norm, act_en=act_en, act_de=act_de)

    def forward(self, image, mask):
        out_c = self.coarse(image)
        out_c = image * (1. - mask) + out_c * mask
        out_r, csa, csa_d = self.refine(out_c, image)
        return out_c, out_r, csa, csa_d


In [None]:
#@title CSA_loss.py
"""
# needs to be merged with loss.py
loss.py (16-12-20)
https://github.com/Yukariin/CSA_pytorch/blob/master/loss.py
"""
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms


def denorm(x):
    out = (x + 1) / 2 # [-1,1] -> [0,1]
    return out.clamp_(0, 1)


class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()

        vgg16 = models.vgg16(pretrained=True)

        self.enc_1 = nn.Sequential(*vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])
        self.enc_4 = nn.Sequential(*vgg16.features[17:23])

        #print(self.enc_1)
        #print(self.enc_2)
        #print(self.enc_3)
        #print(self.enc_4)

        # fix the encoder
        for i in range(4):
            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
                param.requires_grad = False

    def forward(self, image):
        results = [image]
        for i in range(4):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]


class ConsistencyLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        self.vgg = VGG16FeatureExtractor()
        self.vgg.cuda()

        self.l2 = nn.MSELoss()

    def forward(self, csa, csa_d, target, mask):
        # https://pytorch.org/docs/stable/torchvision/models.html
        # Pre-trained VGG16 model expect input images normalized in the same way.
        # The images have to be loaded in to a range of [0, 1]
        # and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
        t = denorm(target) # [-1,1] -> [0,1]
        t = self.normalize(t[0]) # BxCxHxW -> CxHxW -> normalize
        t = t.unsqueeze(0) # CxHxW -> BxCxHxW

        vgg_gt = self.vgg(t)
        vgg_gt = vgg_gt[-1]

        mask_r = F.interpolate(mask, size=csa.size()[2:])

        lossvalue = self.l2(csa*mask_r, vgg_gt*mask_r) + self.l2(csa_d*mask_r, vgg_gt*mask_r)
        return lossvalue


Broken generators:

Generators that are not included here since I can't seem to make them work properly:

PenNet [no AMP] (2019): [researchmm/PEN-Net-for-Inpainting](https://github.com/researchmm/PEN-Net-for-Inpainting/)
- Always outputs white for some reason.

CRA [no AMP] (2019): [wangyx240/High-Resolution-Image-Inpainting-GAN](https://github.com/wangyx240/High-Resolution-Image-Inpainting-GAN)
- Likes to create the color pink.

Global [no AMP] (2020): [SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting](https://github.com/SayedNadim/Global-and-Local-Attention-Based-Free-Form-Image-Inpainting)
- Always outputs white for some reason.

crfill (2020): [zengxianyu/crfill](https://github.com/zengxianyu/crfill)
- No clear instructions/code result in broken results. Unreleased training code makes a correct implementation harder.

---------------------------

Non-Pytorch generators:

co-mod-gan (2021): [zsyzzsoft/co-mod-gan](https://github.com/zsyzzsoft/co-mod-gan)
- Has a web demo and (a broken link to a) docker. Relies on Tensorflow / StyleGAN2 code.

Diverse-Structure-Inpainting (2021): [USTC-JialunPeng/Diverse-Structure-Inpainting](https://github.com/USTC-JialunPeng/Diverse-Structure-Inpainting)
- Tensorflow 1

R-MNet (2021): [Jireh-Jam/R-MNet-Inpainting-keras](https://github.com/Jireh-Jam/R-MNet-Inpainting-keras)
- Not sure if there is much new and interesting stuff.


Hypergraphs (2021): [GouravWadhwa/Hypergraphs-Image-Inpainting](https://github.com/GouravWadhwa/Hypergraphs-Image-Inpainting)
- Uses custom conv layer (that is implemented with tensorflow). It sounds interesting, but I got errors when I tried to port it to pytorch.

PEPSI (2019): [Forty-lock/PEPSI-Fast_image_inpainting_with_parallel_decoding_network](https://github.com/Forty-lock/PEPSI-Fast_image_inpainting_with_parallel_decoding_network)
- The net dcpV2 uses.

Region (2019): [vickyFox/Region-wise-Inpainting](https://github.com/vickyFox/Region-wise-Inpainting)

---------------------------

Pytorch generators that I never tested:

VCNET (2020): [birdortyedi/vcnet-blind-image-inpainting](https://github.com/birdortyedi/vcnet-blind-image-inpainting)
- Blind image inpainting without masks.

DFMA (2020): [mprzewie/dmfa_inpainting](https://github.com/mprzewie/dmfa_inpainting)


GIN (2020): [rlct1/gin-sg](https://github.com/rlct1/gin-sg) and [rlct1/gin](https://github.com/rlct1/gin)


StructureFlow (2019): [RenYurui/StructureFlow](https://github.com/RenYurui/StructureFlow)
- Needs special files.

GMCNN (2018): [shepnerd/inpainting_gmcnn](https://github.com/shepnerd/inpainting_gmcnn)
- The net dcpV1 used iirc.

ShiftNet (2018): [Zhaoyi-Yan/Shift-Net_pytorch](https://github.com/Zhaoyi-Yan/Shift-Net_pytorch)
---------------------------

No training code:

SC-FEGAN (2019): [run-youngjoo/SC-FEGAN](https://github.com/run-youngjoo/SC-FEGAN)

# Upscale Generators

In [None]:
#@title ESRGAN_arch.py (dataloader not implemented)
"""
RRDBNet_arch.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/architectures/RRDBNet_arch.py
"""
import math
import torch
import torch.nn as nn
#import torchvision
#from . import block as B
import functools
#from . import spectral_norm as SN


####################
# RRDBNet Generator (original architecture)
####################

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', \
            finalact=None, gaussian_noise=False, plus=False):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
        rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, \
            gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
        LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)

        if upsample_mode == 'upconv':
            upsample_block = upconv_block
        elif upsample_mode == 'pixelshuffle':
            upsample_block = pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
        HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
        HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)

        # Note: this option adds new parameters to the architecture, another option is to use "outm" in the forward
        outact = act(finalact) if finalact else None
        
        self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1, outact)

    def forward(self, x, outm=None):
        x = self.model(x)
        
        if outm=='scaltanh': # limit output range to [-1,1] range with tanh and rescale to [0,1] Idea from: https://github.com/goldhuang/SRGAN-PyTorch/blob/master/model.py
            return(torch.tanh(x) + 1.0) / 2.0
        elif outm=='tanh': # limit output to [-1,1] range
            return torch.tanh(x)
        elif outm=='sigmoid': # limit output to [0,1] range
            return torch.sigmoid(x)
        elif outm=='clamp':
            return torch.clamp(x, min=0.0, max=1.0)
        else: #Default, no cap for the output
            return x

class RRDB(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(RRDB, self).__init__()
        # This is for backwards compatibility with existing models
        if nr == 3:
            self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
        else:
            RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                                              norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                                              gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
            self.RDBs = nn.Sequential(*RDB_list)

    def forward(self, x):
        if hasattr(self, 'RDB1'):
            out = self.RDB1(x)
            out = self.RDB2(out)
            out = self.RDB3(out)
        else:
            out = self.RDBs(x)
        return out * 0.2 + x

class ResidualDenseBlock_5C(nn.Module):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    Modified options that can be used:
        - "Partial Convolution based Padding" arXiv:1811.11718
        - "Spectral normalization" arXiv:1802.05957
        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. 
            {Rakotonirina} and A. {Rasoanaivo}
    
    Args:
        nf (int): Channel number of intermediate features (num_feat).
        gc (int): Channels for each growth (num_grow_ch: growth channel, 
            i.e. intermediate channels).
        convtype (str): the type of convolution to use. Default: 'Conv2D'
        gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new 
            trainable parameters)
        plus (bool): enable the additional residual paths from ESRGAN+ 
            (adds trainable parameters)
    '''

    def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(ResidualDenseBlock_5C, self).__init__()
        
        ## +
        self.noise = GaussianNoise() if gaussian_noise else None
        self.conv1x1 = conv1x1(nf, gc) if plus else None
        ## +

        self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        if self.conv1x1:
            x2 = x2 + self.conv1x1(x) #+
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        if self.conv1x1:
            x4 = x4 + x2 #+
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        if self.noise:
            return self.noise(x5.mul(0.2) + x)
        else:
            return x5 * 0.2 + x


####################
# RRDBNet Generator (modified/"new" architecture)
####################


class MRRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(MRRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDBM, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

class ResidualDenseBlock_5CM(nn.Module):
    '''
    Residual Dense Block
    '''
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5CM, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], scale=0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDBM(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDBM, self).__init__()
        self.RDB1 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB2 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB3 = ResidualDenseBlock_5CM(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


# Other Generators

In [None]:
#@title [Deoldify](https://github.com/alfagao/DeOldify) (2018) (dataloader not implemented)

"""
torch_imports.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fastai/torch_imports.py

conv_learner.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fastai/conv_learner.py

model.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fastai/model.py

modules.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fasterai/modules.py

generators.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fasterai/generators.py
"""

from abc import ABC, abstractmethod
from torchvision import transforms
from torch.nn.utils.spectral_norm import spectral_norm
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
import pytorch_lightning as pl
from torchvision.models import vgg16_bn, vgg19_bn

def vgg16(pre): return children(vgg16_bn(pre))[0]
def vgg19(pre): return children(vgg19_bn(pre))[0]

def cut_model(m, cut):
    return list(m.children())[:cut] if cut else [m]
"""
model_meta = {
    resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6],
    vgg16:[0,22], vgg19:[0,22],
    resnext50:[8,6], resnext101:[8,6], resnext101_64:[8,6],
    wrn:[8,6], inceptionresnet_2:[-2,9], inception_4:[-1,9],
    dn121:[0,7], dn161:[0,7], dn169:[0,7], dn201:[0,7],
}
"""

model_meta = {
    resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6],
    vgg16:[0,22], vgg19:[0,22],
}



class ConvBlock(pl.LightningModule):
    def __init__(self, ni:int, no:int, ks:int=3, stride:int=1, pad:int=None, actn:bool=True, 
            bn:bool=True, bias:bool=True, sn:bool=False, leakyReLu:bool=False, self_attention:bool=False,
            inplace_relu:bool=True):
        super().__init__()   
        if pad is None: pad = ks//2//stride

        if sn:
            layers = [spectral_norm(nn.Conv2d(ni, no, ks, stride, padding=pad, bias=bias))]
        else:
            layers = [nn.Conv2d(ni, no, ks, stride, padding=pad, bias=bias)]
        if actn:
            layers.append(nn.LeakyReLU(0.2, inplace=inplace_relu)) if leakyReLu else layers.append(nn.ReLU(inplace=inplace_relu)) 
        if bn:
            layers.append(nn.BatchNorm2d(no))
        if self_attention:
            layers.append(SelfAttention(no, 1))

        self.seq = nn.Sequential(*layers)

    def forward(self, x):
        return self.seq(x)


class UpSampleBlock(pl.LightningModule):
    @staticmethod
    def _conv(ni:int, nf:int, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
        layers = [ConvBlock(ni, nf, ks=ks, sn=sn, bn=bn, actn=False, leakyReLu=leakyReLu)]
        return nn.Sequential(*layers)

    @staticmethod
    def _icnr(x:torch.Tensor, scale:int=2):
        init=nn.init.kaiming_normal_
        new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
        subkernel = torch.zeros(new_shape)
        subkernel = init(subkernel)
        subkernel = subkernel.transpose(0, 1)
        subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                                subkernel.shape[1], -1)
        kernel = subkernel.repeat(1, 1, scale ** 2)
        transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
        kernel = kernel.contiguous().view(transposed_shape)
        kernel = kernel.transpose(0, 1)
        return kernel

    def __init__(self, ni:int, nf:int, scale:int=2, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
        super().__init__()
        layers = []
        assert (math.log(scale,2)).is_integer()

        for i in range(int(math.log(scale,2))):
            layers += [UpSampleBlock._conv(ni, nf*4,ks=ks, bn=bn, sn=sn, leakyReLu=leakyReLu), 
                nn.PixelShuffle(2)]
            if bn:
                layers += [nn.BatchNorm2d(nf)]

            ni = nf
                       
        self.sequence = nn.Sequential(*layers)
        self._icnr_init()
        
    def _icnr_init(self):
        conv_shuffle = self.sequence[0][0].seq[0]
        kernel = UpSampleBlock._icnr(conv_shuffle.weight)
        conv_shuffle.weight.data.copy_(kernel)
    
    def forward(self, x):
        return self.sequence(x)


class UnetBlock(pl.LightningModule):
    def __init__(self, up_in:int , x_in:int , n_out:int, bn:bool=True, sn:bool=False, leakyReLu:bool=False, 
            self_attention:bool=False, inplace_relu:bool=True):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = ConvBlock(x_in,  x_out,  ks=1, bn=False, actn=False, sn=sn, inplace_relu=inplace_relu)
        self.tr_conv = UpSampleBlock(up_in, up_out, 2, bn=bn, sn=sn, leakyReLu=leakyReLu)
        self.relu = nn.LeakyReLU(0.2, inplace=inplace_relu) if leakyReLu else nn.ReLU(inplace=inplace_relu)
        out_layers = []
        if bn: 
            out_layers.append(nn.BatchNorm2d(n_out))
        if self_attention:
            out_layers.append(SelfAttention(n_out))
        self.out = nn.Sequential(*out_layers)
        
        
    def forward(self, up_p:int, x_p:int):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        x = torch.cat([up_p,x_p], dim=1)
        x = self.relu(x)
        return self.out(x)

class SaveFeatures():
    features=None
    def __init__(self, m:pl.LightningModule): 
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): 
        self.features = output
    def remove(self): 
        self.hook.remove()

class SelfAttention(pl.LightningModule):
    def __init__(self, in_channel:int, gain:int=1):
        super().__init__()
        self.query = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
        self.key = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
        self.value = self._spectral_init(nn.Conv1d(in_channel, in_channel, 1), gain=gain)
        self.gamma = nn.Parameter(torch.tensor(0.0))

    def _spectral_init(self, module:pl.LightningModule, gain:int=1):
        nn.init.kaiming_uniform_(module.weight, gain)
        if module.bias is not None:
            module.bias.data.zero_()

        return spectral_norm(module)

    def forward(self, input:torch.Tensor):
        shape = input.shape
        flatten = input.view(shape[0], shape[1], -1)
        query = self.query(flatten).permute(0, 2, 1)
        key = self.key(flatten)
        value = self.value(flatten)
        query_key = torch.bmm(query, key)
        attn = F.softmax(query_key, 1)
        attn = torch.bmm(value, attn)
        attn = attn.view(*shape)
        out = self.gamma * attn + input
        return out


class GeneratorModule(ABC, nn.Module):
    def __init__(self):
        super().__init__()
    
    def set_trainable(self, trainable:bool):
        set_trainable(self, trainable)

    @abstractmethod
    def get_layer_groups(self, precompute:bool=False)->[]:
        pass

    @abstractmethod
    def forward(self, x_in:torch.Tensor, max_render_sz:int=400):
        pass
        
    def freeze_to(self, n:int):
        c=self.get_layer_groups()
        for l in c:     set_trainable(l, False)
        for l in c[n:]: set_trainable(l, True)

    def get_device(self):
        return next(self.parameters()).device


class AbstractUnet(GeneratorModule): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__()
        assert (math.log(scale,2)).is_integer()
        self.rn, self.lr_cut = self._get_pretrained_resnet_base()
        ups = self._get_decoding_layers(nf_factor=nf_factor, scale=scale)
        self.relu = nn.ReLU()
        self.up1 = ups[0]
        self.up2 = ups[1]
        self.up3 = ups[2]
        self.up4 = ups[3]
        self.up5 = ups[4]
        self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=True), nn.Tanh())

    @abstractmethod
    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        pass

    @abstractmethod
    def _get_decoding_layers(self, nf_factor:int, scale:int):
        pass

    #Gets around irritating inconsistent halving coming from resnet
    def _pad(self, x:torch.Tensor, target:torch.Tensor, total_padh:int, total_padw:int)-> torch.Tensor:
        h = x.shape[2] 
        w = x.shape[3]

        target_h = target.shape[2]*2
        target_w = target.shape[3]*2

        if h<target_h or w<target_w:
            padh = target_h-h if target_h > h else 0
            total_padh = total_padh + padh
            padw = target_w-w if target_w > w else 0
            total_padw = total_padw + padw
            return (F.pad(x, (0,padw,0,padh), "reflect",0), total_padh, total_padw)

        return (x, total_padh, total_padw)

    def _remove_padding(self, x:torch.Tensor, padh:int, padw:int)->torch.Tensor:
        if padw == 0 and padh == 0:
            return x 
        
        target_h = x.shape[2]-padh
        target_w = x.shape[3]-padw
        return x[:,:,:target_h, :target_w]

    def _encode(self, x:torch.Tensor):
        x = self.rn[0](x)
        x = self.rn[1](x)
        x = self.rn[2](x)
        enc0 = x
        x = self.rn[3](x)
        x = self.rn[4](x)
        enc1 = x
        x = self.rn[5](x)
        enc2 = x
        x = self.rn[6](x)
        enc3 = x
        x = self.rn[7](x)
        return (x, enc0, enc1, enc2, enc3)

    def _decode(self, x:torch.Tensor, enc0:torch.Tensor, enc1:torch.Tensor, enc2:torch.Tensor, enc3:torch.Tensor):
        padh = 0
        padw = 0
        x = self.relu(x)
        enc3, padh, padw = self._pad(enc3, x, padh, padw)
        x = self.up1(x, enc3)
        enc2, padh, padw  = self._pad(enc2, x, padh, padw)
        x = self.up2(x, enc2)
        enc1, padh, padw  = self._pad(enc1, x, padh, padw)
        x = self.up3(x, enc1)
        enc0, padh, padw  = self._pad(enc0, x, padh, padw)
        x = self.up4(x, enc0)
        #This is a bit too much padding being removed, but I 
        #haven't yet figured out a good way to determine what 
        #exactly should be removed.  This is consistently more 
        #than enough though.
        x = self.up5(x)
        x = self.out(x)
        x = self._remove_padding(x, padh, padw)
        return x

    def forward(self, x:torch.Tensor):
        x, enc0, enc1, enc2, enc3 = self._encode(x)
        x = self._decode(x, enc0, enc1, enc2, enc3)
        return x
    
    def get_layer_groups(self, precompute:bool=False)->[]:
        lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
        return lgs + [children(self)[1:]]
    
    def close(self):
        for sf in self.sfs: 
            sf.remove()


class Unet34(AbstractUnet): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__(nf_factor=nf_factor, scale=scale)

    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        f = resnet34
        cut,lr_cut = model_meta[f]
        cut-=layers_cut
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def _get_decoding_layers(self, nf_factor:int, scale:int):
        self_attention=True
        bn=True
        sn=True
        leakyReLu=False
        layers = []
        layers.append(UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
        return layers 


class Unet101(AbstractUnet): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__(nf_factor=nf_factor, scale=scale)

    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        f = resnet101
        cut,lr_cut = model_meta[f]
        cut-=layers_cut
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def _get_decoding_layers(self, nf_factor:int, scale:int):
        self_attention=True
        bn=True
        sn=True
        leakyReLu=False
        layers = []
        layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
        return layers 

class Unet152(AbstractUnet): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__(nf_factor=nf_factor, scale=scale)

    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        f = resnet152
        cut,lr_cut = model_meta[f]
        cut-=layers_cut
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def _get_decoding_layers(self, nf_factor:int, scale:int):
        self_attention=True
        bn=True
        sn=True
        leakyReLu=False
        layers = []
        layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
        return layers 


# Experimental Generators

In [None]:
#@title DSNetRRDB
#@markdown Combining DSNet and RRDB, where DSNet is the first stange and RRDB is the second stage.
"""
DSNet.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/DSNet.py

RegionNorm.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/RegionNorm.py

ValidMigration.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/ValidMigration.py

Attention.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/Attention.py

deform_conv.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/deform_conv.py
"""
#from modules.Attention import PixelContextualAttention
#from modules.RegionNorm import RBNModule, RCNModule
#from modules.ValidMigration import ConvOffset2D
#from modules.deform_conv import th_batch_map_offsets, th_generate_grid
from __future__ import absolute_import, division
from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates
from torch.autograd import Variable
from torchvision import models
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


def th_flatten(a):
    """Flatten tensor"""
    return a.contiguous().view(a.nelement())


def th_repeat(a, repeats, axis=0):
    """Torch version of np.repeat for 1D"""
    assert len(a.size()) == 1
    return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))


def np_repeat_2d(a, repeats):
    """Tensorflow version of np.repeat for 2D"""

    assert len(a.shape) == 2
    a = np.expand_dims(a, 0)
    a = np.tile(a, [repeats, 1, 1])
    return a


def th_gather_2d(input, coords):
    inds = coords[:, 0]*input.size(1) + coords[:, 1]
    x = torch.index_select(th_flatten(input), 0, inds)
    return x.view(coords.size(0))


def th_map_coordinates(input, coords, order=1):
    """Tensorflow verion of scipy.ndimage.map_coordinates
    Note that coords is transposed and only 2D is supported
    Parameters
    ----------
    input : tf.Tensor. shape = (s, s)
    coords : tf.Tensor. shape = (n_points, 2)
    """

    assert order == 1
    input_size = input.size(0)

    coords = torch.clamp(coords, 0, input_size - 1)
    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1)
    coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1)

    vals_lt = th_gather_2d(input,  coords_lt.detach())
    vals_rb = th_gather_2d(input,  coords_rb.detach())
    vals_lb = th_gather_2d(input,  coords_lb.detach())
    vals_rt = th_gather_2d(input,  coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())

    vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0]
    vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0]
    mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1]
    return mapped_vals


def sp_batch_map_coordinates(inputs, coords):
    """Reference implementation for batch_map_coordinates"""
    # coords = coords.clip(0, inputs.shape[1] - 1)

    assert (coords.shape[2] == 2)
    height = coords[:,:,0].clip(0, inputs.shape[1] - 1)
    width = coords[:,:,1].clip(0, inputs.shape[2] - 1)
    np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2)

    mapped_vals = np.array([
        sp_map_coordinates(input, coord.T, mode='nearest', order=1)
        for input, coord in zip(inputs, coords)
    ])
    return mapped_vals


def th_batch_map_coordinates(input, coords, order=1):
    """Batch version of th_map_coordinates
    Only supports 2D feature maps
    Parameters
    ----------
    input : tf.Tensor. shape = (b, s, s)
    coords : tf.Tensor. shape = (b, n_points, 2)
    Returns
    -------
    tf.Tensor. shape = (b, s, s)
    """

    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    n_coords = coords.size(1)

    # coords = torch.clamp(coords, 0, input_size - 1)

    coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2)

    assert (coords.size(1) == n_coords)

    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
    coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2)
    idx = th_repeat(torch.arange(0, batch_size), n_coords).long()
    idx = Variable(idx, requires_grad=False)
    if input.is_cuda:
        idx = idx.cuda()

    def _get_vals_by_coords(input, coords):
        indices = torch.stack([
            idx, th_flatten(coords[..., 0]), th_flatten(coords[..., 1])
        ], 1)
        inds = indices[:, 0]*input.size(1)*input.size(2)+ indices[:, 1]*input.size(2) + indices[:, 2]
        vals = th_flatten(input).index_select(0, inds)
        vals = vals.view(batch_size, n_coords)
        return vals

    vals_lt = _get_vals_by_coords(input, coords_lt.detach())
    vals_rb = _get_vals_by_coords(input, coords_rb.detach())
    vals_lb = _get_vals_by_coords(input, coords_lb.detach())
    vals_rt = _get_vals_by_coords(input, coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())
    vals_t = coords_offset_lt[..., 0]*(vals_rt - vals_lt) + vals_lt
    vals_b = coords_offset_lt[..., 0]*(vals_rb - vals_lb) + vals_lb
    mapped_vals = coords_offset_lt[..., 1]* (vals_b - vals_t) + vals_t
    return mapped_vals


def sp_batch_map_offsets(input, offsets):
    """Reference implementation for tf_batch_map_offsets"""

    batch_size = input.shape[0]
    input_height = input.shape[1]
    input_width = input.shape[2]

    offsets = offsets.reshape(batch_size, -1, 2)
    grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2)
    grid = np.repeat([grid], batch_size, axis=0)
    coords = offsets + grid
    # coords = coords.clip(0, input_size - 1)

    mapped_vals = sp_batch_map_coordinates(input, coords)
    return mapped_vals


def th_generate_grid(batch_size, input_height, input_width, dtype, cuda):
    grid = np.meshgrid(
        range(input_height), range(input_width), indexing='ij'
    )
    grid = np.stack(grid, axis=-1)
    grid = grid.reshape(-1, 2)

    grid = np_repeat_2d(grid, batch_size)
    grid = torch.from_numpy(grid).type(dtype)
    if cuda:
        grid = grid.cuda()
    return Variable(grid, requires_grad=False)


def th_batch_map_offsets(input, offsets, grid=None, order=1):
    """Batch map offsets into input
    Parameters
    ---------
    input : torch.Tensor. shape = (b, s, s)
    offsets: torch.Tensor. shape = (b, s, s, 2)
    Returns
    -------
    torch.Tensor. shape = (b, s, s)
    """
    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    offsets = offsets.view(batch_size, -1, 2)
    if grid is None:
        grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda)

    coords = offsets + grid

    mapped_vals = th_batch_map_coordinates(input, coords)
    return mapped_vals


class SEModule(pl.LightningModule):
    def __init__(self, num_channel, squeeze_ratio=1.0):
        super(SEModule, self).__init__()
        self.sequeeze_mod = nn.AdaptiveAvgPool2d(1)
        self.num_channel = num_channel

        blocks = [nn.Linear(num_channel, int(num_channel * squeeze_ratio)),
                  nn.ReLU(),
                  nn.Linear(int(num_channel * squeeze_ratio), num_channel),
                  nn.Sigmoid()]
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        ori = x
        x = self.sequeeze_mod(x)
        x = x.view(x.size(0), 1, self.num_channel)
        x = self.blocks(x)
        x = x.view(x.size(0), self.num_channel, 1, 1)
        x = ori * x
        return x


class ContextualAttentionModule(pl.LightningModule):

    def __init__(self, patch_size=3, propagate_size=3, stride=1):
        super(ContextualAttentionModule, self).__init__()
        self.patch_size = patch_size
        self.propagate_size = propagate_size
        self.stride = stride
        self.prop_kernels = None

    def forward(self, foreground, masks):
        ###assume the masked area has value 1
        bz, nc, w, h = foreground.size()
        if masks.size(3) != foreground.size(3):
            masks = F.interpolate(masks, foreground.size()[2:])
        background = foreground.clone()
        background = background * masks
        background = F.pad(background,
                           [self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2])
        conv_kernels_all = background.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size,
                                                                                     self.stride).contiguous().view(bz,
                                                                                                                    nc,
                                                                                                                    -1,
                                                                                                                    self.patch_size,
                                                                                                                    self.patch_size)
        conv_kernels_all = conv_kernels_all.transpose(2, 1)
        output_tensor = []
        for i in range(bz):
            mask = masks[i:i + 1]
            feature_map = foreground[i:i + 1].contiguous()
            # form convolutional kernels
            conv_kernels = conv_kernels_all[i] + 0.0000001
            norm_factor = torch.sum(conv_kernels ** 2, [1, 2, 3], keepdim=True) ** 0.5
            conv_kernels = conv_kernels / norm_factor

            conv_result = F.conv2d(feature_map, conv_kernels, padding=self.patch_size // 2)
            """
            if self.propagate_size != 1:
                if self.prop_kernels is None:
                    self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
                    self.prop_kernels.requires_grad = False
                    self.prop_kernels = self.prop_kernels.cuda()
                conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))

            """

            self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
            self.prop_kernels.requires_grad = False
            self.prop_kernels = self.prop_kernels.cuda()
            conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))
            
            attention_scores = F.softmax(conv_result, dim=1)
            ##propagate the scores
            recovered_foreground = F.conv_transpose2d(attention_scores, conv_kernels, stride=1,
                                                      padding=self.patch_size // 2)
            # average the recovered value, at the same time make non-masked area 0
            recovered_foreground = (recovered_foreground * (1 - mask)) / (self.patch_size ** 2)
            # recover the image
            final_output = recovered_foreground + feature_map * mask
            output_tensor.append(final_output)
        return torch.cat(output_tensor, dim=0)


class PixelContextualAttention(pl.LightningModule):

    def __init__(self, inchannel, patch_size_list=[1], propagate_size_list=[3], stride_list=[1]):
        assert isinstance(patch_size_list,
                          list), "patch_size should be a list containing scales, or you should use Contextual Attention to initialize your module"
        assert len(patch_size_list) == len(propagate_size_list) and len(propagate_size_list) == len(
            stride_list), "the input_lists should have same lengths"
        super(PixelContextualAttention, self).__init__()
        for i in range(len(patch_size_list)):
            name = "CA_{:d}".format(i)
            setattr(self, name, ContextualAttentionModule(patch_size_list[i], propagate_size_list[i], stride_list[i]))
        self.num_of_modules = len(patch_size_list)
        self.SqueezeExc = SEModule(inchannel * 2)
        self.combiner = nn.Conv2d(inchannel * 2, inchannel, kernel_size=1)

    def forward(self, foreground, mask):
        outputs = [foreground]
        for i in range(self.num_of_modules):
            name = "CA_{:d}".format(i)
            CA_module = getattr(self, name)
            outputs.append(CA_module(foreground, mask))
        outputs = torch.cat(outputs, dim=1)
        outputs = self.SqueezeExc(outputs)
        outputs = self.combiner(outputs)
        return outputs




class ConvOffset2D(nn.Conv2d):
    """ConvOffset2D

    Convolutional layer responsible for learning the 2D offsets and output the
    deformed feature map using bilinear interpolation

    Note that this layer does not perform convolution on the deformed feature
    map. See get_deform_cnn in cnn.py for usage
    """
    def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
        """Init

        Parameters
        ----------
        filters : int
            Number of channel of the input feature map
        init_normal_stddev : float
            Normal kernel initialization
        **kwargs:
            Pass to superclass. See Con2d layer in pytorch
        """
        self.filters = filters
        self._grid_param = None
        super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs)
        self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))

    def forward(self, x):
        """Return the deformed featured map"""
        x_shape = x.size()
        offsets = super(ConvOffset2D, self).forward(x)

        # offsets: (b*c, h, w, 2)
        offsets = self._to_bc_h_w_2(offsets, x_shape)

        # x: (b*c, h, w)
        x = self._to_bc_h_w(x, x_shape)

        # X_offset: (b*c, h, w)
        x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x))

        # x_offset: (b, h, w, c)
        x_offset = self._to_b_c_h_w(x_offset, x_shape)

        return x_offset

    @staticmethod
    def _get_grid(self, x):
        batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
        dtype, cuda = x.data.type(), x.data.is_cuda
        if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
            return self._grid
        self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
        self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
        return self._grid

    @staticmethod
    def _init_weights(weights, std):
        fan_out = weights.size(0)
        fan_in = weights.size(1) * weights.size(2) * weights.size(3)
        w = np.random.normal(0.0, std, (fan_out, fan_in))
        return torch.from_numpy(w.reshape(weights.size()))

    @staticmethod
    def _to_bc_h_w_2(x, x_shape):
        """(b, 2c, h, w) -> (b*c, h, w, 2)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2)
        return x

    @staticmethod
    def _to_bc_h_w(x, x_shape):
        """(b, c, h, w) -> (b*c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
        return x

    @staticmethod
    def _to_b_c_h_w(x, x_shape):
        """(b*c, h, w) -> (b, c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))
        return x





class RBNModule(pl.LightningModule):
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                     'running_mean', 'running_var', 'num_batches_tracked']

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, track_running_stats=True):
        super(RBNModule, self).__init__()
        self.num_features = num_features
        self.track_running_stats = track_running_stats
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
            self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, mask_t):
        input_m = input * mask_t
        if self.training:
            mask_mean = torch.mean(mask_t, (0, 2, 3), True)
            x_mean = torch.mean(input_m, (0, 2, 3), True) / mask_mean
            x_var = torch.mean(((input_m - x_mean) * mask_t) ** 2, (0, 2, 3), True) / mask_mean

            x_out = self.weight * (input_m - x_mean) / torch.sqrt(x_var + self.eps) + self.bias

            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * x_mean.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * x_var.data)
        else:
            x_out = self.weight * (input_m - self.running_mean) / torch.sqrt(self.running_var + self.eps) + self.bias
        return x_out * mask_t + input * (1 - mask_t)


class RCNModule(pl.LightningModule):
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                     'running_mean', 'running_var', 'num_batches_tracked']

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, track_running_stats=True):
        super(RCNModule, self).__init__()
        self.num_features = num_features
        self.track_running_stats = track_running_stats
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
            self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, mask_t):
        input_m = input * mask_t

        if self.training:
            mask_mean_bn = torch.mean(mask_t, (0, 2, 3), True)
            mean_bn = torch.mean(input_m, (0, 2, 3), True) / mask_mean_bn
            var_bn = torch.mean(((input_m - mean_bn) * mask_t) ** 2, (0, 2, 3), True) / mask_mean_bn

            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * mean_bn.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        mask_mean_in = torch.mean(mask_t, (2, 3), True)
        mean_in = torch.mean(input_m, (2, 3), True) / mask_mean_in
        var_in = torch.mean(((input_m - mean_in) * mask_t) ** 2, (2, 3), True) / mask_mean_in

        mask_mean_ln = torch.mean(mask_t, (1, 2, 3), True)
        mean_ln = torch.mean(input_m, (1, 2, 3), True) / mask_mean_ln
        var_ln = torch.mean(((input_m - mean_ln) * mask_t) ** 2, (1, 2, 3), True) / mask_mean_ln

        mean_weight = F.softmax(self.mean_weight)
        var_weight = F.softmax(self.var_weight)

        x_mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
        x_var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn

        x_out = self.weight * (input_m - x_mean) / torch.sqrt(x_var + self.eps) + self.bias
        return x_out * mask_t + input * (1 - mask_t)


class DSModule(pl.LightningModule):
    def __init__(self, in_ch, out_ch, bn=False, rn=True, sample='none-3', activ='relu',
                 conv_bias=False, defor=True):
        super().__init__()
        if sample == 'down-5':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 5, 2, 2, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(5,2,2)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        elif sample == 'down-7':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 7, 2, 3, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(7, 2, 3)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        elif sample == 'down-3':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 3, 2, 1, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(3, 2, 1)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        else:
            self.conv = nn.Conv2d(in_ch+2, out_ch, 3, 1, 1, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(3,1,1)
            if defor:
                self.offset0 = ConvOffset2D(in_ch-out_ch+1)
                self.offset1 = ConvOffset2D(out_ch+1)
        self.in_ch = in_ch
        self.out_ch = out_ch

        if bn:
            self.bn = nn.BatchNorm2d(out_ch)
        if rn:
            # Regional Composite Normalization
            self.rn = RCNModule(out_ch)

            # Regional Batch Normalization
            # self.rn = RBNModule(out_ch)
        if activ == 'relu':
            self.activation = nn.ReLU(inplace = True)
        elif activ == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2, inplace = True)

    def forward(self, input, input_mask):
        if hasattr(self, 'offset'):
            input = torch.cat([input, input_mask[:,:1,:,:]], dim = 1)
            h = self.offset(input)
            h = input*input_mask[:,:1,:,:] + (1-input_mask[:,:1,:,:])*h
            h = self.conv(h)
            h_mask = self.updatemask(input_mask[:,:1,:,:])
            h = h*h_mask
            h = self.rn(h, h_mask)
        elif hasattr(self, 'offset0'):
            h1_in = torch.cat([input[:,self.in_ch-self.out_ch:,:,:], input_mask[:,1:,:,:]], dim = 1)
            m1_in = input_mask[:,1:,:,:]
            h0 = torch.cat([input[:,:self.in_ch-self.out_ch,:,:], input_mask[:,:1,:,:]], dim = 1)
            h1 = self.offset1(h1_in)
            h1 = m1_in*h1_in + (1-m1_in)*h1
            h = self.conv(torch.cat([h0,h1], dim = 1))
            h = self.rn(h, input_mask[:,:1,:,:])
            h_mask = F.interpolate(input_mask[:,:1,:,:], scale_factor=2, mode='nearest')
        else:
            h = self.conv(torch.cat([input, input_mask[:,:,:,:]], dim = 1))
            h_mask = self.updatemask(input_mask[:,:1,:,:])
            h = h*h_mask

        if hasattr(self, 'bn'):
            h = self.bn(h)
        if hasattr(self, 'activation'):
            h = self.activation(h)
        return h, h_mask


class DSNet(pl.LightningModule):
    def __init__(self, layer_size=8, input_channels=3, upsampling_mode='nearest'):
        super().__init__()
        self.freeze_enc_bn = False
        self.upsampling_mode = upsampling_mode
        self.layer_size = layer_size
        self.enc_1 = DSModule(input_channels, 64, rn=False, sample='down-7', defor = False)
        self.enc_2 = DSModule(64, 128, sample='down-5')
        self.enc_3 = DSModule(128, 256, sample='down-5')
        self.enc_4 = DSModule(256, 512, sample='down-3')
        for i in range(4, self.layer_size):
            name = 'enc_{:d}'.format(i + 1)
            setattr(self, name, DSModule(512, 512, sample='down-3'))

        for i in range(4, self.layer_size):
            name = 'dec_{:d}'.format(i + 1)
            setattr(self, name, DSModule(512 + 512, 512, activ='leaky'))
        self.dec_4 = DSModule(512 + 256, 256, activ='leaky')
        self.dec_3 = DSModule(256 + 128, 128, activ='leaky')
        self.dec_2 = DSModule(128 + 64, 64, activ='leaky')
        self.dec_1 = DSModule(64 + input_channels, input_channels,
                              rn=False, activ=None, defor = False)
        self.att = PixelContextualAttention(128)
    def forward(self, input, input_mask):
        input = input.type(torch.cuda.FloatTensor)
        input_mask = input_mask.type(torch.cuda.FloatTensor)

        input_mask = input_mask[:,0:1,:,:]
        h_dict = {}  # for the output of enc_N
        h_mask_dict = {}  # for the output of enc_N

        h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask

        h_key_prev = 'h_0'
        for i in range(1, self.layer_size + 1):
            l_key = 'enc_{:d}'.format(i)
            h_key = 'h_{:d}'.format(i)
            h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
                h_dict[h_key_prev], h_mask_dict[h_key_prev])
            h_key_prev = h_key

        h_key = 'h_{:d}'.format(self.layer_size)
        h, h_mask = h_dict[h_key], h_mask_dict[h_key]
        h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')

        for i in range(self.layer_size, 0, -1):
            enc_h_key = 'h_{:d}'.format(i - 1)
            dec_l_key = 'dec_{:d}'.format(i)

            h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)

            h = torch.cat([h, h_dict[enc_h_key]], dim=1)
            h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
            h, h_mask = getattr(self, dec_l_key)(h, h_mask)
            if i == 3:
                h = self.att(h, input_mask[:,:1,:,:])
        #return h, h_mask
        return h


"""
block.py (9-3-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/architectures/block.py
"""

def sequential(*args):
    # Flatten Sequential. It unwraps nn.Sequential.
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)

"""
RRDBNet_arch.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/architectures/RRDBNet_arch.py
"""
import math
import torch
import torch.nn as nn
#import torchvision
#from . import block as B
import functools
#from . import spectral_norm as SN


####################
# RRDBNet Generator (original architecture)
####################

class RRDBNet(pl.LightningModule):
    def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', \
            finalact=None, gaussian_noise=False, plus=False):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
        rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, \
            gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
        LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)

        if upsample_mode == 'upconv':
            upsample_block = upconv_block
        elif upsample_mode == 'pixelshuffle':
            upsample_block = pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
        HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
        HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)

        # Note: this option adds new parameters to the architecture, another option is to use "outm" in the forward
        outact = act(finalact) if finalact else None
        
        self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1, outact)

    def forward(self, x, outm=None):
        x = self.model(x)
        
        if outm=='scaltanh': # limit output range to [-1,1] range with tanh and rescale to [0,1] Idea from: https://github.com/goldhuang/SRGAN-PyTorch/blob/master/model.py
            return(torch.tanh(x) + 1.0) / 2.0
        elif outm=='tanh': # limit output to [-1,1] range
            return torch.tanh(x)
        elif outm=='sigmoid': # limit output to [0,1] range
            return torch.sigmoid(x)
        elif outm=='clamp':
            return torch.clamp(x, min=0.0, max=1.0)
        else: #Default, no cap for the output
            return x

class RRDB(pl.LightningModule):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(RRDB, self).__init__()
        # This is for backwards compatibility with existing models
        if nr == 3:
            self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
        else:
            RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                                              norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                                              gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
            self.RDBs = nn.Sequential(*RDB_list)

    def forward(self, x):
        if hasattr(self, 'RDB1'):
            out = self.RDB1(x)
            out = self.RDB2(out)
            out = self.RDB3(out)
        else:
            out = self.RDBs(x)
        return out * 0.2 + x

class ResidualDenseBlock_5C(pl.LightningModule):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    Modified options that can be used:
        - "Partial Convolution based Padding" arXiv:1811.11718
        - "Spectral normalization" arXiv:1802.05957
        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. 
            {Rakotonirina} and A. {Rasoanaivo}
    
    Args:
        nf (int): Channel number of intermediate features (num_feat).
        gc (int): Channels for each growth (num_grow_ch: growth channel, 
            i.e. intermediate channels).
        convtype (str): the type of convolution to use. Default: 'Conv2D'
        gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new 
            trainable parameters)
        plus (bool): enable the additional residual paths from ESRGAN+ 
            (adds trainable parameters)
    '''

    def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(ResidualDenseBlock_5C, self).__init__()
        
        ## +
        self.noise = GaussianNoise() if gaussian_noise else None
        self.conv1x1 = conv1x1(nf, gc) if plus else None
        ## +

        self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        if self.conv1x1:
            x2 = x2 + self.conv1x1(x) #+
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        if self.conv1x1:
            x4 = x4 + x2 #+
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        if self.noise:
            return self.noise(x5.mul(0.2) + x)
        else:
            return x5 * 0.2 + x


####################
# RRDBNet Generator (modified/"new" architecture)
####################


class MRRDBNet(pl.LightningModule):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(MRRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDBM, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

class ResidualDenseBlock_5CM(pl.LightningModule):
    '''
    Residual Dense Block
    '''
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5CM, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], scale=0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDBM(pl.LightningModule):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDBM, self).__init__()
        self.RDB1 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB2 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB3 = ResidualDenseBlock_5CM(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x



class DSNetRRDB(pl.LightningModule):
    def __init__(self, layer_size=8, input_channels=3, upsampling_mode='nearest',
                in_nc=4, out_nc=3, nf=128, nb=8, gc=32, upscale=1, norm_type=None,
                act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
                finalact=None, gaussian_noise=True, plus=False, 
                nr=3):
      super(DSNetRRDB, self).__init__()
      self.netG1 = DSNet(layer_size=layer_size, input_channels=input_channels, upsampling_mode=upsampling_mode)
      self.netG2 = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, gc=gc, upscale=upscale, norm_type=norm_type,
                act_type=act_type, mode=mode, upsample_mode=upsample_mode, convtype=convtype,
                finalact=finalact, gaussian_noise=gaussian_noise, plus=plus, 
                nr=nr)
    def forward(self, input, input_mask):
      result1 = self.netG1(input, input_mask)

      result1 = input*input_mask+result1*(1-input_mask)

      concat = torch.cat((result1, input_mask), dim=1)
      result2 = self.netG2(concat)
      return result2


In [None]:
#@title DSNetDeoldify
#@title DSNetRRDB
#@markdown Combining DSNet and Deoldify, where DSNet is the first stange and Deoldify is the second stage.
"""
DSNet.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/DSNet.py

RegionNorm.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/RegionNorm.py

ValidMigration.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/ValidMigration.py

Attention.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/Attention.py

deform_conv.py (6-3-20)
https://github.com/wangning-001/DSNet/blob/afa174a8f8e4fbdeff086fb546c83c871e959141/modules/deform_conv.py
"""
#from modules.Attention import PixelContextualAttention
#from modules.RegionNorm import RBNModule, RCNModule
#from modules.ValidMigration import ConvOffset2D
#from modules.deform_conv import th_batch_map_offsets, th_generate_grid
from __future__ import absolute_import, division
from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates
from torch.autograd import Variable
from torchvision import models
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl


def th_flatten(a):
    """Flatten tensor"""
    return a.contiguous().view(a.nelement())


def th_repeat(a, repeats, axis=0):
    """Torch version of np.repeat for 1D"""
    assert len(a.size()) == 1
    return th_flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))


def np_repeat_2d(a, repeats):
    """Tensorflow version of np.repeat for 2D"""

    assert len(a.shape) == 2
    a = np.expand_dims(a, 0)
    a = np.tile(a, [repeats, 1, 1])
    return a


def th_gather_2d(input, coords):
    inds = coords[:, 0]*input.size(1) + coords[:, 1]
    x = torch.index_select(th_flatten(input), 0, inds)
    return x.view(coords.size(0))


def th_map_coordinates(input, coords, order=1):
    """Tensorflow verion of scipy.ndimage.map_coordinates
    Note that coords is transposed and only 2D is supported
    Parameters
    ----------
    input : tf.Tensor. shape = (s, s)
    coords : tf.Tensor. shape = (n_points, 2)
    """

    assert order == 1
    input_size = input.size(0)

    coords = torch.clamp(coords, 0, input_size - 1)
    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[:, 0], coords_rb[:, 1]], 1)
    coords_rt = torch.stack([coords_rb[:, 0], coords_lt[:, 1]], 1)

    vals_lt = th_gather_2d(input,  coords_lt.detach())
    vals_rb = th_gather_2d(input,  coords_rb.detach())
    vals_lb = th_gather_2d(input,  coords_lb.detach())
    vals_rt = th_gather_2d(input,  coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())

    vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0]
    vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0]
    mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1]
    return mapped_vals


def sp_batch_map_coordinates(inputs, coords):
    """Reference implementation for batch_map_coordinates"""
    # coords = coords.clip(0, inputs.shape[1] - 1)

    assert (coords.shape[2] == 2)
    height = coords[:,:,0].clip(0, inputs.shape[1] - 1)
    width = coords[:,:,1].clip(0, inputs.shape[2] - 1)
    np.concatenate((np.expand_dims(height, axis=2), np.expand_dims(width, axis=2)), 2)

    mapped_vals = np.array([
        sp_map_coordinates(input, coord.T, mode='nearest', order=1)
        for input, coord in zip(inputs, coords)
    ])
    return mapped_vals


def th_batch_map_coordinates(input, coords, order=1):
    """Batch version of th_map_coordinates
    Only supports 2D feature maps
    Parameters
    ----------
    input : tf.Tensor. shape = (b, s, s)
    coords : tf.Tensor. shape = (b, n_points, 2)
    Returns
    -------
    tf.Tensor. shape = (b, s, s)
    """

    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    n_coords = coords.size(1)

    # coords = torch.clamp(coords, 0, input_size - 1)

    coords = torch.cat((torch.clamp(coords.narrow(2, 0, 1), 0, input_height - 1), torch.clamp(coords.narrow(2, 1, 1), 0, input_width - 1)), 2)

    assert (coords.size(1) == n_coords)

    coords_lt = coords.floor().long()
    coords_rb = coords.ceil().long()
    coords_lb = torch.stack([coords_lt[..., 0], coords_rb[..., 1]], 2)
    coords_rt = torch.stack([coords_rb[..., 0], coords_lt[..., 1]], 2)
    idx = th_repeat(torch.arange(0, batch_size), n_coords).long()
    idx = Variable(idx, requires_grad=False)
    if input.is_cuda:
        idx = idx.cuda()

    def _get_vals_by_coords(input, coords):
        indices = torch.stack([
            idx, th_flatten(coords[..., 0]), th_flatten(coords[..., 1])
        ], 1)
        inds = indices[:, 0]*input.size(1)*input.size(2)+ indices[:, 1]*input.size(2) + indices[:, 2]
        vals = th_flatten(input).index_select(0, inds)
        vals = vals.view(batch_size, n_coords)
        return vals

    vals_lt = _get_vals_by_coords(input, coords_lt.detach())
    vals_rb = _get_vals_by_coords(input, coords_rb.detach())
    vals_lb = _get_vals_by_coords(input, coords_lb.detach())
    vals_rt = _get_vals_by_coords(input, coords_rt.detach())

    coords_offset_lt = coords - coords_lt.type(coords.data.type())
    vals_t = coords_offset_lt[..., 0]*(vals_rt - vals_lt) + vals_lt
    vals_b = coords_offset_lt[..., 0]*(vals_rb - vals_lb) + vals_lb
    mapped_vals = coords_offset_lt[..., 1]* (vals_b - vals_t) + vals_t
    return mapped_vals


def sp_batch_map_offsets(input, offsets):
    """Reference implementation for tf_batch_map_offsets"""

    batch_size = input.shape[0]
    input_height = input.shape[1]
    input_width = input.shape[2]

    offsets = offsets.reshape(batch_size, -1, 2)
    grid = np.stack(np.mgrid[:input_height, :input_width], -1).reshape(-1, 2)
    grid = np.repeat([grid], batch_size, axis=0)
    coords = offsets + grid
    # coords = coords.clip(0, input_size - 1)

    mapped_vals = sp_batch_map_coordinates(input, coords)
    return mapped_vals


def th_generate_grid(batch_size, input_height, input_width, dtype, cuda):
    grid = np.meshgrid(
        range(input_height), range(input_width), indexing='ij'
    )
    grid = np.stack(grid, axis=-1)
    grid = grid.reshape(-1, 2)

    grid = np_repeat_2d(grid, batch_size)
    grid = torch.from_numpy(grid).type(dtype)
    if cuda:
        grid = grid.cuda()
    return Variable(grid, requires_grad=False)


def th_batch_map_offsets(input, offsets, grid=None, order=1):
    """Batch map offsets into input
    Parameters
    ---------
    input : torch.Tensor. shape = (b, s, s)
    offsets: torch.Tensor. shape = (b, s, s, 2)
    Returns
    -------
    torch.Tensor. shape = (b, s, s)
    """
    batch_size = input.size(0)
    input_height = input.size(1)
    input_width = input.size(2)

    offsets = offsets.view(batch_size, -1, 2)
    if grid is None:
        grid = th_generate_grid(batch_size, input_height, input_width, offsets.data.type(), offsets.data.is_cuda)

    coords = offsets + grid

    mapped_vals = th_batch_map_coordinates(input, coords)
    return mapped_vals


class SEModule(pl.LightningModule):
    def __init__(self, num_channel, squeeze_ratio=1.0):
        super(SEModule, self).__init__()
        self.sequeeze_mod = nn.AdaptiveAvgPool2d(1)
        self.num_channel = num_channel

        blocks = [nn.Linear(num_channel, int(num_channel * squeeze_ratio)),
                  nn.ReLU(),
                  nn.Linear(int(num_channel * squeeze_ratio), num_channel),
                  nn.Sigmoid()]
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        ori = x
        x = self.sequeeze_mod(x)
        x = x.view(x.size(0), 1, self.num_channel)
        x = self.blocks(x)
        x = x.view(x.size(0), self.num_channel, 1, 1)
        x = ori * x
        return x


class ContextualAttentionModule(pl.LightningModule):

    def __init__(self, patch_size=3, propagate_size=3, stride=1):
        super(ContextualAttentionModule, self).__init__()
        self.patch_size = patch_size
        self.propagate_size = propagate_size
        self.stride = stride
        self.prop_kernels = None

    def forward(self, foreground, masks):
        ###assume the masked area has value 1
        bz, nc, w, h = foreground.size()
        if masks.size(3) != foreground.size(3):
            masks = F.interpolate(masks, foreground.size()[2:])
        background = foreground.clone()
        background = background * masks
        background = F.pad(background,
                           [self.patch_size // 2, self.patch_size // 2, self.patch_size // 2, self.patch_size // 2])
        conv_kernels_all = background.unfold(2, self.patch_size, self.stride).unfold(3, self.patch_size,
                                                                                     self.stride).contiguous().view(bz,
                                                                                                                    nc,
                                                                                                                    -1,
                                                                                                                    self.patch_size,
                                                                                                                    self.patch_size)
        conv_kernels_all = conv_kernels_all.transpose(2, 1)
        output_tensor = []
        for i in range(bz):
            mask = masks[i:i + 1]
            feature_map = foreground[i:i + 1].contiguous()
            # form convolutional kernels
            conv_kernels = conv_kernels_all[i] + 0.0000001
            norm_factor = torch.sum(conv_kernels ** 2, [1, 2, 3], keepdim=True) ** 0.5
            conv_kernels = conv_kernels / norm_factor

            conv_result = F.conv2d(feature_map, conv_kernels, padding=self.patch_size // 2)
            """
            if self.propagate_size != 1:
                if self.prop_kernels is None:
                    self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
                    self.prop_kernels.requires_grad = False
                    self.prop_kernels = self.prop_kernels.cuda()
                conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))

            """

            self.prop_kernels = torch.ones([conv_result.size(1), 1, self.propagate_size, self.propagate_size])
            self.prop_kernels.requires_grad = False
            self.prop_kernels = self.prop_kernels.cuda()
            conv_result = F.conv2d(conv_result, self.prop_kernels, stride=1, padding=1, groups=conv_result.size(1))
            
            attention_scores = F.softmax(conv_result, dim=1)
            ##propagate the scores
            recovered_foreground = F.conv_transpose2d(attention_scores, conv_kernels, stride=1,
                                                      padding=self.patch_size // 2)
            # average the recovered value, at the same time make non-masked area 0
            recovered_foreground = (recovered_foreground * (1 - mask)) / (self.patch_size ** 2)
            # recover the image
            final_output = recovered_foreground + feature_map * mask
            output_tensor.append(final_output)
        return torch.cat(output_tensor, dim=0)


class PixelContextualAttention(pl.LightningModule):

    def __init__(self, inchannel, patch_size_list=[1], propagate_size_list=[3], stride_list=[1]):
        assert isinstance(patch_size_list,
                          list), "patch_size should be a list containing scales, or you should use Contextual Attention to initialize your module"
        assert len(patch_size_list) == len(propagate_size_list) and len(propagate_size_list) == len(
            stride_list), "the input_lists should have same lengths"
        super(PixelContextualAttention, self).__init__()
        for i in range(len(patch_size_list)):
            name = "CA_{:d}".format(i)
            setattr(self, name, ContextualAttentionModule(patch_size_list[i], propagate_size_list[i], stride_list[i]))
        self.num_of_modules = len(patch_size_list)
        self.SqueezeExc = SEModule(inchannel * 2)
        self.combiner = nn.Conv2d(inchannel * 2, inchannel, kernel_size=1)

    def forward(self, foreground, mask):
        outputs = [foreground]
        for i in range(self.num_of_modules):
            name = "CA_{:d}".format(i)
            CA_module = getattr(self, name)
            outputs.append(CA_module(foreground, mask))
        outputs = torch.cat(outputs, dim=1)
        outputs = self.SqueezeExc(outputs)
        outputs = self.combiner(outputs)
        return outputs




class ConvOffset2D(nn.Conv2d):
    """ConvOffset2D

    Convolutional layer responsible for learning the 2D offsets and output the
    deformed feature map using bilinear interpolation

    Note that this layer does not perform convolution on the deformed feature
    map. See get_deform_cnn in cnn.py for usage
    """
    def __init__(self, filters, init_normal_stddev=0.01, **kwargs):
        """Init

        Parameters
        ----------
        filters : int
            Number of channel of the input feature map
        init_normal_stddev : float
            Normal kernel initialization
        **kwargs:
            Pass to superclass. See Con2d layer in pytorch
        """
        self.filters = filters
        self._grid_param = None
        super(ConvOffset2D, self).__init__(self.filters, self.filters*2, 3, padding=1, bias=False, **kwargs)
        self.weight.data.copy_(self._init_weights(self.weight, init_normal_stddev))

    def forward(self, x):
        """Return the deformed featured map"""
        x_shape = x.size()
        offsets = super(ConvOffset2D, self).forward(x)

        # offsets: (b*c, h, w, 2)
        offsets = self._to_bc_h_w_2(offsets, x_shape)

        # x: (b*c, h, w)
        x = self._to_bc_h_w(x, x_shape)

        # X_offset: (b*c, h, w)
        x_offset = th_batch_map_offsets(x, offsets, grid=self._get_grid(self,x))

        # x_offset: (b, h, w, c)
        x_offset = self._to_b_c_h_w(x_offset, x_shape)

        return x_offset

    @staticmethod
    def _get_grid(self, x):
        batch_size, input_height, input_width = x.size(0), x.size(1), x.size(2)
        dtype, cuda = x.data.type(), x.data.is_cuda
        if self._grid_param == (batch_size, input_height, input_width, dtype, cuda):
            return self._grid
        self._grid_param = (batch_size, input_height, input_width, dtype, cuda)
        self._grid = th_generate_grid(batch_size, input_height, input_width, dtype, cuda)
        return self._grid

    @staticmethod
    def _init_weights(weights, std):
        fan_out = weights.size(0)
        fan_in = weights.size(1) * weights.size(2) * weights.size(3)
        w = np.random.normal(0.0, std, (fan_out, fan_in))
        return torch.from_numpy(w.reshape(weights.size()))

    @staticmethod
    def _to_bc_h_w_2(x, x_shape):
        """(b, 2c, h, w) -> (b*c, h, w, 2)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]), 2)
        return x

    @staticmethod
    def _to_bc_h_w(x, x_shape):
        """(b, c, h, w) -> (b*c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))
        return x

    @staticmethod
    def _to_b_c_h_w(x, x_shape):
        """(b*c, h, w) -> (b, c, h, w)"""
        x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))
        return x





class RBNModule(pl.LightningModule):
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                     'running_mean', 'running_var', 'num_batches_tracked']

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, track_running_stats=True):
        super(RBNModule, self).__init__()
        self.num_features = num_features
        self.track_running_stats = track_running_stats
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
            self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, mask_t):
        input_m = input * mask_t
        if self.training:
            mask_mean = torch.mean(mask_t, (0, 2, 3), True)
            x_mean = torch.mean(input_m, (0, 2, 3), True) / mask_mean
            x_var = torch.mean(((input_m - x_mean) * mask_t) ** 2, (0, 2, 3), True) / mask_mean

            x_out = self.weight * (input_m - x_mean) / torch.sqrt(x_var + self.eps) + self.bias

            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * x_mean.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * x_var.data)
        else:
            x_out = self.weight * (input_m - self.running_mean) / torch.sqrt(self.running_var + self.eps) + self.bias
        return x_out * mask_t + input * (1 - mask_t)


class RCNModule(pl.LightningModule):
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                     'running_mean', 'running_var', 'num_batches_tracked']

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True, track_running_stats=True):
        super(RCNModule, self).__init__()
        self.num_features = num_features
        self.track_running_stats = track_running_stats
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))
            self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.weight)
            nn.init.zeros_(self.bias)

    def forward(self, input, mask_t):
        input_m = input * mask_t

        if self.training:
            mask_mean_bn = torch.mean(mask_t, (0, 2, 3), True)
            mean_bn = torch.mean(input_m, (0, 2, 3), True) / mask_mean_bn
            var_bn = torch.mean(((input_m - mean_bn) * mask_t) ** 2, (0, 2, 3), True) / mask_mean_bn

            self.running_mean.mul_(self.momentum)
            self.running_mean.add_((1 - self.momentum) * mean_bn.data)
            self.running_var.mul_(self.momentum)
            self.running_var.add_((1 - self.momentum) * var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        mask_mean_in = torch.mean(mask_t, (2, 3), True)
        mean_in = torch.mean(input_m, (2, 3), True) / mask_mean_in
        var_in = torch.mean(((input_m - mean_in) * mask_t) ** 2, (2, 3), True) / mask_mean_in

        mask_mean_ln = torch.mean(mask_t, (1, 2, 3), True)
        mean_ln = torch.mean(input_m, (1, 2, 3), True) / mask_mean_ln
        var_ln = torch.mean(((input_m - mean_ln) * mask_t) ** 2, (1, 2, 3), True) / mask_mean_ln

        mean_weight = F.softmax(self.mean_weight)
        var_weight = F.softmax(self.var_weight)

        x_mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
        x_var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn

        x_out = self.weight * (input_m - x_mean) / torch.sqrt(x_var + self.eps) + self.bias
        return x_out * mask_t + input * (1 - mask_t)


class DSModule(pl.LightningModule):
    def __init__(self, in_ch, out_ch, bn=False, rn=True, sample='none-3', activ='relu',
                 conv_bias=False, defor=True):
        super().__init__()
        if sample == 'down-5':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 5, 2, 2, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(5,2,2)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        elif sample == 'down-7':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 7, 2, 3, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(7, 2, 3)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        elif sample == 'down-3':
            self.conv = nn.Conv2d(in_ch+1, out_ch, 3, 2, 1, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(3, 2, 1)
            if defor:
                self.offset = ConvOffset2D(in_ch+1)
        else:
            self.conv = nn.Conv2d(in_ch+2, out_ch, 3, 1, 1, bias=conv_bias)
            self.updatemask = nn.MaxPool2d(3,1,1)
            if defor:
                self.offset0 = ConvOffset2D(in_ch-out_ch+1)
                self.offset1 = ConvOffset2D(out_ch+1)
        self.in_ch = in_ch
        self.out_ch = out_ch

        if bn:
            self.bn = nn.BatchNorm2d(out_ch)
        if rn:
            # Regional Composite Normalization
            self.rn = RCNModule(out_ch)

            # Regional Batch Normalization
            # self.rn = RBNModule(out_ch)
        if activ == 'relu':
            self.activation = nn.ReLU(inplace = True)
        elif activ == 'leaky':
            self.activation = nn.LeakyReLU(negative_slope=0.2, inplace = True)

    def forward(self, input, input_mask):
        if hasattr(self, 'offset'):
            input = torch.cat([input, input_mask[:,:1,:,:]], dim = 1)
            h = self.offset(input)
            h = input*input_mask[:,:1,:,:] + (1-input_mask[:,:1,:,:])*h
            h = self.conv(h)
            h_mask = self.updatemask(input_mask[:,:1,:,:])
            h = h*h_mask
            h = self.rn(h, h_mask)
        elif hasattr(self, 'offset0'):
            h1_in = torch.cat([input[:,self.in_ch-self.out_ch:,:,:], input_mask[:,1:,:,:]], dim = 1)
            m1_in = input_mask[:,1:,:,:]
            h0 = torch.cat([input[:,:self.in_ch-self.out_ch,:,:], input_mask[:,:1,:,:]], dim = 1)
            h1 = self.offset1(h1_in)
            h1 = m1_in*h1_in + (1-m1_in)*h1
            h = self.conv(torch.cat([h0,h1], dim = 1))
            h = self.rn(h, input_mask[:,:1,:,:])
            h_mask = F.interpolate(input_mask[:,:1,:,:], scale_factor=2, mode='nearest')
        else:
            h = self.conv(torch.cat([input, input_mask[:,:,:,:]], dim = 1))
            h_mask = self.updatemask(input_mask[:,:1,:,:])
            h = h*h_mask

        if hasattr(self, 'bn'):
            h = self.bn(h)
        if hasattr(self, 'activation'):
            h = self.activation(h)
        return h, h_mask


class DSNet(pl.LightningModule):
    def __init__(self, layer_size=8, input_channels=3, upsampling_mode='nearest'):
        super().__init__()
        self.freeze_enc_bn = False
        self.upsampling_mode = upsampling_mode
        self.layer_size = layer_size
        self.enc_1 = DSModule(input_channels, 64, rn=False, sample='down-7', defor = False)
        self.enc_2 = DSModule(64, 128, sample='down-5')
        self.enc_3 = DSModule(128, 256, sample='down-5')
        self.enc_4 = DSModule(256, 512, sample='down-3')
        for i in range(4, self.layer_size):
            name = 'enc_{:d}'.format(i + 1)
            setattr(self, name, DSModule(512, 512, sample='down-3'))

        for i in range(4, self.layer_size):
            name = 'dec_{:d}'.format(i + 1)
            setattr(self, name, DSModule(512 + 512, 512, activ='leaky'))
        self.dec_4 = DSModule(512 + 256, 256, activ='leaky')
        self.dec_3 = DSModule(256 + 128, 128, activ='leaky')
        self.dec_2 = DSModule(128 + 64, 64, activ='leaky')
        self.dec_1 = DSModule(64 + input_channels, input_channels,
                              rn=False, activ=None, defor = False)
        self.att = PixelContextualAttention(128)
    def forward(self, input, input_mask):
        input = input.type(torch.cuda.FloatTensor)
        input_mask = input_mask.type(torch.cuda.FloatTensor)

        input_mask = input_mask[:,0:1,:,:]
        h_dict = {}  # for the output of enc_N
        h_mask_dict = {}  # for the output of enc_N

        h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask

        h_key_prev = 'h_0'
        for i in range(1, self.layer_size + 1):
            l_key = 'enc_{:d}'.format(i)
            h_key = 'h_{:d}'.format(i)
            h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)(
                h_dict[h_key_prev], h_mask_dict[h_key_prev])
            h_key_prev = h_key

        h_key = 'h_{:d}'.format(self.layer_size)
        h, h_mask = h_dict[h_key], h_mask_dict[h_key]
        h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest')

        for i in range(self.layer_size, 0, -1):
            enc_h_key = 'h_{:d}'.format(i - 1)
            dec_l_key = 'dec_{:d}'.format(i)

            h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode)

            h = torch.cat([h, h_dict[enc_h_key]], dim=1)
            h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1)
            h, h_mask = getattr(self, dec_l_key)(h, h_mask)
            if i == 3:
                h = self.att(h, input_mask[:,:1,:,:])
        #return h, h_mask
        return h





#@title [Deoldify](https://github.com/alfagao/DeOldify) (2018) (dataloader not implemented)

"""
torch_imports.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fastai/torch_imports.py

conv_learner.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fastai/conv_learner.py

model.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fastai/model.py

modules.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fasterai/modules.py

generators.py (9-3-20)
https://github.com/alfagao/DeOldify/blob/bc9d4562bf2014f5268f5c616ae31873577d9fde/fasterai/generators.py
"""

from abc import ABC, abstractmethod
from torchvision import transforms
from torch.nn.utils.spectral_norm import spectral_norm
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
import pytorch_lightning as pl
from torchvision.models import vgg16_bn, vgg19_bn

def vgg16(pre): return children(vgg16_bn(pre))[0]
def vgg19(pre): return children(vgg19_bn(pre))[0]

def cut_model(m, cut):
    return list(m.children())[:cut] if cut else [m]
"""
model_meta = {
    resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6],
    vgg16:[0,22], vgg19:[0,22],
    resnext50:[8,6], resnext101:[8,6], resnext101_64:[8,6],
    wrn:[8,6], inceptionresnet_2:[-2,9], inception_4:[-1,9],
    dn121:[0,7], dn161:[0,7], dn169:[0,7], dn201:[0,7],
}
"""

model_meta = {
    resnet18:[8,6], resnet34:[8,6], resnet50:[8,6], resnet101:[8,6], resnet152:[8,6],
    vgg16:[0,22], vgg19:[0,22],
}



class ConvBlock(pl.LightningModule):
    def __init__(self, ni:int, no:int, ks:int=3, stride:int=1, pad:int=None, actn:bool=True, 
            bn:bool=True, bias:bool=True, sn:bool=False, leakyReLu:bool=False, self_attention:bool=False,
            inplace_relu:bool=True):
        super().__init__()   
        if pad is None: pad = ks//2//stride

        if sn:
            layers = [spectral_norm(nn.Conv2d(ni, no, ks, stride, padding=pad, bias=bias))]
        else:
            layers = [nn.Conv2d(ni, no, ks, stride, padding=pad, bias=bias)]
        if actn:
            layers.append(nn.LeakyReLU(0.2, inplace=inplace_relu)) if leakyReLu else layers.append(nn.ReLU(inplace=inplace_relu)) 
        if bn:
            layers.append(nn.BatchNorm2d(no))
        if self_attention:
            layers.append(SelfAttention(no, 1))

        self.seq = nn.Sequential(*layers)

    def forward(self, x):
        return self.seq(x)


class UpSampleBlock(pl.LightningModule):
    @staticmethod
    def _conv(ni:int, nf:int, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
        layers = [ConvBlock(ni, nf, ks=ks, sn=sn, bn=bn, actn=False, leakyReLu=leakyReLu)]
        return nn.Sequential(*layers)

    @staticmethod
    def _icnr(x:torch.Tensor, scale:int=2):
        init=nn.init.kaiming_normal_
        new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:])
        subkernel = torch.zeros(new_shape)
        subkernel = init(subkernel)
        subkernel = subkernel.transpose(0, 1)
        subkernel = subkernel.contiguous().view(subkernel.shape[0],
                                                subkernel.shape[1], -1)
        kernel = subkernel.repeat(1, 1, scale ** 2)
        transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:])
        kernel = kernel.contiguous().view(transposed_shape)
        kernel = kernel.transpose(0, 1)
        return kernel

    def __init__(self, ni:int, nf:int, scale:int=2, ks:int=3, bn:bool=True, sn:bool=False, leakyReLu:bool=False):
        super().__init__()
        layers = []
        assert (math.log(scale,2)).is_integer()

        for i in range(int(math.log(scale,2))):
            layers += [UpSampleBlock._conv(ni, nf*4,ks=ks, bn=bn, sn=sn, leakyReLu=leakyReLu), 
                nn.PixelShuffle(2)]
            if bn:
                layers += [nn.BatchNorm2d(nf)]

            ni = nf
                       
        self.sequence = nn.Sequential(*layers)
        self._icnr_init()
        
    def _icnr_init(self):
        conv_shuffle = self.sequence[0][0].seq[0]
        kernel = UpSampleBlock._icnr(conv_shuffle.weight)
        conv_shuffle.weight.data.copy_(kernel)
    
    def forward(self, x):
        return self.sequence(x)


class UnetBlock(pl.LightningModule):
    def __init__(self, up_in:int , x_in:int , n_out:int, bn:bool=True, sn:bool=False, leakyReLu:bool=False, 
            self_attention:bool=False, inplace_relu:bool=True):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = ConvBlock(x_in,  x_out,  ks=1, bn=False, actn=False, sn=sn, inplace_relu=inplace_relu)
        self.tr_conv = UpSampleBlock(up_in, up_out, 2, bn=bn, sn=sn, leakyReLu=leakyReLu)
        self.relu = nn.LeakyReLU(0.2, inplace=inplace_relu) if leakyReLu else nn.ReLU(inplace=inplace_relu)
        out_layers = []
        if bn: 
            out_layers.append(nn.BatchNorm2d(n_out))
        if self_attention:
            out_layers.append(SelfAttention(n_out))
        self.out = nn.Sequential(*out_layers)
        
        
    def forward(self, up_p:int, x_p:int):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        x = torch.cat([up_p,x_p], dim=1)
        x = self.relu(x)
        return self.out(x)

class SaveFeatures():
    features=None
    def __init__(self, m:pl.LightningModule): 
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): 
        self.features = output
    def remove(self): 
        self.hook.remove()

class SelfAttention(pl.LightningModule):
    def __init__(self, in_channel:int, gain:int=1):
        super().__init__()
        self.query = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
        self.key = self._spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),gain=gain)
        self.value = self._spectral_init(nn.Conv1d(in_channel, in_channel, 1), gain=gain)
        self.gamma = nn.Parameter(torch.tensor(0.0))

    def _spectral_init(self, module:pl.LightningModule, gain:int=1):
        nn.init.kaiming_uniform_(module.weight, gain)
        if module.bias is not None:
            module.bias.data.zero_()

        return spectral_norm(module)

    def forward(self, input:torch.Tensor):
        shape = input.shape
        flatten = input.view(shape[0], shape[1], -1)
        query = self.query(flatten).permute(0, 2, 1)
        key = self.key(flatten)
        value = self.value(flatten)
        query_key = torch.bmm(query, key)
        attn = F.softmax(query_key, 1)
        attn = torch.bmm(value, attn)
        attn = attn.view(*shape)
        out = self.gamma * attn + input
        return out


class GeneratorModule(ABC, nn.Module):
    def __init__(self):
        super().__init__()
    
    def set_trainable(self, trainable:bool):
        set_trainable(self, trainable)

    @abstractmethod
    def get_layer_groups(self, precompute:bool=False)->[]:
        pass

    @abstractmethod
    def forward(self, x_in:torch.Tensor, max_render_sz:int=400):
        pass
        
    def freeze_to(self, n:int):
        c=self.get_layer_groups()
        for l in c:     set_trainable(l, False)
        for l in c[n:]: set_trainable(l, True)

    def get_device(self):
        return next(self.parameters()).device


class AbstractUnet(GeneratorModule): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__()
        assert (math.log(scale,2)).is_integer()
        self.rn, self.lr_cut = self._get_pretrained_resnet_base()
        ups = self._get_decoding_layers(nf_factor=nf_factor, scale=scale)
        self.relu = nn.ReLU()
        self.up1 = ups[0]
        self.up2 = ups[1]
        self.up3 = ups[2]
        self.up4 = ups[3]
        self.up5 = ups[4]
        self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=True), nn.Tanh())

    @abstractmethod
    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        pass

    @abstractmethod
    def _get_decoding_layers(self, nf_factor:int, scale:int):
        pass

    #Gets around irritating inconsistent halving coming from resnet
    def _pad(self, x:torch.Tensor, target:torch.Tensor, total_padh:int, total_padw:int)-> torch.Tensor:
        h = x.shape[2] 
        w = x.shape[3]

        target_h = target.shape[2]*2
        target_w = target.shape[3]*2

        if h<target_h or w<target_w:
            padh = target_h-h if target_h > h else 0
            total_padh = total_padh + padh
            padw = target_w-w if target_w > w else 0
            total_padw = total_padw + padw
            return (F.pad(x, (0,padw,0,padh), "reflect",0), total_padh, total_padw)

        return (x, total_padh, total_padw)

    def _remove_padding(self, x:torch.Tensor, padh:int, padw:int)->torch.Tensor:
        if padw == 0 and padh == 0:
            return x 
        
        target_h = x.shape[2]-padh
        target_w = x.shape[3]-padw
        return x[:,:,:target_h, :target_w]

    def _encode(self, x:torch.Tensor):
        x = self.rn[0](x)
        x = self.rn[1](x)
        x = self.rn[2](x)
        enc0 = x
        x = self.rn[3](x)
        x = self.rn[4](x)
        enc1 = x
        x = self.rn[5](x)
        enc2 = x
        x = self.rn[6](x)
        enc3 = x
        x = self.rn[7](x)
        return (x, enc0, enc1, enc2, enc3)

    def _decode(self, x:torch.Tensor, enc0:torch.Tensor, enc1:torch.Tensor, enc2:torch.Tensor, enc3:torch.Tensor):
        padh = 0
        padw = 0
        x = self.relu(x)
        enc3, padh, padw = self._pad(enc3, x, padh, padw)
        x = self.up1(x, enc3)
        enc2, padh, padw  = self._pad(enc2, x, padh, padw)
        x = self.up2(x, enc2)
        enc1, padh, padw  = self._pad(enc1, x, padh, padw)
        x = self.up3(x, enc1)
        enc0, padh, padw  = self._pad(enc0, x, padh, padw)
        x = self.up4(x, enc0)
        #This is a bit too much padding being removed, but I 
        #haven't yet figured out a good way to determine what 
        #exactly should be removed.  This is consistently more 
        #than enough though.
        x = self.up5(x)
        x = self.out(x)
        x = self._remove_padding(x, padh, padw)
        return x

    def forward(self, x:torch.Tensor):
        x, enc0, enc1, enc2, enc3 = self._encode(x)
        x = self._decode(x, enc0, enc1, enc2, enc3)
        return x
    
    def get_layer_groups(self, precompute:bool=False)->[]:
        lgs = list(split_by_idxs(children(self.rn), [self.lr_cut]))
        return lgs + [children(self)[1:]]
    
    def close(self):
        for sf in self.sfs: 
            sf.remove()


class Unet34(AbstractUnet): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__(nf_factor=nf_factor, scale=scale)

    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        f = resnet34
        cut,lr_cut = model_meta[f]
        cut-=layers_cut
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def _get_decoding_layers(self, nf_factor:int, scale:int):
        self_attention=True
        bn=True
        sn=True
        leakyReLu=False
        layers = []
        layers.append(UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,128,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
        return layers 


class Unet101(AbstractUnet): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__(nf_factor=nf_factor, scale=scale)

    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        f = resnet101
        cut,lr_cut = model_meta[f]
        cut-=layers_cut
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def _get_decoding_layers(self, nf_factor:int, scale:int):
        self_attention=True
        bn=True
        sn=True
        leakyReLu=False
        layers = []
        layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
        return layers 

class Unet152(AbstractUnet): 
    def __init__(self, nf_factor:int=1, scale:int=1):
        super().__init__(nf_factor=nf_factor, scale=scale)

    def _get_pretrained_resnet_base(self, layers_cut:int=0):
        f = resnet152
        cut,lr_cut = model_meta[f]
        cut-=layers_cut
        layers = cut_model(f(True), cut)
        return nn.Sequential(*layers), lr_cut

    def _get_decoding_layers(self, nf_factor:int, scale:int):
        self_attention=True
        bn=True
        sn=True
        leakyReLu=False
        layers = []
        layers.append(UnetBlock(2048,1024,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,512,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,256,512*nf_factor, sn=sn, self_attention=self_attention, leakyReLu=leakyReLu, bn=bn))
        layers.append(UnetBlock(512*nf_factor,64,256*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn))
        layers.append(UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn))
        return layers 



class DSNetDeoldify(pl.LightningModule):
    def __init__(self, layer_size=8, input_channels=3, upsampling_mode='nearest'):
      super(DSNetDeoldify, self).__init__()
      self.netG1 = DSNet(layer_size=layer_size, input_channels=input_channels, upsampling_mode=upsampling_mode)
      self.netG2 = Unet34()
    def forward(self, input, input_mask):
      result1 = self.netG1(input, input_mask)

      result1 = input*input_mask+result1*(1-input_mask)

      #concat = torch.cat((result1, input_mask), dim=1)
      #result2 = self.netG2(concat)

      result2 = self.netG2(result1)
      return result2


# Training

In [None]:
#@title delete validation, logs and checkpoints if needed
%cd /content/
!sudo rm -rf /content/validation_output
!sudo rm -rf /content/lightning_logs
!sudo rm -rf /content/logs
#!mkdir /content/logs/
!find . -name "*.ckpt" -type f -delete

Info about ``logger=None``: Logging will be done outside of lightning to have iter-based logging instead of epochs by using tensorboardX within the training loop. Be aware that loss weights and the overall training configuration is located in ``CustomTrainClass.py``. If you want to use another generator or configure parameters, edit stuff there.

In [None]:
#@title Training
#@markdown Normal training assumes the usage of generator and discriminator. If you only use generator, then edit ``checkpoint.py`` and comment out the line, where the discriminator pth gets saved.

#@markdown Settings depends on the dataloader. Not everything is active at the same time.
import pytorch_lightning as pl
%cd /content/

dir_lr = '/content/lr/' #@param
dir_hr = '/content/hr/' #@param
val_lr =  '/content/val_lr/'#@param
val_hr = '/content/val_hr/' #@param
num_workers = 1 #@param
hr_size = 256 #@param
scale = 4 #@param
batch_size = 1 #@param
batch_size_DL = 20 #@param
gpus=1 #@param
max_epochs = 100 #@param
progress_bar_refresh_rate = 20 #@param
default_root_dir='/content/' #@param
save_path='/content/' #@param
save_step_frequency = 100 #@param
tpu_cores = 8 #@param
#@markdown For batch dataloader
image_size=400 #@param
amount_tiles=3 #@param
#############################################
# Dataloader
#############################################
# Inpainting
# normal training
#dm = DFNetDataModule(batch_size=batch_size, training_path = dir_hr, validation_path = val_lr)
# tiled dataloader (batch return)
#dm = DFNetDataModule(training_path = dir_hr, validation_path = val_lr, batch_size=batch_size, num_workers=num_workers, batch_size_DL=batch_size_DL)

# Super Resolution
# lr/hr dataloader
dm = DFNetDataModule(batch_size=batch_size, dir_lr = dir_lr, dir_hr = dir_hr, val_lr = val_lr, val_hr = val_hr, num_workers = num_workers, hr_size = hr_size, scale = scale)
# batch
#dm = DFNetDataModule(batch_size=batch_size, training_path = dir_lr, val_lr = val_lr, val_hr = val_hr, num_workers=num_workers, batch_size_DL=batch_size_DL, hr_size=hr_size, scale = scale, image_size=image_size, amount_tiles=amount_tiles)
#############################################


#############################################
# Loading a Model
#############################################
model = CustomTrainClass()

#@markdown Loading a pretrain pth
pretrain_path = None #@param
if pretrain_path is not None:
  trainer.model.netG.load_state_dict(torch.load(model_path))

#@markdown For resuming training
checkpoint_path = '/content/Checkpoint_1_500.ckpt' #@param

# load from checkpoint (optional) (using a model as pretrain and disregarding other parameters)
#model = model.load_from_checkpoint(checkpoint_path) # start training from checkpoint, warning: apperantly global_step will be reset to zero and overwriting validation images, you could manually make an offset


# continue training with checkpoint (does restore values) (optional)
# https://github.com/PyTorchLightning/pytorch-lightning/issues/2613
# https://pytorch-lightning.readthedocs.io/en/0.6.0/pytorch_lightning.trainer.training_io.html
# https://github.com/PyTorchLightning/pytorch-lightning/issues/4333
# dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'callbacks', 'optimizer_states', 'lr_schedulers', 'state_dict', 'hparams_name', 'hyper_parameters'])

# To use DDP for local multi-GPU training, you need to add find_unused_parameters=True inside the DDP command
"""
model = model.load_from_checkpoint(checkpoint_path)
trainer = pl.Trainer(resume_from_checkpoint=checkpoint_path, logger=None, gpus=1, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=100, save_path='/content/')])
checkpoint = torch.load(checkpoint_path)
trainer.checkpoint_connector.restore(checkpoint, on_gpu=True)
trainer.checkpoint_connector.restore_training_state(checkpoint)
pl.Trainer.global_step = checkpoint['global_step']
pl.Trainer.epoch = checkpoint['epoch']
"""
#############################################



#############################################
# Training
#############################################
# GPU
# Also maybe useful:
# auto_scale_batch_size='binsearch'
# stochastic_weight_avg=True

# Warning: stochastic_weight_avg **can cause crashing after an epoch**. Test if it crashes first if you reach next epoch. Not all generators are tested.
trainer = pl.Trainer(logger=None, gpus=gpus, max_epochs=max_epochs, progress_bar_refresh_rate=progress_bar_refresh_rate, default_root_dir=default_root_dir, callbacks=[CheckpointEveryNSteps(save_step_frequency=save_step_frequency, save_path=save_path)])
# 2+ GPUS (locally, not inside Google Colab)
# Recommended: Pytorch 1.8+. 1.7.1 seems to have dataloader issues and ddp only works if code is run within console.
#trainer = pl.Trainer(logger=None, gpus=2, distributed_backend='dp', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=100, save_path='/content/')])
# GPU with AMP (amp_level='O1' = mixed precision, 'O2' = Almost FP16, 'O3' = FP16)
# https://nvidia.github.io/apex/amp.html?highlight=opt_level#o1-mixed-precision-recommended-for-typical-use
#trainer = pl.Trainer(logger=None, gpus=1, precision=16, amp_level='O1', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# TPU
#trainer = pl.Trainer(logger=None, tpu_cores=tpu_cores, max_epochs=max_epochs, progress_bar_refresh_rate=progress_bar_refresh_rate, default_root_dir=default_root_dir, callbacks=[CheckpointEveryNSteps(save_step_frequency=save_step_frequency, save_path=save_path)])
#############################################

trainer.fit(model, dm)

# Testing 

In [None]:
#@title testing the model
dm = DS_green_from_mask('/content/test')
model = CustomTrainClass()
# GPU
#trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# GPU with AMP (amp_level='O1' = mixed precision)
trainer = pl.Trainer(gpus=1, precision=16, amp_level='O1', max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
# TPU
#trainer = pl.Trainer(tpu_cores=8, max_epochs=10, progress_bar_refresh_rate=20, default_root_dir='/content/', callbacks=[CheckpointEveryNSteps(save_step_frequency=1000, save_path='/content/')])
trainer.test(model, dm, ckpt_path='/content/Checkpoint_0_0.ckpt')

# Misc

In [None]:
#@title creating 16x16 images
import cv2
import numpy
import glob
rootdir = '/content/data' #@param {type:"string"}
destination_dir = "/content/4k/" #@param {type:"string"}

files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files.extend(files_jpg)
err_files=[]

filepos = 0
img_cnt = 0
tmp_img = numpy.zeros((4096,4096, 3))
while True:
  for i in range(16):
    for j in range(16):
      image = cv2.imread(files[filepos])
      filepos += 1
      
      image = cv2.resize(image, (256,256))
      
      tmp_img[i*256:(i+1)*256, j*256:(j+1)*256] = image
  #cv2.imwrite("/content/4k/"+str(img_cnt)+".png", tmp_img)
  cv2.imwrite(destination_dir+str(img_cnt)+".jpg", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
  
  img_cnt += 1

In [None]:
#@title creating 16x16 images (with skip)
import cv2
import numpy
import glob
import shutil
import tqdm
import os
rootdir = '/content/data' #@param {type:"string"}
destination_dir = "/content/4k/" #@param {type:"string"}
broken_dir = '/content/opencv_fail/' #@param {type:"string"}
 
files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

filepos = 0
img_cnt = 0
filename_cnt = 0
tmp_img = numpy.zeros((4096,4096, 3))

with tqdm.tqdm(files) as pbar:
  while True:
      image = cv2.imread(files[filepos])
      filepos += 1

      if image is not None:
        
        i = img_cnt % 16
        j = img_cnt // 16

        image = cv2.resize(image, (256,256))
        tmp_img[i*256:(i+1)*256, j*256:(j+1)*256] = image
        img_cnt += 1
      else:
        print(files[filepos])
        print(f'{broken_dir}/{os.path.basename(files[filepos])}')
        shutil.move(files[filepos], f'{broken_dir}/{os.path.basename(files[filepos])}')

      if img_cnt == 256:
        cv2.imwrite(destination_dir+str(filename_cnt)+".jpg", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
        filename_cnt += 1
        img_cnt = 0
      pbar.update(1)

In [None]:
#@title creating 3x3 grayscale images (with skip)
import cv2
import numpy
import glob
import shutil
import tqdm
import os
import random
rootdir = '/media/veracrypt1/Font/png/' #@param {type:"string"}
destination_dir = "/media/veracrypt1/Font/hr_400/" #@param {type:"string"}
broken_dir = '/media/veracrypt1/Font/broken/' #@param {type:"string"}
 
files = glob.glob(rootdir + '/**/*.png', recursive=True)
files_jpg = glob.glob(rootdir + '/**/*.jpg', recursive=True)
files_jpeg = glob.glob(rootdir + '/**/*.jpeg', recursive=True)
files_webp = glob.glob(rootdir + '/**/*.webp', recursive=True)
files.extend(files_jpg)
files.extend(files_jpeg)
files.extend(files_webp)
err_files=[]

image_size = 400 #@param

filepos = 0
img_cnt = 0
filename_cnt = 0
tmp_img = numpy.zeros((image_size*3,image_size*3))

interpolation_method = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]

with tqdm.tqdm(files) as pbar:
  while True:
      image = cv2.imread(files[filepos], cv2.IMREAD_GRAYSCALE)
      image = cv2.resize(image, (400,400), interpolation=random.choice(interpolation_method))
      filepos += 1

      if image is not None:
        
        i = img_cnt % 3
        j = img_cnt // 3

        image = cv2.resize(image, (image_size,image_size))
        tmp_img[i*image_size:(i+1)*image_size, j*image_size:(j+1)*image_size] = image
        img_cnt += 1
      else:
        print(files[filepos])
        print(f'{broken_dir}/{os.path.basename(files[filepos])}')
        shutil.move(files[filepos], f'{broken_dir}/{os.path.basename(files[filepos])}')

      if img_cnt == 9:
        #cv2.imwrite(destination_dir+str(filename_cnt)+".png", tmp_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
        cv2.imwrite(destination_dir+str(filename_cnt)+".png", tmp_img)
        filename_cnt += 1
        img_cnt = 0
      pbar.update(1)

In [None]:
#@title convert to onnx
#@markdown Make sure the input dimensions are correct. Maybe a runtime restart is needed if it complains about ``TypeError: forward() missing 1 required positional argument``. Make sure you only run the required cells.
from torch.autograd import Variable
model = CustomTrainClass()
checkpoint_path = '/content/Checkpoint_0_0.ckpt' #@param
output_path = '/content/output.onnx' #@param
model = model.load_from_checkpoint(checkpoint_path) # start training from checkpoint, warning: apperantly global_step will be reset to zero and overwriting validation images, you could manually make an offset
dummy_input = Variable(torch.randn(1, 1, 64, 64))

model.to_onnx(output_path, input_sample=dummy_input)

In [None]:
#@title copy pasting data to create artificatial dataset for debugging
import shutil
from random import random
from tqdm import tqdm
for i in tqdm(range(5000)):
  shutil.copy("/content/4k/0.jpg", "/content/4k/"+str(random())+"jpg")