In [8]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline


# (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime

In this tutorial, we describe how to convert a model defined
in PyTorch into the ONNX format and then run it with ONNX Runtime.

ONNX Runtime is a performance-focused engine for ONNX models,
which inferences efficiently across multiple platforms and hardware
(Windows, Linux, and Mac and on both CPUs and GPUs).
ONNX Runtime has proved to considerably increase performance over
multiple models as explained [here](https://cloudblogs.microsoft.com/opensource/2019/05/22/onnx-runtime-machine-learning-inferencing-0-4-release)_

For this tutorial, you will need to install [ONNX](https://github.com/onnx/onnx)_
and [ONNX Runtime](https://github.com/microsoft/onnxruntime)_.
You can get binary builds of ONNX and ONNX Runtime with
``pip install onnx onnxruntime``.
Note that ONNX Runtime is compatible with Python versions 3.5 to 3.7.

``NOTE``: This tutorial needs PyTorch master branch which can be installed by following
the instructions [here](https://github.com/pytorch/pytorch#from-source)_


In [9]:
# Some standard imports
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

Super-resolution is a way of increasing the resolution of images, videos
and is widely used in image processing or video editing. For this
tutorial, we will use a small super-resolution model.

First, let's create a ``SuperResolution`` model in PyTorch.
This model uses the efficient sub-pixel convolution layer described in
["Real-Time Single Image and Video Super-Resolution Using an Efficient
Sub-Pixel Convolutional Neural Network" - Shi et al](https://arxiv.org/abs/1609.05158)_
for increasing the resolution of an image by an upscale factor.
The model expects the Y component of the ``YCbCr`` of an image as an input, and
outputs the upscaled Y component in super resolution.

[The
model](https://github.com/pytorch/examples/blob/master/super_resolution/model.py)_
comes directly from PyTorch's examples without modification:




In [10]:
# SWIN MODEL model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from itertools import repeat
import collections.abc


# -----------------------------------------------------------------------------------
# Code borrowed from: pytorch-image-models, 
# https://github.com/rwightman/pytorch-image-models 
# Originally Written by Ross Wightman.
# -----------------------------------------------------------------------------------

def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    '''
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    '''
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    '''
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    '''
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse

to_2tuple = _ntuple(2)

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    '''
    Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    '''
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

# -----------------------------------------------------------------------------------
# Code borrowed from: SwinIR: Image Restoration Using Swin Transformer, 
# https://arxiv.org/abs/2108.10257, https://github.com/JingyunLiang/SwinIR
# Originally Written by Jingyun Liang, Modified by Ricard Lado.
# -----------------------------------------------------------------------------------

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    '''
    Args:
        x: (B, H, W, C)H
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    '''
    B, H, W, C = x.shape
    #   Debug
    #print("B, H, W, C")
    #print(window_size)
    #print(x.shape)
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    '''
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    '''
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    ''' 
    Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    '''

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # indexing='ij'  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer('relative_position_index', relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)

        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        '''
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        '''
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops


class SwinTransformerBlock(nn.Module):
    ''' 
    Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    '''

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer('attn_mask', attn_mask)

    def calculate_mask(self, x_size):
        # calculate attention mask for SW-MSA
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x, x_size):
        H, W = x_size
        B, L, C = x.shape
        # assert L == H * W, 'input feature has wrong size'

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
        if self.input_resolution == x_size:
            attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        else:
            attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
               f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}'

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # W-MSA/SW-MSA
        nW = H * W / self.window_size / self.window_size
        flops += nW * self.attn.flops(self.window_size * self.window_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops


class PatchMerging(nn.Module):
    ''' 
    Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    '''

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        '''
        x: B, H*W, C
        '''
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, 'input feature has wrong size'
        assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.'

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def extra_repr(self) -> str:
        return f'input_resolution={self.input_resolution}, dim={self.dim}'

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
        return flops


class BasicLayer(nn.Module):
    ''' 
    A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    '''

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, x_size):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, x_size)
            else:
                x = blk(x, x_size)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops


class RSTB(nn.Module):
    '''
    Residual Swin Transformer Block (RSTB).

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        img_size: Input image size.
        patch_size: Patch size.
        resi_connection: The convolutional block before residual connection.
    '''

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                 img_size=224, patch_size=4, resi_connection='1conv'):
        super(RSTB, self).__init__()

        self.dim = dim
        self.input_resolution = input_resolution

        self.residual_group = BasicLayer(dim=dim,
                                         input_resolution=input_resolution,
                                         depth=depth,
                                         num_heads=num_heads,
                                         window_size=window_size,
                                         mlp_ratio=mlp_ratio,
                                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                                         drop=drop, attn_drop=attn_drop,
                                         drop_path=drop_path,
                                         norm_layer=norm_layer,
                                         downsample=downsample,
                                         use_checkpoint=use_checkpoint)

        if resi_connection == '1conv':
            self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
        elif resi_connection == '3conv':
            # to save parameters and memory
            self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                      nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
                                      nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                      nn.Conv2d(dim // 4, dim, 3, 1, 1))

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
            norm_layer=None)

        self.patch_unembed = PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
            norm_layer=None)

    def forward(self, x, x_size):
        return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x

    def flops(self):
        flops = 0
        flops += self.residual_group.flops()
        H, W = self.input_resolution
        flops += H * W * self.dim * self.dim * 9
        flops += self.patch_embed.flops()
        flops += self.patch_unembed.flops()

        return flops


class PatchEmbed(nn.Module):
    ''' 
    Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    '''

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        flops = 0
        H, W = self.img_size
        if self.norm is not None:
            flops += H * W * self.embed_dim
        return flops


class PatchUnEmbed(nn.Module):
    '''
    Image to Patch Unembedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    '''

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

    def forward(self, x, x_size):
        B, HW, C = x.shape
        x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])  # B Ph*Pw C
        return x

    def flops(self):
        flops = 0
        return flops

###
#LB-VMM
###

def _make_layer(block, in_planes, out_planes, num_layers, kernel_size=3, stride=1): 
    layers = []
    for i in range(num_layers):
        layers.append(block(in_planes, out_planes, kernel_size, stride))
    return nn.Sequential(*layers)

class ResBlock(nn.Module):
    def __init__(self, in_planes, output_planes, kernel_size=3, stride=1):
        super(ResBlock, self).__init__()
        p = (kernel_size-1)//2
        self.pad1 = nn.ReflectionPad2d(p)
        self.conv1 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size,
             stride=stride, bias=False)
        self.pad2 = nn.ReflectionPad2d(p)
        self.conv2 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size,
             stride=stride, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.relu(self.conv1(self.pad1(x)))
        y = self.conv2(self.pad2(y))
        return y + x

class ConvBlock(nn.Module):
    def __init__(self, in_planes, output_planes, kernel_size=7, stride=1):
        super(ConvBlock, self).__init__()
        p=3
        self.pad1 = nn.ReflectionPad2d(p)
        self.conv1 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size,
             stride=stride, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv1(self.pad1(x)))

class ConvBlockAfter(nn.Module):
    def __init__(self, in_planes, output_planes, kernel_size=3, stride=1):
        super(ConvBlockAfter, self).__init__()
        p=1
        self.pad1 = nn.ReflectionPad2d(p)
        self.conv1 = nn.Conv2d(in_planes, output_planes, kernel_size=kernel_size,
             stride=stride, bias=False)

    def forward(self, x):
        return self.conv1(self.pad1(x))

class Manipulator(nn.Module):
    def __init__(self, num_resblk, embed_dim):
        super(Manipulator, self).__init__()
        self.convblks = _make_layer(ConvBlock, embed_dim, embed_dim, 1, kernel_size=7, stride=1)
        self.convblks_after = _make_layer(ConvBlockAfter, embed_dim, embed_dim, 1, kernel_size=3, stride=1)
        self.resblks = _make_layer(ResBlock, embed_dim, embed_dim, num_resblk, kernel_size=3, stride=1)
        
    def forward(self, x_a, x_b, amp):
        diff = x_b - x_a
        diff = self.convblks(diff)
        diff = (amp - 1.0) * diff
        diff = self.convblks_after(diff)
        diff = self.resblks(diff)      

        return x_b + diff

### Model
class STBVMM(nn.Module):
    '''
    STBVMM model class

    Args:
        img_size (int | tuple(int)): Input image size. Default 384
        patch_size (int | tuple(int)): Patch size. Default: 1
        in_chans (int): Number of input image channels. Default: 3
        embed_dim (int): Patch embedding dimension. Default: 192
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 8
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 2
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
        img_range (int): Image range. 1. or 255.
        resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
        manipulator_num_resblk (int): Number of residual blocks of the maipulator. Default: 1
    '''

    def __init__(self, img_size=384, patch_size=1, in_chans=3,
                 embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
                 window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, img_range=1., resi_connection='1conv',
                 manipulator_num_resblk = 1,
                 **kwargs):
                 
        super(STBVMM, self).__init__()
        img_size = img_size // 8
        num_in_ch = in_chans
        num_out_ch = in_chans
        self.img_range = img_range
        if in_chans == 3:
            rgb_mean = (0.4488, 0.4371, 0.4040)
            self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
        else:
            self.mean = torch.zeros(1, 1, 1, 1)
        self.window_size = window_size

        #################### Shallow Feature Extraction ########################
        self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 8, 8, 0) #Downsample x8

        ###################### Deep Feature Extraction #########################
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = embed_dim
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed_dfe = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed_dfe.num_patches
        patches_resolution = self.patch_embed_dfe.patches_resolution
        self.patches_resolution = patches_resolution

        # merge non-overlapping patches into image
        self.patch_unembed_dfe = PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed_dfe = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed_dfe, std=.02)

        self.pos_drop_dfe = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr_dfe = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build Residual Swin Transformer blocks (RSTB)
        self.layers_dfe = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = RSTB(dim=embed_dim,
                         input_resolution=(patches_resolution[0],
                                           patches_resolution[1]),
                         depth=depths[i_layer],
                         num_heads=num_heads[i_layer],
                         window_size=window_size,
                         mlp_ratio=self.mlp_ratio,
                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                         drop=drop_rate, attn_drop=attn_drop_rate,
                         drop_path=dpr_dfe[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                         norm_layer=norm_layer,
                         downsample=None,
                         use_checkpoint=use_checkpoint,
                         img_size=img_size,
                         patch_size=patch_size,
                         resi_connection=resi_connection

                         )
            self.layers_dfe.append(layer)
        self.norm_dfe = norm_layer(self.num_features)

        # build the last conv layer in deep feature extraction
        if resi_connection == '1conv':
            self.conv_after_body_dfe = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        elif resi_connection == '3conv':
            # to save parameters and memory
            self.conv_after_body_dfe = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                                 nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                                 nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
        
        ############################## Manipulator #############################
        self.manipulator = Manipulator(manipulator_num_resblk, embed_dim)

        ############### Mixed Magnified Transformer Block (MMTB) ###############
        # split image into non-overlapping patches
        self.patch_embed_mmsa = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed_mmsa.num_patches
        patches_resolution = self.patch_embed_mmsa.patches_resolution
        self.patches_resolution = patches_resolution

        # merge non-overlapping patches into image
        self.patch_unembed_mmsa = PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed_mmsa = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed_mmsa, std=.02)

        self.pos_drop_mmsa = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr_mmsa = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build Residual Swin Transformer blocks (RSTB)
        self.layers_mmsa = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = RSTB(dim=embed_dim,
                         input_resolution=(patches_resolution[0],
                                           patches_resolution[1]),
                         depth=depths[i_layer],
                         num_heads=num_heads[i_layer],
                         window_size=window_size,
                         mlp_ratio=self.mlp_ratio,
                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                         drop=drop_rate, attn_drop=attn_drop_rate,
                         drop_path=dpr_mmsa[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                         norm_layer=norm_layer,
                         downsample=None,
                         use_checkpoint=use_checkpoint,
                         img_size=img_size,
                         patch_size=patch_size,
                         resi_connection=resi_connection

                         )
            self.layers_mmsa.append(layer)
        self.norm_mmsa = norm_layer(self.num_features)

        # build the last conv layer in deep feature extraction
        if resi_connection == '1conv':
            self.conv_after_body_mmsa = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
        elif resi_connection == '3conv':
            # to save parameters and memory
            self.conv_after_body_mmsa = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                                 nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
                                                 nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                                 nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
        
        ######################## Decoder Reconstruction ########################
        self.upsample_conv = nn.ConvTranspose2d(embed_dim, embed_dim, 8, 8, 0)
        self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)

        #Init weights
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
        mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
        return x

    def forward_features_dfe(self, x):
        x_size = (x.shape[2], x.shape[3])
        x = self.patch_embed_dfe(x)
        if self.ape:
            x = x + self.absolute_pos_embed_dfe
        x = self.pos_drop_dfe(x)

        for layer in self.layers_dfe:
            x = layer(x, x_size)

        x = self.norm_dfe(x)  # B L C
        x = self.patch_unembed_dfe(x, x_size)

        return x
    
    def forward_features_mmsa(self, x):
        x_size = (x.shape[2], x.shape[3])
        x = self.patch_embed_mmsa(x)
        if self.ape:
            x = x + self.absolute_pos_embed_mmsa
        x = self.pos_drop_mmsa(x)

        for layer in self.layers_mmsa:
            x = layer(x, x_size)

        x = self.norm_mmsa(x)  # B L C
        x = self.patch_unembed_mmsa(x, x_size)

        return x

    def forward(self, a, b, amp, c = None):
        if a.shape != b.shape:
            raise RuntimeError('Image size mismatch')
        a = self.check_image_size(a)
        b = self.check_image_size(b)
        
        self.mean = self.mean.type_as(a)

        a = (a - self.mean) * self.img_range
        b = (b - self.mean) * self.img_range

        # Forward
        ## Shallow Feature Extractor
        a_first = self.conv_first(a)
        b_first = self.conv_first(b)

        ## Deep Feature Extractor
        res_a = self.conv_after_body_dfe(self.forward_features_dfe(a_first)) + a_first
        res_b = self.conv_after_body_dfe(self.forward_features_dfe(b_first)) + b_first

        ## Manipulator
        m = self.manipulator(res_a, res_b, amp)

        ## Mixed Magnified Transformer Block
        res_m = self.conv_after_body_mmsa(self.forward_features_mmsa(m)) + m

        ## Decoder Reconstruction
        y_hat = self.conv_last(self.upsample_conv(res_m))
        
        ## Extract features for c if training and not None
        if c != None and self.training:
            if b.shape != c.shape:
                raise RuntimeError('Image C size mismatch')
            c = self.check_image_size(c)

            c = (c - self.mean) * self.img_range

            ## Shallow Feature Extractor
            c_first = self.conv_first(c)

            ## Deep Feature Extractor
            res_c = self.conv_after_body_dfe(self.forward_features_dfe(c_first)) + c_first

            return y_hat, res_a, res_b, res_c

        else:
            return y_hat, res_a, res_b, None

if __name__ == '__main__':
    model = STBVMM(img_size=384, patch_size=1, in_chans=3,
                 embed_dim=48, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
                 window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, img_range=1., resi_connection='1conv',
                 manipulator_num_resblk = 1)
    
    model.eval()

    a = torch.randn((1,3,384,384))
    b = torch.randn((1,3,384,384))

    output = model(a, b, .2)

    print(output[0].shape)
    print(output[3])

    model.train()

    output = model(a, b, .2, b)

    print(output[0].shape)
    print(output[3].shape)

torch_model = STBVMM(img_size=384, patch_size=1, in_chans=3,
                 embed_dim=48, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
                 window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, img_range=1., resi_connection='1conv',
                 manipulator_num_resblk = 1)
# Create the super-resolution model by using the above model definition.
# torch_model = SuperResolutionNet(upscale_factor=3)

torch.Size([1, 3, 384, 384])
None
torch.Size([1, 3, 384, 384])
torch.Size([1, 48, 48, 48])


Ordinarily, you would now train this model; however, for this tutorial,
we will instead download some pretrained weights. Note that this model
was not trained fully for good accuracy and is used here for
demonstration purposes only.

It is important to call ``torch_model.eval()`` or ``torch_model.train(False)``
before exporting the model, to turn the model to inference mode.
This is required since operators like dropout or batchnorm behave
differently in inference and training mode.




In [11]:
import os
import torch
import torch.nn as nn
from models.model import STBVMM


model = STBVMM(img_size=384, patch_size=1, in_chans=3,
                embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
                window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
                drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                use_checkpoint=False, img_range=1., resi_connection='1conv',
                manipulator_num_resblk=1).to("cpu")

# Initialize the model
if torch.cuda.is_available():
    torch_model.cuda()

# Load model weights from the checkpoint
checkpoint = torch.load('ckpt/ckpt_e10.pth.tar')
model.load_state_dict(checkpoint['state_dict'], strict= False)

# torch_model.load_state_dict(checkpoint['state_dict'])

# Set the model to inference mode
model.eval()

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
# if torch.cuda.is_available():
#     map_location = None
# model.load_state_dict(model_zoo.load_url(model, map_location=map_location))
# set the model to inference mode
model.eval()

STBVMM(
  (conv_first): Conv2d(3, 192, kernel_size=(8, 8), stride=(8, 8))
  (patch_embed_dfe): PatchEmbed(
    (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  )
  (patch_unembed_dfe): PatchUnEmbed()
  (pos_drop_dfe): Dropout(p=0.0, inplace=False)
  (layers_dfe): ModuleList(
    (0): RSTB(
      (residual_group): BasicLayer(
        dim=192, input_resolution=(48, 48), depth=6
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            dim=192, input_resolution=(48, 48), num_heads=6, window_size=8, shift_size=0, mlp_ratio=2.0
            (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              dim=192, window_size=(8, 8), num_heads=6
              (qkv): Linear(in_features=192, out_features=576, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=192, out_features=192, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
     

Exporting a model in PyTorch works via tracing or scripting. This
tutorial will use as an example a model exported by tracing.
To export a model, we call the ``torch.onnx.export()`` function.
This will execute the model, recording a trace of what operators
are used to compute the outputs.
Because ``export`` runs the model, we need to provide an input
tensor ``x``. The values in this can be random as long as it is the
right type and size.
Note that the input size will be fixed in the exported ONNX graph for
all the input's dimensions, unless specified as a dynamic axes.
In this example we export the model with an input of batch_size 1,
but then specify the first dimension as dynamic in the ``dynamic_axes``
parameter in ``torch.onnx.export()``.
The exported model will thus accept inputs of size [batch_size, 1, 224, 224]
where batch_size can be variable.

To learn more details about PyTorch's export interface, check out the
[torch.onnx documentation](https://pytorch.org/docs/master/onnx.html)_.




In [12]:
model.eval()

a = torch.randn((1,3,384,384))
b = torch.randn((1,3,384,384))

output = model(a, b, .2)

print(output[0].shape)
print(output[3])

model.train()


output = model(a, b, .2, b)

print(output[0].shape)
print(output[3].shape)

torch.Size([1, 3, 384, 384])
None
torch.Size([1, 3, 384, 384])
torch.Size([1, 192, 48, 48])


In [17]:
# Input to the model
# x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
a = torch.randn((1,3,384,384))
b = torch.randn((1,3,384,384))
amp = torch.tensor(1.0)  # replace 1.0 with an appropriate dummy value


# Export the model
torch.onnx.export(model,                    # model being run
                  (a, b, amp),                   # model input (or a tuple for multiple inputs)
                  "model_STB.onnx",         # where to save the model (can be a file or file-like object)
                  export_params=True,       # store the trained parameter weights inside the model file
                  opset_version=10,         # the ONNX version to export the model to
                  do_constant_folding=True, # whether to execute constant folding for optimization
                  input_names = ['input_a', 'input_b'],   # the model's input names
                  output_names = ['output1', 'output2', 'output3', 'output4'], # the model's output names
                  dynamic_axes={'input_a' : {0 : 'batch_size'},    # variable length axes
                                'input_b' : {0 : 'batch_size'},
                                'output1' : {0 : 'batch_size'},
                                'output2' : {0 : 'batch_size'},
                                'output3' : {0 : 'batch_size'},
                                'output4' : {0 : 'batch_size'}})

verbose: False, log level: Level.ERROR



SymbolicValueError: Unsupported: ONNX export of Pad in opset 9. The sizes of the padding must be constant. Please try opset version 11.  [Caused by the value '1168 defined in (%1168 : int[] = prim::ListConstruct(%1167, %1166, %1167, %1135), scope: models.model.STBVMM::
)' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.] 

    Inputs:
        #0: 1167 defined in (%1167 : Long(device=cpu) = onnx::Constant[value={0}](), scope: models.model.STBVMM::
    )  (type 'Tensor')
        #1: 1166 defined in (%1166 : Long(requires_grad=0, device=cpu) = onnx::Sub(%1151, %1165), scope: models.model.STBVMM:: # /Users/raoulritter/STB-VMM/models/model.py:953:0
    )  (type 'Tensor')
        #2: 1167 defined in (%1167 : Long(device=cpu) = onnx::Constant[value={0}](), scope: models.model.STBVMM::
    )  (type 'Tensor')
        #3: 1135 defined in (%1135 : Long(requires_grad=0, device=cpu) = onnx::Sub(%1120, %1134), scope: models.model.STBVMM:: # /Users/raoulritter/STB-VMM/models/model.py:952:0
    )  (type 'Tensor')
    Outputs:
        #0: 1168 defined in (%1168 : int[] = prim::ListConstruct(%1167, %1166, %1167, %1135), scope: models.model.STBVMM::
    )  (type 'List[int]')

We also computed ``torch_out``, the output after of the model,
which we will use to verify that the model we exported computes
the same values when run in ONNX Runtime.

But before verifying the model's output with ONNX Runtime, we will check
the ONNX model with ONNX API.
First, ``onnx.load("super_resolution.onnx")`` will load the saved model and
will output a ``onnx.ModelProto`` structure (a top-level file/container format for bundling a ML model.
For more information [onnx.proto documentation](https://github.com/onnx/onnx/blob/master/onnx/onnx.proto)_.).
Then, ``onnx.checker.check_model(onnx_model)`` will verify the model's structure
and confirm that the model has a valid schema.
The validity of the ONNX graph is verified by checking the model's
version, the graph's structure, as well as the nodes and their inputs
and outputs.




In [None]:
import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

Now let's compute the output using ONNX Runtime's Python APIs.
This part can normally be done in a separate process or on another
machine, but we will continue in the same process so that we can
verify that ONNX Runtime and PyTorch are computing the same value
for the network.

In order to run the model with ONNX Runtime, we need to create an
inference session for the model with the chosen configuration
parameters (here we use the default config).
Once the session is created, we evaluate the model using the run() API.
The output of this call is a list containing the outputs of the model
computed by ONNX Runtime.




In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

ModuleNotFoundError: No module named 'onnxruntime'

We should see that the output of PyTorch and ONNX Runtime runs match
numerically with the given precision (``rtol=1e-03`` and ``atol=1e-05``).
As a side-note, if they do not match then there is an issue in the
ONNX exporter, so please contact us in that case.




## Running the model on an image using ONNX Runtime




So far we have exported a model from PyTorch and shown how to load it
and run it in ONNX Runtime with a dummy tensor as an input.



For this tutorial, we will use a famous cat image used widely which
looks like below

.. figure:: /_static/img/cat_224x224.jpg
   :alt: cat




First, let's load the image, preprocess it using standard PIL
python library. Note that this preprocessing is the standard practice of
processing data for training/testing neural networks.

We first resize the image to fit the size of the model's input (224x224).
Then we split the image into its Y, Cb, and Cr components.
These components represent a grayscale image (Y), and
the blue-difference (Cb) and red-difference (Cr) chroma components.
The Y component being more sensitive to the human eye, we are
interested in this component which we will be transforming.
After extracting the Y component, we convert it to a tensor which
will be the input of our model.




In [None]:
from PIL import Image
import torchvision.transforms as transforms

img = Image.open("./_static/img/cat.jpg")

resize = transforms.Resize([224, 224])
img = resize(img)

img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()

to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)

Now, as a next step, let's take the tensor representing the
grayscale resized cat image and run the super-resolution model in
ONNX Runtime as explained previously.




In [None]:
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]

At this point, the output of the model is a tensor.
Now, we'll process the output of the model to construct back the
final output image from the output tensor, and save the image.
The post-processing steps have been adopted from PyTorch
implementation of super-resolution model
[here](https://github.com/pytorch/examples/blob/master/super_resolution/super_resolve.py)_.




In [None]:
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')

# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge(
    "YCbCr", [
        img_out_y,
        img_cb.resize(img_out_y.size, Image.BICUBIC),
        img_cr.resize(img_out_y.size, Image.BICUBIC),
    ]).convert("RGB")

# Save the image, we will compare this with the output image from mobile device
final_img.save("./_static/img/cat_superres_with_ort.jpg")

.. figure:: /_static/img/cat_superres_with_ort.jpg
   :alt: output\_cat


ONNX Runtime being a cross platform engine, you can run it across
multiple platforms and on both CPUs and GPUs.

ONNX Runtime can also be deployed to the cloud for model inferencing
using Azure Machine Learning Services. More information [here](https://docs.microsoft.com/en-us/azure/machine-learning/service/concept-onnx)_.

More information about ONNX Runtime's performance [here](https://github.com/microsoft/onnxruntime#high-performance)_.


For more information about ONNX Runtime [here](https://github.com/microsoft/onnxruntime)_.


