# Converting RRDB (pth to onnx)

Taking a normal pytorch model (``.pth``) and converting it to ``.onnx``.

In [None]:
!nvidia-smi

# Model

In [None]:
#@title utils.py
from pathlib import Path

import cv2
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


def resize_like(x, target, mode='bilinear'):
    return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)


def list2nparray(lst, dtype=None):
    """fast conversion from nested list to ndarray by pre-allocating space"""
    if isinstance(lst, np.ndarray):
        return lst
    assert isinstance(lst, (list, tuple)), 'bad type: {}'.format(type(lst))
    assert lst, 'attempt to convert empty list to np array'
    if isinstance(lst[0], np.ndarray):
        dim1 = lst[0].shape
        assert all(i.shape == dim1 for i in lst)
        if dtype is None:
            dtype = lst[0].dtype
            assert all(i.dtype == dtype for i in lst), \
                'bad dtype: {} {}'.format(dtype, set(i.dtype for i in lst))
    elif isinstance(lst[0], (int, float, complex, np.number)):
        return np.array(lst, dtype=dtype)
    else:
        dim1 = list2nparray(lst[0])
        if dtype is None:
            dtype = dim1.dtype
        dim1 = dim1.shape
    shape = [len(lst)] + list(dim1)
    rst = np.empty(shape, dtype=dtype)
    for idx, i in enumerate(lst):
        rst[idx] = i
    return rst


def get_img_list(path):
    return sorted(list(Path(path).glob('*.png'))) + \
        sorted(list(Path(path).glob('*.jpg'))) + \
        sorted(list(Path(path).glob('*.jpeg')))


def gen_miss(img, mask, output):

    imgs = get_img_list(img)
    masks = get_img_list(mask)
    print('Total images:', len(imgs), len(masks))

    out = Path(output)
    out.mkdir(parents=True, exist_ok=True)

    for i, (img, mask) in tqdm(enumerate(zip(imgs, masks))):
        path = out.joinpath('miss_%04d.png' % (i+1))
        img = cv2.imread(str(img), cv2.IMREAD_COLOR)
        mask = cv2.imread(str(mask), cv2.IMREAD_GRAYSCALE)
        mask = cv2.resize(mask, img.shape[:2][::-1])
        mask = mask[..., np.newaxis]
        miss = img * (mask > 127) + 255 * (mask <= 127)
        cv2.imwrite(str(path), miss)

def merge_imgs(dirs, output, row=1, gap=2, res=512):

    image_list = [get_img_list(path) for path in dirs]
    img_count = [len(image) for image in image_list]
    print('Total images:', img_count)
    assert min(img_count) > 0, 'Please check the path of empty folder.'

    output_dir = Path(output)
    output_dir.mkdir(parents=True, exist_ok=True)

    n_img = len(dirs)
    row = row
    column = (n_img - 1) // row + 1
    print('Row:', row)
    print('Column:', column)

    for i, unit in tqdm(enumerate(zip(*image_list))):
        name = output_dir.joinpath('merge_%04d.png' % i)
        merge = np.ones([
            res*row + (row+1)*gap, res*column + (column+1)*gap, 3], np.uint8) * 255
        for j, img in enumerate(unit):
            r = j // column
            c = j - r * column
            img = cv2.imread(str(img), cv2.IMREAD_COLOR)
            if img.shape[:2] != (res, res):
                img = cv2.resize(img, (res, res))
            start_h, start_w = (r + 1) * gap + r * res, (c + 1) * gap + c * res
            merge[start_h: start_h + res, start_w: start_w + res] = img
        cv2.imwrite(str(name), merge)


In [None]:
#@title block.py
from collections import OrderedDict

import torch
import torch.nn as nn
#from models.modules.architectures.convolutions.partialconv2d import PartialConv2d #TODO
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d
#from models.networks import weights_init_normal, weights_init_xavier, weights_init_kaiming, weights_init_orthogonal


####################
# Basic blocks
####################

# Swish activation funtion
def swish_func(x, beta=1.0):
    """
    "Swish: a Self-Gated Activation Function"
    Searching for Activation Functions (https://arxiv.org/abs/1710.05941)
    
    If beta=1 applies the Sigmoid Linear Unit (SiLU) function element-wise
    If beta=0, Swish becomes the scaled linear function (identity 
      activation) f(x) = x/2
    As beta -> ∞, the sigmoid component converges to approach a 0-1 function
      (unit step), and multiplying that by x gives us f(x)=2max(0,x), which 
      is the ReLU multiplied by a constant factor of 2, so Swish becomes like 
      the ReLU function.
    
    Including beta, Swish can be loosely viewed as a smooth function that 
      nonlinearly interpolate between identity (linear) and ReLU function.
      The degree of interpolation can be controlled by the model if beta is 
      set as a trainable parameter.
    
    Alt: 1.78718727865 * (x * sigmoid(x) - 0.20662096414)
    """
    
    # In-place implementation, may consume less GPU memory: 
    """ 
    result = x.clone()
    torch.sigmoid_(beta*x)
    x *= result
    return x
    #"""
    
    # Normal out-of-place implementation:
    #"""
    return x * torch.sigmoid(beta * x)
    #"""
    
# Swish module
class Swish(nn.Module):
    
    __constants__ = ['beta', 'slope', 'inplace']
    
    def __init__(self, beta=1.0, slope=1.67653251702, inplace=False):
        """
        Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
        """
        super(Swish).__init__()
        self.inplace = inplace
        # self.beta = beta # user-defined beta parameter, non-trainable
        # self.beta = beta * torch.nn.Parameter(torch.ones(1)) # learnable beta parameter, create a tensor out of beta
        self.beta = torch.nn.Parameter(torch.tensor(beta)) # learnable beta parameter, create a tensor out of beta
        self.beta.requiresGrad = True # set requiresGrad to true to make it trainable

        self.slope = slope / 2 # user-defined "slope", non-trainable
        # self.slope = slope * torch.nn.Parameter(torch.ones(1)) # learnable slope parameter, create a tensor out of slope
        # self.slope = torch.nn.Parameter(torch.tensor(slope)) # learnable slope parameter, create a tensor out of slope
        # self.slope.requiresGrad = True # set requiresGrad to true to true to make it trainable
    
    def forward(self, input):
        """
        # Disabled, using inplace causes:
        # "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"
        if self.inplace:
            input.mul_(torch.sigmoid(self.beta*input))
            return 2 * self.slope * input
        else:
            return 2 * self.slope * swish_func(input, self.beta)
        """
        return 2 * self.slope * swish_func(input, self.beta)


def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
    # helper selecting activation
    # neg_slope: for leakyrelu and init of prelu
    # n_prelu: for p_relu num_parameters
    # beta: for swish
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type == 'leakyrelu' or act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act_type == 'Tanh' or act_type == 'tanh':  # [-1, 1] range output
        layer = nn.Tanh()
    elif act_type == 'sigmoid':  # [0, 1] range output
        layer = nn.Sigmoid()
    elif act_type == 'swish':
        layer = Swish(beta=beta, inplace=inplace)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer


class Identity(nn.Module):
    def __init__(self, *kwargs):
        super(Identity, self).__init__()

    def forward(self, x, *kwargs):
        return x


def norm(norm_type, nc):
    """Return a normalization layer
    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    norm_type = norm_type.lower()
    if norm_type == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
        # norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
        # norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    # elif norm_type == 'layer':
    #     return lambda num_features: nn.GroupNorm(1, num_features)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
    return layer


def add_spectral_norm(module, use_spectral_norm=False):
    """ Add spectral norm to any module passed if use_spectral_norm = True,
    else, returns the original module without change
    """
    if use_spectral_norm:
        return nn.utils.spectral_norm(module)
    return module


def pad(pad_type, padding):
    """
    helper selecting padding layer
    if padding is 'zero', can be done with conv layers
    """
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    elif pad_type == 'zero':
        layer = nn.ZeroPad2d(padding)
    else:
        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding


class ConcatBlock(nn.Module):
    # Concat the output of a submodule to its input
    def __init__(self, submodule):
        super(ConcatBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = torch.cat((x, self.sub(x)), dim=1)
        return output

    def __repr__(self):
        return 'Identity .. \n|' + self.sub.__repr__().replace('\n', '\n|')


class ShortcutBlock(nn.Module):
    # Elementwise sum the output of a submodule to its input
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output

    def __repr__(self):
        return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')


def sequential(*args):
    # Flatten Sequential. It unwraps nn.Sequential.
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', \
               spectral_norm=False):
    """
    Conv layer with padding, normalization, activation
    mode: CNA --> Conv -> Norm -> Act
        NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
    """
    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0
    
    if convtype=='PartialConv2D':
        c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='DeformConv2D':
        c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='Conv3D':
        c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
                dilation=dilation, bias=bias, groups=groups)
    else: #default case is standard 'Conv2D':
        c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
                dilation=dilation, bias=bias, groups=groups) #normal conv2d
            
    if spectral_norm:
        c = nn.utils.spectral_norm(c)
    
    a = act(act_type) if act_type else None
    if 'CNA' in mode:
        #n = norm(norm_type, out_nc) if norm_type else None
        n = None
        return sequential(p, c, n, a)
    elif mode == 'NAC':
        if norm_type is None and act_type is not None:
            a = act(act_type, inplace=False)
            # Important!
            # input----ReLU(inplace)----Conv--+----output
            #        |________________________|
            # inplace ReLU will modify the input, therefore wrong output
        n = norm(norm_type, in_nc) if norm_type else None
        return sequential(n, a, p, c)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.
    Args:
        basic_block (nn.module): nn.module class for basic block. (block)
        num_basic_block (int): number of blocks. (n_layers)
    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


class Mean(nn.Module):
  def __init__(self, dim: list, keepdim=False):
    super().__init__()
    self.dim = dim
    self.keepdim = keepdim

  def forward(self, x):
    return torch.mean(x, self.dim, self.keepdim)


####################
# initialize modules
####################

@torch.no_grad()
def default_init_weights(module_list, init_type='kaiming', scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.
    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        init_type (str): the type of initialization in: 'normal', 'kaiming' 
            or 'orthogonal'
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1. (for 'kaiming')
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function:
            mean and/or std for 'normal'.
            a and/or mode for 'kaiming'
            gain for 'orthogonal' and xavier
    """
    
    # TODO
    # logger.info('Initialization method [{:s}]'.format(init_type))
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if init_type == 'normal':
                weights_init_normal(m, bias_fill=bias_fill, **kwargs)
            if init_type == 'xavier':
                weights_init_xavier(m, scale=scale, bias_fill=bias_fill, **kwargs)    
            elif init_type == 'kaiming':
                weights_init_kaiming(m, scale=scale, bias_fill=bias_fill, **kwargs)
            elif init_type == 'orthogonal':
                weights_init_orthogonal(m, bias_fill=bias_fill)
            else:
                raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))



####################
# Upsampler
####################

class Upsample(nn.Module):
    r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.

    The input data is assumed to be of the form
    `minibatch x channels x [optional depth] x [optional height] x width`.

    Args:
        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
            output spatial sizes
        scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
            multiplier for spatial size. Has to match input size if it is a tuple.
        mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
            ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
            Default: ``'nearest'``
        align_corners (bool, optional): if ``True``, the corner pixels of the input
            and output tensors are aligned, and thus preserving the values at
            those pixels. This only has effect when :attr:`mode` is
            ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``
    """
    # To prevent warning: nn.Upsample is deprecated
    # https://discuss.pytorch.org/t/which-function-is-better-for-upsampling-upsampling-or-interpolate/21811/8
    # From: https://pytorch.org/docs/stable/_modules/torch/nn/modules/upsampling.html#Upsample
    # Alternative: https://discuss.pytorch.org/t/using-nn-function-interpolate-inside-nn-sequential/23588/2?u=ptrblck
    
    def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
        super(Upsample, self).__init__()
        if isinstance(scale_factor, tuple):
            self.scale_factor = tuple(float(factor) for factor in scale_factor)
        else:
            self.scale_factor = float(scale_factor) if scale_factor else None
        self.mode = mode
        self.size = size
        self.align_corners = align_corners
        # self.interp = nn.functional.interpolate
    
    def forward(self, x):
        return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
        # return self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
    
    def extra_repr(self):
        if self.scale_factor is not None:
            info = 'scale_factor=' + str(self.scale_factor)
        else:
            info = 'size=' + str(self.size)
        info += ', mode=' + self.mode
        return info

def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                        pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
    """
    Pixel shuffle layer
    (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
    Neural Network, CVPR17)
    """
    conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
    pixel_shuffle = nn.PixelShuffle(upscale_factor)

    n = norm(norm_type, out_nc) if norm_type else None
    a = act(act_type) if act_type else None
    return sequential(conv, pixel_shuffle, n, a)

def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
    """
    Upconv layer described in https://distill.pub/2016/deconv-checkerboard/
    Example to replace deconvolutions: 
        - from: nn.ConvTranspose2d(in_nc, out_nc, kernel_size=4, stride=2, padding=1)
        - to: upconv_block(in_nc, out_nc,kernel_size=3, stride=1, act_type=None)
    """
    # upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
    upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
    upsample = Upsample(scale_factor=upscale_factor, mode=mode) #Updated to prevent the "nn.Upsample is deprecated" Warning
    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
    return sequential(upsample, conv)

# PPON
def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1):
    padding = int((kernel_size - 1) / 2) * dilation
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation, groups=groups)




####################
# ESRGANplus
####################

class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1, is_relative_detach=False):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.noise = torch.tensor(0, dtype=torch.float).to(torch.device('cuda'))

    def forward(self, x):
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
            x = x + sampled_noise
        return x 

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


# TODO: Not used:
# https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py
class minibatch_std_concat_layer(nn.Module):
    def __init__(self, averaging='all'):
        super(minibatch_std_concat_layer, self).__init__()
        self.averaging = averaging.lower()
        if 'group' in self.averaging:
            self.n = int(self.averaging[5:])
        else:
            assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging
        self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)

    def forward(self, x):
        shape = list(x.size())
        target_shape = copy.deepcopy(shape)
        vals = self.adjusted_std(x, dim=0, keepdim=True)
        if self.averaging == 'all':
            target_shape[1] = 1
            vals = torch.mean(vals, dim=1, keepdim=True)
        elif self.averaging == 'spatial':
            if len(shape) == 4:
                vals = mean(vals, axis=[2,3], keepdim=True)             # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True)
        elif self.averaging == 'none':
            target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
        elif self.averaging == 'gpool':
            if len(shape) == 4:
                vals = mean(x, [0,2,3], keepdim=True)                   # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True)
        elif self.averaging == 'flat':
            target_shape[1] = 1
            vals = torch.FloatTensor([self.adjusted_std(x)])
        else:                                                           # self.averaging == 'group'
            target_shape[1] = self.n
            vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3])
            vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1)
        vals = vals.expand(*target_shape)
        return torch.cat([x, vals], 1)


####################
# Useful blocks
####################

class SelfAttentionBlock(nn.Module):
    """ 
        Implementation of Self attention Block according to paper 
        'Self-Attention Generative Adversarial Networks' (https://arxiv.org/abs/1805.08318)
        Flexible Self Attention (FSA) layer according to paper
        Efficient Super Resolution For Large-Scale Images Using Attentional GAN (https://arxiv.org/pdf/1812.04821.pdf)
          The FSA layer borrows the self attention layer from SAGAN, 
          and wraps it with a max-pooling layer to reduce the size 
          of the feature maps and enable large-size images to fit in memory.
        Used in Generator and Discriminator Networks.
    """

    def __init__(self, in_dim, max_pool=False, poolsize = 4, spectral_norm=False, ret_attention=False): #in_dim = in_feature_maps
        super(SelfAttentionBlock,self).__init__()

        self.in_dim = in_dim
        self.max_pool = max_pool
        self.poolsize = poolsize
        self.ret_attention = ret_attention
        
        if self.max_pool:
            self.pooled = nn.MaxPool2d(kernel_size=self.poolsize, stride=self.poolsize) #kernel_size=4, stride=4
            # Note: can test using strided convolutions instead of MaxPool2d! :
            #upsample_block_num = int(math.log(scale_factor, 2))
            #self.pooled = nn.Conv2d .... strided conv
            # upsample_o = [UpconvBlock(in_channels=in_dim, out_channels=in_dim, upscale_factor=2, mode='bilinear', act_type='leakyrelu') for _ in range(upsample_block_num)]
            ## upsample_o.append(nn.Conv2d(nf, in_nc, kernel_size=9, stride=1, padding=4))
            ## self.upsample_o = nn.Sequential(*upsample_o)

            # self.upsample_o = B.Upsample(scale_factor=self.poolsize, mode='bilinear', align_corners=False) 
            
        self.conv_f = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #query_conv 
        self.conv_g = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #key_conv 
        self.conv_h = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #value_conv 

        self.gamma = nn.Parameter(torch.zeros(1)) # Trainable interpolation parameter
        self.softmax  = nn.Softmax(dim = -1)
        
    def forward(self,input):
        """
            inputs :
                input : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        
        if self.max_pool: #Downscale with Max Pool
            x = self.pooled(input)
        else:
            x = input
            
        batch_size, C, width, height = x.size()
        
        N = width * height
        x = x.view(batch_size, -1, N)
        f = self.conv_f(x) #proj_query  # B X CX(N)
        g = self.conv_g(x) #proj_key    # B X C x (*W*H)
        h = self.conv_h(x) #proj_value  # B X C X N

        s = torch.bmm(f.permute(0, 2, 1), g) # energy, transpose check
        # get probabilities
        attention = self.softmax(s) #beta #attention # BX (N) X (N) 
        
        out = torch.bmm(h, attention.permute(0,2,1))
        out = out.view(batch_size, C, width, height) 
        
        if self.max_pool: #Upscale to original size
            # out = self.upsample_o(out)
            out = Upsample(size=(input.shape[2],input.shape[3]), mode='bicubic', align_corners=False)(out) #bicubic (PyTorch > 1.0) | bilinear others.
        
        out = self.gamma*out + input #Add original input
        
        if self.ret_attention:
            return out, attention
        else:
            return out



In [None]:
#@title spectral_norm.py
"""
spectral_norm.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/spectral_norm.py
"""
'''
Copy from pytorch github repo
Spectral Normalization from https://arxiv.org/abs/1802.05957
'''
import torch
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter


class SpectralNorm(object):
    def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
        self.name = name
        self.dim = dim
        if n_power_iterations <= 0:
            raise ValueError('Expected n_power_iterations to be positive, but '
                             'got n_power_iterations={}'.format(n_power_iterations))
        self.n_power_iterations = n_power_iterations
        self.eps = eps

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        weight_mat = weight
        if self.dim != 0:
            # permute dim to front
            weight_mat = weight_mat.permute(self.dim,
                                            *[d for d in range(weight_mat.dim()) if d != self.dim])
        height = weight_mat.size(0)
        weight_mat = weight_mat.reshape(height, -1)
        with torch.no_grad():
            for _ in range(self.n_power_iterations):
                # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                # are the first left and right singular vectors.
                # This power iteration produces approximations of `u` and `v`.
                v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
                u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)

        sigma = torch.dot(u, torch.matmul(weight_mat, v))
        weight = weight / sigma
        return weight, u

    def remove(self, module):
        weight = getattr(module, self.name)
        delattr(module, self.name)
        delattr(module, self.name + '_u')
        delattr(module, self.name + '_orig')
        module.register_parameter(self.name, torch.nn.Parameter(weight))

    def __call__(self, module, inputs):
        if module.training:
            weight, u = self.compute_weight(module)
            setattr(module, self.name, weight)
            setattr(module, self.name + '_u', u)
        else:
            r_g = getattr(module, self.name + '_orig').requires_grad
            getattr(module, self.name).detach_().requires_grad_(r_g)

    @staticmethod
    def apply(module, name, n_power_iterations, dim, eps):
        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = module._parameters[name]
        height = weight.size(dim)

        u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a
        # buffer, which will cause weight to be included in the state dict
        # and also supports nn.init due to shared storage.
        module.register_buffer(fn.name, weight.data)
        module.register_buffer(fn.name + "_u", u)

        module.register_forward_pre_hook(fn)
        return fn


def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
    r"""Applies spectral normalization to a parameter in the given module.

    .. math::
         \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
         \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

    Spectral normalization stabilizes the training of discriminators (critics)
    in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
    with spectral norm :math:`\sigma` of the weight matrix calculated using
    power iteration method. If the dimension of the weight tensor is greater
    than 2, it is reshaped to 2D in power iteration method to get spectral
    norm. This is implemented via a hook that calculates spectral norm and
    rescales weight before every :meth:`~Module.forward` call.

    See `Spectral Normalization for Generative Adversarial Networks`_ .

    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter
        n_power_iterations (int, optional): number of power iterations to
            calculate spectal norm
        eps (float, optional): epsilon for numerical stability in
            calculating norms
        dim (int, optional): dimension corresponding to number of outputs,
            the default is 0, except for modules that are instances of
            ConvTranspose1/2/3d, when it is 1

    Returns:
        The original module with the spectal norm hook

    Example::

        >>> m = spectral_norm(nn.Linear(20, 40))
        Linear (20 -> 40)
        >>> m.weight_u.size()
        torch.Size([20])

    """
    if dim is None:
        if isinstance(
                module,
            (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
            dim = 1
        else:
            dim = 0
    SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module


def remove_spectral_norm(module, name='weight'):
    r"""Removes the spectral normalization reparameterization from a module.

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter

    Example:
        >>> m = spectral_norm(nn.Linear(40, 10))
        >>> remove_spectral_norm(m)
    """
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, SpectralNorm) and hook.name == name:
            hook.remove(module)
            del module._forward_pre_hooks[k]
            return module

    raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))


In [None]:
#@title ESRGAN_arch.py
"""
RRDBNet_arch.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/architectures/RRDBNet_arch.py
"""
import math
import torch
import torch.nn as nn
#import torchvision
#from . import block as B
import functools
#from . import spectral_norm as SN


####################
# RRDBNet Generator (original architecture)
####################

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', \
            finalact=None, gaussian_noise=False, plus=False):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
        rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, \
            gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
        LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)

        if upsample_mode == 'upconv':
            upsample_block = upconv_block
        elif upsample_mode == 'pixelshuffle':
            upsample_block = pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
        HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
        HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)

        # Note: this option adds new parameters to the architecture, another option is to use "outm" in the forward
        outact = act(finalact) if finalact else None
        
        self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1, outact)

    def forward(self, x, outm=None):
        x = self.model(x)
        
        if outm=='scaltanh': # limit output range to [-1,1] range with tanh and rescale to [0,1] Idea from: https://github.com/goldhuang/SRGAN-PyTorch/blob/master/model.py
            return(torch.tanh(x) + 1.0) / 2.0
        elif outm=='tanh': # limit output to [-1,1] range
            return torch.tanh(x)
        elif outm=='sigmoid': # limit output to [0,1] range
            return torch.sigmoid(x)
        elif outm=='clamp':
            return torch.clamp(x, min=0.0, max=1.0)
        else: #Default, no cap for the output
            return x

class RRDB(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(RRDB, self).__init__()
        # This is for backwards compatibility with existing models
        if nr == 3:
            self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
        else:
            RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                                              norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                                              gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
            self.RDBs = nn.Sequential(*RDB_list)

    def forward(self, x):
        if hasattr(self, 'RDB1'):
            out = self.RDB1(x)
            out = self.RDB2(out)
            out = self.RDB3(out)
        else:
            out = self.RDBs(x)
        return out * 0.2 + x

class ResidualDenseBlock_5C(nn.Module):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    Modified options that can be used:
        - "Partial Convolution based Padding" arXiv:1811.11718
        - "Spectral normalization" arXiv:1802.05957
        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. 
            {Rakotonirina} and A. {Rasoanaivo}
    
    Args:
        nf (int): Channel number of intermediate features (num_feat).
        gc (int): Channels for each growth (num_grow_ch: growth channel, 
            i.e. intermediate channels).
        convtype (str): the type of convolution to use. Default: 'Conv2D'
        gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new 
            trainable parameters)
        plus (bool): enable the additional residual paths from ESRGAN+ 
            (adds trainable parameters)
    '''

    def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(ResidualDenseBlock_5C, self).__init__()
        
        ## +
        self.noise = GaussianNoise() if gaussian_noise else None
        self.conv1x1 = conv1x1(nf, gc) if plus else None
        ## +

        self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        if self.conv1x1:
            x2 = x2 + self.conv1x1(x) #+
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        if self.conv1x1:
            x4 = x4 + x2 #+
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        if self.noise:
            return self.noise(x5.mul(0.2) + x)
        else:
            return x5 * 0.2 + x


####################
# RRDBNet Generator (modified/"new" architecture)
####################


class MRRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(MRRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDBM, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

class ResidualDenseBlock_5CM(nn.Module):
    '''
    Residual Dense Block
    '''
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5CM, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], scale=0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDBM(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDBM, self).__init__()
        self.RDB1 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB2 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB3 = ResidualDenseBlock_5CM(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


# Getting model

In [None]:
#@title
%cd /content/
#!wget https://b2.fromtheexchange.space/file/space-fromtheexchange-b2/colab/model-conversion/v1/RRDB_ESRGAN_x4.pth
#!gdown --id 19qzrIGWUW3eu5troIjHzhCIbTZONSRjX

#%cd /content/
!wget https://f002.backblazeb2.com/file/ESRGAN/_RELEASE/4x_cat_patch_325000_G.pth

# Converting
Output will be in ``/content/output.onnx``.

In [None]:
%cd /content/
# setting needed params, norm_type will be forced to None
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, upscale=4, norm_type='null',
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
            finalact=None, gaussian_noise=True, plus=False, 
            nr=3)

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

dummy_input = Variable(torch.randn(1, 3, 64, 64)) # don't set it too high, will run out of RAM
model_path = '/content/4x_cat_patch_325000_G.pth' #@param
state_dict = torch.load(model_path)

model.load_state_dict(state_dict)
model = model.eval()

#torch.onnx.export(model, dummy_input, "output.onnx")
torch.onnx.export(model, dummy_input, "output.onnx", opset_version=11)

# Misc

# torch.jit.save

In [None]:
#@title torch.jit.save
model_path = '/content/4x_cat_patch_325000_G.pth' #@param
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model = model.eval()

traced_model = torch.jit.trace(model, torch.randn(1,3,60,60))
torch.jit.save(traced_model, 'traces.pt')

# torch.jit.script

Using ``torch.jit.script`` to generate the picture. ``torch.jit.script`` does not like certain parts of the code, so I just removed them to make ``torch.jit.script`` usable.

In [None]:
#@title block.py (removing if self.training and self.sigma != 0:)
from collections import OrderedDict

import torch
import torch.nn as nn
#from models.modules.architectures.convolutions.partialconv2d import PartialConv2d #TODO
#from models.modules.architectures.convolutions.deformconv2d import DeformConv2d
#from models.networks import weights_init_normal, weights_init_xavier, weights_init_kaiming, weights_init_orthogonal


####################
# Basic blocks
####################

# Swish activation funtion
def swish_func(x, beta=1.0):
    """
    "Swish: a Self-Gated Activation Function"
    Searching for Activation Functions (https://arxiv.org/abs/1710.05941)
    
    If beta=1 applies the Sigmoid Linear Unit (SiLU) function element-wise
    If beta=0, Swish becomes the scaled linear function (identity 
      activation) f(x) = x/2
    As beta -> ∞, the sigmoid component converges to approach a 0-1 function
      (unit step), and multiplying that by x gives us f(x)=2max(0,x), which 
      is the ReLU multiplied by a constant factor of 2, so Swish becomes like 
      the ReLU function.
    
    Including beta, Swish can be loosely viewed as a smooth function that 
      nonlinearly interpolate between identity (linear) and ReLU function.
      The degree of interpolation can be controlled by the model if beta is 
      set as a trainable parameter.
    
    Alt: 1.78718727865 * (x * sigmoid(x) - 0.20662096414)
    """
    
    # In-place implementation, may consume less GPU memory: 
    """ 
    result = x.clone()
    torch.sigmoid_(beta*x)
    x *= result
    return x
    #"""
    
    # Normal out-of-place implementation:
    #"""
    return x * torch.sigmoid(beta * x)
    #"""
    
# Swish module
class Swish(nn.Module):
    
    __constants__ = ['beta', 'slope', 'inplace']
    
    def __init__(self, beta=1.0, slope=1.67653251702, inplace=False):
        """
        Shape:
        - Input: (N, *) where * means, any number of additional
          dimensions
        - Output: (N, *), same shape as the input
        """
        super(Swish).__init__()
        self.inplace = inplace
        # self.beta = beta # user-defined beta parameter, non-trainable
        # self.beta = beta * torch.nn.Parameter(torch.ones(1)) # learnable beta parameter, create a tensor out of beta
        self.beta = torch.nn.Parameter(torch.tensor(beta)) # learnable beta parameter, create a tensor out of beta
        self.beta.requiresGrad = True # set requiresGrad to true to make it trainable

        self.slope = slope / 2 # user-defined "slope", non-trainable
        # self.slope = slope * torch.nn.Parameter(torch.ones(1)) # learnable slope parameter, create a tensor out of slope
        # self.slope = torch.nn.Parameter(torch.tensor(slope)) # learnable slope parameter, create a tensor out of slope
        # self.slope.requiresGrad = True # set requiresGrad to true to true to make it trainable
    
    def forward(self, input):
        """
        # Disabled, using inplace causes:
        # "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"
        if self.inplace:
            input.mul_(torch.sigmoid(self.beta*input))
            return 2 * self.slope * input
        else:
            return 2 * self.slope * swish_func(input, self.beta)
        """
        return 2 * self.slope * swish_func(input, self.beta)


def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
    # helper selecting activation
    # neg_slope: for leakyrelu and init of prelu
    # n_prelu: for p_relu num_parameters
    # beta: for swish
    act_type = act_type.lower()
    if act_type == 'relu':
        layer = nn.ReLU(inplace)
    elif act_type == 'leakyrelu' or act_type == 'lrelu':
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act_type == 'prelu':
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act_type == 'Tanh' or act_type == 'tanh':  # [-1, 1] range output
        layer = nn.Tanh()
    elif act_type == 'sigmoid':  # [0, 1] range output
        layer = nn.Sigmoid()
    elif act_type == 'swish':
        layer = Swish(beta=beta, inplace=inplace)
    else:
        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
    return layer


class Identity(nn.Module):
    def __init__(self, *kwargs):
        super(Identity, self).__init__()

    def forward(self, x, *kwargs):
        return x


def norm(norm_type, nc):
    """Return a normalization layer
    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    norm_type = norm_type.lower()
    if norm_type == 'batch':
        layer = nn.BatchNorm2d(nc, affine=True)
        # norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        layer = nn.InstanceNorm2d(nc, affine=False)
        # norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    # elif norm_type == 'layer':
    #     return lambda num_features: nn.GroupNorm(1, num_features)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
    return layer


def add_spectral_norm(module, use_spectral_norm=False):
    """ Add spectral norm to any module passed if use_spectral_norm = True,
    else, returns the original module without change
    """
    if use_spectral_norm:
        return nn.utils.spectral_norm(module)
    return module


def pad(pad_type, padding):
    """
    helper selecting padding layer
    if padding is 'zero', can be done with conv layers
    """
    pad_type = pad_type.lower()
    if padding == 0:
        return None
    if pad_type == 'reflect':
        layer = nn.ReflectionPad2d(padding)
    elif pad_type == 'replicate':
        layer = nn.ReplicationPad2d(padding)
    elif pad_type == 'zero':
        layer = nn.ZeroPad2d(padding)
    else:
        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
    return layer


def get_valid_padding(kernel_size, dilation):
    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
    padding = (kernel_size - 1) // 2
    return padding


class ConcatBlock(nn.Module):
    # Concat the output of a submodule to its input
    def __init__(self, submodule):
        super(ConcatBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = torch.cat((x, self.sub(x)), dim=1)
        return output

    def __repr__(self):
        return 'Identity .. \n|' + self.sub.__repr__().replace('\n', '\n|')


class ShortcutBlock(nn.Module):
    # Elementwise sum the output of a submodule to its input
    def __init__(self, submodule):
        super(ShortcutBlock, self).__init__()
        self.sub = submodule

    def forward(self, x):
        output = x + self.sub(x)
        return output

    def __repr__(self):
        return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')


def sequential(*args):
    # Flatten Sequential. It unwraps nn.Sequential.
    if len(args) == 1:
        if isinstance(args[0], OrderedDict):
            raise NotImplementedError('sequential does not support OrderedDict input.')
        return args[0]  # No sequential is needed.
    modules = []
    for module in args:
        if isinstance(module, nn.Sequential):
            for submodule in module.children():
                modules.append(submodule)
        elif isinstance(module, nn.Module):
            modules.append(module)
    return nn.Sequential(*modules)


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', \
               spectral_norm=False):
    """
    Conv layer with padding, normalization, activation
    mode: CNA --> Conv -> Norm -> Act
        NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
    """
    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0
    
    if convtype=='PartialConv2D':
        c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='DeformConv2D':
        c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
               dilation=dilation, bias=bias, groups=groups)
    elif convtype=='Conv3D':
        c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
                dilation=dilation, bias=bias, groups=groups)
    else: #default case is standard 'Conv2D':
        c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
                dilation=dilation, bias=bias, groups=groups) #normal conv2d
            
    if spectral_norm:
        c = nn.utils.spectral_norm(c)
    
    a = act(act_type) if act_type else None
    if 'CNA' in mode:
        #n = norm(norm_type, out_nc) if norm_type else None
        n = None
        return sequential(p, c, n, a)
    elif mode == 'NAC':
        if norm_type is None and act_type is not None:
            a = act(act_type, inplace=False)
            # Important!
            # input----ReLU(inplace)----Conv--+----output
            #        |________________________|
            # inplace ReLU will modify the input, therefore wrong output
        n = norm(norm_type, in_nc) if norm_type else None
        return sequential(n, a, p, c)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.
    Args:
        basic_block (nn.module): nn.module class for basic block. (block)
        num_basic_block (int): number of blocks. (n_layers)
    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


class Mean(nn.Module):
  def __init__(self, dim: list, keepdim=False):
    super().__init__()
    self.dim = dim
    self.keepdim = keepdim

  def forward(self, x):
    return torch.mean(x, self.dim, self.keepdim)


####################
# initialize modules
####################

@torch.no_grad()
def default_init_weights(module_list, init_type='kaiming', scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.
    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        init_type (str): the type of initialization in: 'normal', 'kaiming' 
            or 'orthogonal'
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1. (for 'kaiming')
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function:
            mean and/or std for 'normal'.
            a and/or mode for 'kaiming'
            gain for 'orthogonal' and xavier
    """
    
    # TODO
    # logger.info('Initialization method [{:s}]'.format(init_type))
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if init_type == 'normal':
                weights_init_normal(m, bias_fill=bias_fill, **kwargs)
            if init_type == 'xavier':
                weights_init_xavier(m, scale=scale, bias_fill=bias_fill, **kwargs)    
            elif init_type == 'kaiming':
                weights_init_kaiming(m, scale=scale, bias_fill=bias_fill, **kwargs)
            elif init_type == 'orthogonal':
                weights_init_orthogonal(m, bias_fill=bias_fill)
            else:
                raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))



####################
# Upsampler
####################

class Upsample(nn.Module):
    r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.

    The input data is assumed to be of the form
    `minibatch x channels x [optional depth] x [optional height] x width`.

    Args:
        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
            output spatial sizes
        scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
            multiplier for spatial size. Has to match input size if it is a tuple.
        mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
            ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
            Default: ``'nearest'``
        align_corners (bool, optional): if ``True``, the corner pixels of the input
            and output tensors are aligned, and thus preserving the values at
            those pixels. This only has effect when :attr:`mode` is
            ``'linear'``, ``'bilinear'``, or ``'trilinear'``. Default: ``False``
    """
    # To prevent warning: nn.Upsample is deprecated
    # https://discuss.pytorch.org/t/which-function-is-better-for-upsampling-upsampling-or-interpolate/21811/8
    # From: https://pytorch.org/docs/stable/_modules/torch/nn/modules/upsampling.html#Upsample
    # Alternative: https://discuss.pytorch.org/t/using-nn-function-interpolate-inside-nn-sequential/23588/2?u=ptrblck
    
    def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
        super(Upsample, self).__init__()
        if isinstance(scale_factor, tuple):
            self.scale_factor = tuple(float(factor) for factor in scale_factor)
        else:
            self.scale_factor = float(scale_factor) if scale_factor else None
        self.mode = mode
        self.size = size
        self.align_corners = align_corners
        # self.interp = nn.functional.interpolate
    
    def forward(self, x):
        return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
        # return self.interp(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
    
    def extra_repr(self):
        if self.scale_factor is not None:
            info = 'scale_factor=' + str(self.scale_factor)
        else:
            info = 'size=' + str(self.size)
        info += ', mode=' + self.mode
        return info

def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                        pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
    """
    Pixel shuffle layer
    (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
    Neural Network, CVPR17)
    """
    conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
    pixel_shuffle = nn.PixelShuffle(upscale_factor)

    n = norm(norm_type, out_nc) if norm_type else None
    a = act(act_type) if act_type else None
    return sequential(conv, pixel_shuffle, n, a)

def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
                pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
    """
    Upconv layer described in https://distill.pub/2016/deconv-checkerboard/
    Example to replace deconvolutions: 
        - from: nn.ConvTranspose2d(in_nc, out_nc, kernel_size=4, stride=2, padding=1)
        - to: upconv_block(in_nc, out_nc,kernel_size=3, stride=1, act_type=None)
    """
    # upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
    upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
    upsample = Upsample(scale_factor=upscale_factor, mode=mode) #Updated to prevent the "nn.Upsample is deprecated" Warning
    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
                        pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
    return sequential(upsample, conv)

# PPON
def conv_layer(in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1):
    padding = int((kernel_size - 1) / 2) * dilation
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, bias=True, dilation=dilation, groups=groups)




####################
# ESRGANplus
####################

class GaussianNoise(nn.Module):
    def __init__(self, sigma=0.1, is_relative_detach=False):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.noise = torch.tensor(0, dtype=torch.float).to(torch.device('cuda'))

    def forward(self, x):
        """
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
            x = x + sampled_noise
        """
        return x 

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


# TODO: Not used:
# https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py
class minibatch_std_concat_layer(nn.Module):
    def __init__(self, averaging='all'):
        super(minibatch_std_concat_layer, self).__init__()
        self.averaging = averaging.lower()
        if 'group' in self.averaging:
            self.n = int(self.averaging[5:])
        else:
            assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging
        self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)

    def forward(self, x):
        shape = list(x.size())
        target_shape = copy.deepcopy(shape)
        vals = self.adjusted_std(x, dim=0, keepdim=True)
        if self.averaging == 'all':
            target_shape[1] = 1
            vals = torch.mean(vals, dim=1, keepdim=True)
        elif self.averaging == 'spatial':
            if len(shape) == 4:
                vals = mean(vals, axis=[2,3], keepdim=True)             # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True)
        elif self.averaging == 'none':
            target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
        elif self.averaging == 'gpool':
            if len(shape) == 4:
                vals = mean(x, [0,2,3], keepdim=True)                   # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True)
        elif self.averaging == 'flat':
            target_shape[1] = 1
            vals = torch.FloatTensor([self.adjusted_std(x)])
        else:                                                           # self.averaging == 'group'
            target_shape[1] = self.n
            vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3])
            vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1)
        vals = vals.expand(*target_shape)
        return torch.cat([x, vals], 1)


####################
# Useful blocks
####################

class SelfAttentionBlock(nn.Module):
    """ 
        Implementation of Self attention Block according to paper 
        'Self-Attention Generative Adversarial Networks' (https://arxiv.org/abs/1805.08318)
        Flexible Self Attention (FSA) layer according to paper
        Efficient Super Resolution For Large-Scale Images Using Attentional GAN (https://arxiv.org/pdf/1812.04821.pdf)
          The FSA layer borrows the self attention layer from SAGAN, 
          and wraps it with a max-pooling layer to reduce the size 
          of the feature maps and enable large-size images to fit in memory.
        Used in Generator and Discriminator Networks.
    """

    def __init__(self, in_dim, max_pool=False, poolsize = 4, spectral_norm=False, ret_attention=False): #in_dim = in_feature_maps
        super(SelfAttentionBlock,self).__init__()

        self.in_dim = in_dim
        self.max_pool = max_pool
        self.poolsize = poolsize
        self.ret_attention = ret_attention
        
        if self.max_pool:
            self.pooled = nn.MaxPool2d(kernel_size=self.poolsize, stride=self.poolsize) #kernel_size=4, stride=4
            # Note: can test using strided convolutions instead of MaxPool2d! :
            #upsample_block_num = int(math.log(scale_factor, 2))
            #self.pooled = nn.Conv2d .... strided conv
            # upsample_o = [UpconvBlock(in_channels=in_dim, out_channels=in_dim, upscale_factor=2, mode='bilinear', act_type='leakyrelu') for _ in range(upsample_block_num)]
            ## upsample_o.append(nn.Conv2d(nf, in_nc, kernel_size=9, stride=1, padding=4))
            ## self.upsample_o = nn.Sequential(*upsample_o)

            # self.upsample_o = B.Upsample(scale_factor=self.poolsize, mode='bilinear', align_corners=False) 
            
        self.conv_f = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #query_conv 
        self.conv_g = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #key_conv 
        self.conv_h = add_spectral_norm(
            nn.Conv1d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1, padding = 0), 
            use_spectral_norm=spectral_norm) #value_conv 

        self.gamma = nn.Parameter(torch.zeros(1)) # Trainable interpolation parameter
        self.softmax  = nn.Softmax(dim = -1)
        
    def forward(self,input):
        """
            inputs :
                input : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        
        if self.max_pool: #Downscale with Max Pool
            x = self.pooled(input)
        else:
            x = input
            
        batch_size, C, width, height = x.size()
        
        N = width * height
        x = x.view(batch_size, -1, N)
        f = self.conv_f(x) #proj_query  # B X CX(N)
        g = self.conv_g(x) #proj_key    # B X C x (*W*H)
        h = self.conv_h(x) #proj_value  # B X C X N

        s = torch.bmm(f.permute(0, 2, 1), g) # energy, transpose check
        # get probabilities
        attention = self.softmax(s) #beta #attention # BX (N) X (N) 
        
        out = torch.bmm(h, attention.permute(0,2,1))
        out = out.view(batch_size, C, width, height) 
        
        if self.max_pool: #Upscale to original size
            # out = self.upsample_o(out)
            out = Upsample(size=(input.shape[2],input.shape[3]), mode='bicubic', align_corners=False)(out) #bicubic (PyTorch > 1.0) | bilinear others.
        
        out = self.gamma*out + input #Add original input
        
        if self.ret_attention:
            return out, attention
        else:
            return out



In [None]:
#@title ESRGAN_arch.py (removing if statements)
"""
RRDBNet_arch.py (12-2-20)
https://github.com/victorca25/BasicSR/blob/master/codes/models/modules/architectures/RRDBNet_arch.py
"""
import math
import torch
import torch.nn as nn
#import torchvision
#from . import block as B
import functools
#from . import spectral_norm as SN


####################
# RRDBNet Generator (original architecture)
####################

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', \
            finalact=None, gaussian_noise=False, plus=False):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
        rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, \
            gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
        LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)

        if upsample_mode == 'upconv':
            upsample_block = upconv_block
        elif upsample_mode == 'pixelshuffle':
            upsample_block = pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
        HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
        HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)

        # Note: this option adds new parameters to the architecture, another option is to use "outm" in the forward
        outact = act(finalact) if finalact else None
        
        self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1, outact)

    def forward(self, x,):
        x = self.model(x)
      
        """
        if outm=='scaltanh': # limit output range to [-1,1] range with tanh and rescale to [0,1] Idea from: https://github.com/goldhuang/SRGAN-PyTorch/blob/master/model.py
            return(torch.tanh(x) + 1.0) / 2.0
        elif outm=='tanh': # limit output to [-1,1] range
            return torch.tanh(x)
        elif outm=='sigmoid': # limit output to [0,1] range
            return torch.sigmoid(x)
        elif outm=='clamp':
            return torch.clamp(x, min=0.0, max=1.0)
        else: #Default, no cap for the output
          return x
        """
        return x

class RRDB(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(RRDB, self).__init__()
        # This is for backwards compatibility with existing models
        if nr == 3:
            self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
            self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, \
                    norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, \
                    gaussian_noise=gaussian_noise, plus=plus)
        else:
            RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
                                              norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
                                              gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
            self.RDBs = nn.Sequential(*RDB_list)

    def forward(self, x):
        if hasattr(self, 'RDB1'):
            out = self.RDB1(x)
            out = self.RDB2(out)
            out = self.RDB3(out)
        else:
            out = self.RDBs(x)
        return out * 0.2 + x

class ResidualDenseBlock_5C(nn.Module):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    Modified options that can be used:
        - "Partial Convolution based Padding" arXiv:1811.11718
        - "Spectral normalization" arXiv:1802.05957
        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. 
            {Rakotonirina} and A. {Rasoanaivo}
    
    Args:
        nf (int): Channel number of intermediate features (num_feat).
        gc (int): Channels for each growth (num_grow_ch: growth channel, 
            i.e. intermediate channels).
        convtype (str): the type of convolution to use. Default: 'Conv2D'
        gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new 
            trainable parameters)
        plus (bool): enable the additional residual paths from ESRGAN+ 
            (adds trainable parameters)
    '''

    def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', \
            spectral_norm=False, gaussian_noise=False, plus=False):
        super(ResidualDenseBlock_5C, self).__init__()
        
        ## +
        self.noise = GaussianNoise() if gaussian_noise else None
        self.conv1x1 = conv1x1(nf, gc) if plus else None
        ## +

        self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, \
            spectral_norm=spectral_norm)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        """
        if self.conv1x1:
            x2 = x2 + self.conv1x1(x) #+
        """
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        """
        if self.conv1x1:
            x4 = x4 + x2 #+
        """
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        """
        if self.noise:
            return self.noise(x5.mul(0.2) + x)
        """
        #else:
        return x5 * 0.2 + x


####################
# RRDBNet Generator (modified/"new" architecture)
####################


class MRRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(MRRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDBM, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

class ResidualDenseBlock_5CM(nn.Module):
    '''
    Residual Dense Block
    '''
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5CM, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], scale=0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDBM(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDBM, self).__init__()
        self.RDB1 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB2 = ResidualDenseBlock_5CM(nf, gc)
        self.RDB3 = ResidualDenseBlock_5CM(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


In [None]:
model_path = '/content/4x_cat_patch_325000_G.pth' #@param
state_dict = torch.load(model_path)
model = torch.jit.script(RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, upscale=4, norm_type='null',
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
            finalact=None, gaussian_noise=False, plus=False, 
            nr=3))

image_path = "/content/image.png" #@param
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.from_numpy(image).unsqueeze(0).permute(0,3,1,2)/255

#for k, v in state_dict.items():
#  print(k)

model.load_state_dict(state_dict)


model = model.eval()

#out = model(torch.randn(1, 3, 64, 64))
out = model(image)
print(out.shape)

from torchvision.utils import save_image
save_image(out, "out.png")