In [1]:
# !pip install einops

In [2]:
import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers

from tensorflow.keras import backend as K

import timeit

2022-09-30 15:30:31.640939: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class BlockImages(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x, patch_size):
        bs, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )

        grid_height, grid_width = h // patch_size[0], w // patch_size[1]

        x = layers.Reshape(
            (grid_height * patch_size[0], grid_width * patch_size[1], num_channels)
        )(x)

        x = layers.Reshape(
            (-1, grid_height * grid_width, patch_size[0] * patch_size[1], num_channels)
        )(x)

        return x

In [4]:
class UnblockImages(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x, grid_size, patch_size):
        num_channels = K.int_shape(x)[-1]

        x = layers.Reshape(
            (grid_size[0] * grid_size[1], patch_size[0] * patch_size[1], num_channels)
        )(x)

        x = layers.Reshape(
            (grid_size[0] * patch_size[0], grid_size[1] * patch_size[1], num_channels)
        )(x)

        return x

In [5]:
import functools

Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
ConvT_up = functools.partial(
    layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
)
Conv_down = functools.partial(
    layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
)

In [6]:
def MlpBlock(
    mlp_dim: int,
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    name: str = "mlp_block",
):
    """A 1-hidden-layer MLP block, applied over the last dimension."""

    def apply(x):
        d = K.int_shape(x)[-1]
        x = layers.Dense(mlp_dim, use_bias=use_bias)(x)
        x = tf.nn.gelu(x)
        x = layers.Dropout(dropout_rate)(x)
        x = layers.Dense(d, use_bias=use_bias)(x)
        return x

    return apply

In [7]:
def UpSampleRatio(
    num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
):
    """Upsample features given a ratio > 0."""

    def apply(x):
        n, h, w, c = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )

        x = layers.Resizing(
            height=tf.cast(h * ratio, tf.int32), width=tf.cast(w * ratio, tf.int32)
        )(x)

        x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_point_conv")(x)
        return x

    return apply

In [8]:
def CALayer(
    num_channels: int,
    reduction: int = 4,
    use_bias: bool = True,
    name: str = "channel_attention",
):
    """Squeeze-and-excitation block for channel attention.
    ref: https://arxiv.org/abs/1709.01507
    """

    def apply(x):
        # 2D global average pooling
        y = layers.GlobalAvgPool2D(keepdims=True)(x)
        # Squeeze (in Squeeze-Excitation)
        y = Conv1x1(filters=num_channels // reduction, use_bias=use_bias)(y)
        y = tf.nn.relu(y)
        # Excitation (in Squeeze-Excitation)
        y = Conv1x1(filters=num_channels, use_bias=use_bias)(y)
        y = tf.nn.sigmoid(y)
        return x * y

    return apply

In [9]:
def RCAB(
    num_channels: int,
    reduction: int = 4,
    lrelu_slope: float = 0.2,
    use_bias: bool = True,
    name: str = "residual_ca",
):
    """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""

    def apply(x):
        shortcut = x
        x = layers.LayerNormalization(name=f"{name}_LayerNorm")(x)
        x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
        x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
        x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
        x = CALayer(
            num_channels=num_channels,
            reduction=reduction,
            use_bias=use_bias,
            name="channel_attention",
        )(x)
        return x + shortcut

    return apply

In [10]:
def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
    """A SpatialGatingUnit as defined in the gMLP paper.
    The 'spatial' dim is defined as the second last.
    If applied on other dims, you should swapaxes first.
    """

    def apply(x):
        u, v = tf.split(x, 2, axis=-1)
        v = layers.LayerNormalization(name=f"{name}_intermediate_layernorm")(v)
        n = K.int_shape(x)[-3]  # get spatial dim
        v = SwapAxes()(v, -1, -3)
        v = layers.Dense(n, use_bias=use_bias)(v)
        v = SwapAxes()(v, -1, -3)
        return u * (v + 1.0)

    return apply

In [11]:
def GridGmlpLayer(
    grid_size,
    use_bias: bool = True,
    factor: int = 2,
    dropout_rate: float = 0.0,
    name: str = "grid_gmlp",
):
    """Grid gMLP layer that performs global mixing of tokens."""

    def apply(x):
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        gh, gw = grid_size
        fh, fw = h // gh, w // gw

        x = BlockImages()(x, patch_size=(fh, fw))
        # gMLP1: Global (grid) mixing part, provides global grid communication.
        y = layers.LayerNormalization(name=f"{name}_LayerNorm")(x)
        y = layers.Dense(
            num_channels * factor,
            use_bias=use_bias,
            name=f"{name}_in_project",
        )(y)
        y = tf.nn.gelu(y)
        y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
        y = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_out_project",
        )(y)
        y = layers.Dropout(dropout_rate)(y)
        x = x + y
        x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
        return x

    return apply

In [12]:
class SwapAxes(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x, axis_one, axis_two):
        return tf.experimental.numpy.swapaxes(x, axis_one, axis_two)

In [13]:
def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
    """A SpatialGatingUnit as defined in the gMLP paper.
    The 'spatial' dim is defined as the **second last**.
    If applied on other dims, you should swapaxes first.
    """

    def apply(x):
        u, v = tf.split(x, 2, axis=-1)
        v = layers.LayerNormalization(name=f"{name}_intermediate_layernorm")(v)
        n = K.int_shape(x)[-2]  # get spatial dim
        v = SwapAxes()(v, -1, -2)
        v = layers.Dense(n, use_bias=use_bias)(v)
        v = SwapAxes()(v, -1, -2)
        return u * (v + 1.0)

    return apply

In [14]:
def BlockGmlpLayer(
    block_size,
    use_bias: bool = True,
    factor: int = 2,
    dropout_rate: float = 0.0,
    name: str = "block_gmlp",
):
    """Block gMLP layer that performs local mixing of tokens."""

    def apply(x):
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        fh, fw = block_size
        gh, gw = h // fh, w // fw
        x = BlockImages()(x, patch_size=(fh, fw))
        # MLP2: Local (block) mixing part, provides within-block communication.
        y = layers.LayerNormalization(name=f"{name}_LayerNorm")(x)
        y = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_in_project",
        )(y)
        y = tf.nn.gelu(y)
        y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
        y = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_out_project",
        )(y)
        y = layers.Dropout(dropout_rate)(y)
        x = x + y
        x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
        return x

    return apply

In [15]:
def ResidualSplitHeadMultiAxisGmlpLayer(
    block_size,
    grid_size,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    use_bias: bool = True,
    dropout_rate: float = 0.0,
    name: str = "residual_split_head_maxim",
):
    """The multi-axis gated MLP block."""

    def apply(x):
        shortcut = x
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        x = layers.LayerNormalization(name=f"{name}_LayerNorm_in")(x)

        x = layers.Dense(
            int(num_channels) * input_proj_factor,
            use_bias=use_bias,
            name=f"{name}_in_project",
        )(x)
        x = tf.nn.gelu(x)

        u, v = tf.split(x, 2, axis=-1)

        # GridGMLPLayer
        u = GridGmlpLayer(
            grid_size=grid_size,
            factor=grid_gmlp_factor,
            use_bias=use_bias,
            dropout_rate=dropout_rate,
            name=f"{name}_GridGmlpLayer",
        )(u)

        # BlockGMLPLayer
        v = BlockGmlpLayer(
            block_size=block_size,
            factor=block_gmlp_factor,
            use_bias=use_bias,
            dropout_rate=dropout_rate,
            name=f"{name}_BlockGmlpLayer",
        )(v)

        x = tf.concat([u, v], axis=-1)

        x = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_out_project",
        )(x)
        x = layers.Dropout(dropout_rate)(x)
        x = x + shortcut
        return x

    return apply

In [16]:
def RDCAB(
    num_channels: int,
    reduction: int = 16,
    use_bias: bool = True,
    dropout_rate: float = 0.0,
    name: str = "rdcab",
):
    """Residual dense channel attention block. Used in Bottlenecks."""

    def apply(x):
        y = layers.LayerNormalization(name=f"{name}_LayerNorm")(x)
        y = MlpBlock(
            mlp_dim=num_channels,
            dropout_rate=dropout_rate,
            use_bias=use_bias,
            name="channel_mixing",
        )(y)
        y = CALayer(
            num_channels=num_channels,
            reduction=reduction,
            use_bias=use_bias,
            name="channel_attention",
        )(y)
        x = x + y
        return x

    return apply

In [17]:
def BottleneckBlock(
    features: int,
    block_size,
    grid_size,
    num_groups: int = 1,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    channels_reduction: int = 4,
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    name: str = "bottleneck_block",
):
    """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""

    def apply(x):
        """Applies the Mixer block to inputs."""
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        # input projection
        x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
        shortcut_long = x

        for i in range(num_groups):
            x = ResidualSplitHeadMultiAxisGmlpLayer(
                grid_size=grid_size,
                block_size=block_size,
                grid_gmlp_factor=grid_gmlp_factor,
                block_gmlp_factor=block_gmlp_factor,
                input_proj_factor=input_proj_factor,
                use_bias=use_bias,
                dropout_rate=dropout_rate,
                name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
            )(x)
            # Channel-mixing part, which provides within-patch communication.
            x = RDCAB(
                num_channels=features,
                reduction=channels_reduction,
                use_bias=use_bias,
                name=f"{name}_channel_attention_block_1_{i}",
            )(x)

        # long skip-connect
        x = x + shortcut_long
        return x

    return apply

In [18]:
def UNetEncoderBlock(
    num_channels: int,
    block_size,
    grid_size,
    num_groups: int = 1,
    lrelu_slope: float = 0.2,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    channels_reduction: int = 4,
    dropout_rate: float = 0.0,
    downsample: bool = True,
    use_global_mlp: bool = True,
    use_bias: bool = True,
    use_cross_gating: bool = False,
    name: str = "unet_encoder",
):
    """Encoder block in MAXIM."""

    def apply(x, skip=None, enc=None, dec=None):
        if skip is not None:
            x = tf.concat([x, skip], axis=-1)

        # convolution-in
        x = Conv1x1(filters=num_channels, use_bias=use_bias)(x)
        shortcut_long = x

        for i in range(num_groups):
            if use_global_mlp:
                x = ResidualSplitHeadMultiAxisGmlpLayer(
                    grid_size=grid_size,
                    block_size=block_size,
                    grid_gmlp_factor=grid_gmlp_factor,
                    block_gmlp_factor=block_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    use_bias=use_bias,
                    dropout_rate=dropout_rate,
                    name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
                )(x)
            x = RCAB(
                num_channels=num_channels,
                reduction=channels_reduction,
                use_bias=use_bias,
                name=f"{name}_channel_attention_block_1{i}",
            )(x)

        x = x + shortcut_long

        if enc is not None and dec is not None:
            x, _ = CrossGatingBlock(
                features=num_channels,
                block_size=block_size,
                grid_size=grid_size,
                dropout_rate=dropout_rate,
                input_proj_factor=input_proj_factor,
                upsample_y=False,
                use_bias=use_bias,
                name=f"{name}_cross_gating_block",
            )(x, enc + dec)

        if downsample:
            x_down = Conv_down(filters=num_channels, use_bias=use_bias)(x)
            return x_down, x
        else:
            return x

    return apply

In [19]:
def UNetDecoderBlock(
    num_channels: int,
    block_size,
    grid_size,
    num_groups: int = 1,
    lrelu_slope: float = 0.2,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    channels_reduction: int = 4,
    dropout_rate: float = 0.0,
    downsample: bool = True,
    use_global_mlp: bool = True,
    use_bias: bool = True,
    name: str = "unet_decoder",
):

    """Decoder block in MAXIM."""

    def apply(x, bridge=None):
        x = ConvT_up(filters=num_channels, use_bias=use_bias)(x)
        x = UNetEncoderBlock(
            num_channels=num_channels,
            num_groups=num_groups,
            lrelu_slope=lrelu_slope,
            block_size=block_size,
            grid_size=grid_size,
            block_gmlp_factor=block_gmlp_factor,
            grid_gmlp_factor=grid_gmlp_factor,
            channels_reduction=channels_reduction,
            use_global_mlp=use_global_mlp,
            dropout_rate=dropout_rate,
            downsample=False,
            use_bias=use_bias,
            name=f"{name}_unet_encoder",
        )(x, skip=bridge)

        return x

    return apply

In [20]:
def GetSpatialGatingWeights(
    features: int,
    block_size,
    grid_size,
    input_proj_factor: int = 2,
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    name: str = "spatial_gating",
):

    """Get gating weights for cross-gating MLP block."""

    def apply(x):
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )

        # input projection
        x = layers.LayerNormalization(name=f"{name}_LayerNorm_in")(x)
        x = layers.Dense(
            num_channels * input_proj_factor,
            use_bias=use_bias,
            name=f"{name}_in_project",
        )(x)
        x = tf.nn.gelu(x)
        u, v = tf.split(x, 2, axis=-1)

        # Get grid MLP weights
        gh, gw = grid_size
        fh, fw = h // gh, w // gw
        u = BlockImages()(u, patch_size=(fh, fw))
        dim_u = K.int_shape(u)[-3]
        u = SwapAxes()(u, -1, -3)
        u = layers.Dense(dim_u, use_bias=use_bias)(u)
        u = SwapAxes()(u, -1, -3)
        u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))

        # Get Block MLP weights
        fh, fw = block_size
        gh, gw = h // fh, w // fw
        v = BlockImages()(v, patch_size=(fh, fw))
        dim_v = v.shape[-2]
        v = SwapAxes()(v, -1, -2)
        v = layers.Dense(dim_v, use_bias=use_bias)(v)
        v = SwapAxes()(v, -1, -2)
        v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))

        x = tf.concat([u, v], axis=-1)
        x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
        x = layers.Dropout(dropout_rate)(x)
        return x

    return apply

In [21]:
def CrossGatingBlock(
    features: int,
    block_size,
    grid_size,
    dropout_rate: float = 0.0,
    input_proj_factor: int = 2,
    upsample_y: bool = True,
    use_bias: bool = True,
    name: str = "cross_gating",
):

    """Cross-gating MLP block."""

    def apply(x, y):
        # Upscale Y signal, y is the gating signal.
        if upsample_y:
            y = ConvT_up(filters=features, use_bias=use_bias)(y)

        x = Conv1x1(filters=features, use_bias=use_bias)(x)
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )

        y = Conv1x1(filters=num_channels, use_bias=use_bias)(y)

        shortcut_x = x
        shortcut_y = y

        # Get gating weights from X
        x = layers.LayerNormalization(name=f"{name}_LayerNorm_x")(x)
        x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(
            x
        )
        x = tf.nn.gelu(x)
        gx = GetSpatialGatingWeights(
            features=num_channels,
            block_size=block_size,
            grid_size=grid_size,
            dropout_rate=dropout_rate,
            use_bias=use_bias,
            name=f"{name}_SplitHeadMultiAxisGating_x",
        )(x)

        # Get gating weights from Y
        y = layers.LayerNormalization(name=f"{name}_LayerNorm_y")(y)
        y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(
            y
        )
        y = tf.nn.gelu(y)
        gy = GetSpatialGatingWeights(
            features=num_channels,
            block_size=block_size,
            grid_size=grid_size,
            dropout_rate=dropout_rate,
            use_bias=use_bias,
            name=f"{name}_SplitHeadMultiAxisGating_y",
        )(y)

        # Apply cross gating: X = X * GY, Y = Y * GX
        y = y * gx
        y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(
            y
        )
        y = layers.Dropout(dropout_rate)(y)
        y = y + shortcut_y

        x = x * gy  # gating x using y
        x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(
            x
        )
        x = layers.Dropout(dropout_rate)(x)
        x = x + y + shortcut_x  # get all aggregated signals
        return x, y

    return apply

In [22]:
def SAM(
    num_channels: int,
    output_channels: int = 3,
    use_bias: bool = True,
    name: str = "sam",
):

    """Supervised attention module for multi-stage training.
    Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
    """

    def apply(x, x_image):
        """Apply the SAM module to the input and num_channels.
        Args:
          x: the output num_channels from UNet decoder with shape (h, w, c)
          x_image: the input image with shape (h, w, 3)
        Returns:
          A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
            next stage, and (image) is the output restored image at current stage.
        """
        # Get num_channels
        x1 = Conv3x3(filters=num_channels, use_bias=use_bias)(x)

        # Output restored image X_s
        if output_channels == 3:
            image = Conv3x3(filters=output_channels, use_bias=use_bias)(x) + x_image
        else:
            image = Conv3x3(filters=output_channels, use_bias=use_bias)(x)

        # Get attention maps for num_channels
        x2 = tf.nn.sigmoid(Conv3x3(filters=num_channels, use_bias=use_bias)(image))

        # Get attended feature maps
        x1 = x1 * x2

        # Residual connection
        x1 = x1 + x
        return x1, image

    return apply

In [23]:
def MAXIM(
    features: int = 64,
    depth: int = 3,
    num_stages: int = 2,
    num_groups: int = 1,
    use_bias: bool = True,
    num_supervision_scales: int = 1,
    lrelu_slope: float = 0.2,
    use_global_mlp: bool = True,
    use_cross_gating: bool = True,
    high_res_stages: int = 2,
    block_size_hr=(16, 16),
    block_size_lr=(8, 8),
    grid_size_hr=(16, 16),
    grid_size_lr=(8, 8),
    num_bottleneck_blocks: int = 1,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    channels_reduction: int = 4,
    num_outputs: int = 3,
    dropout_rate: float = 0.0,
):
    """The MAXIM model function with multi-stage and multi-scale supervision.
    For more model details, please check the CVPR paper:
    MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)
    Attributes:
      features: initial hidden dimension for the input resolution.
      depth: the number of downsampling depth for the model.
      num_stages: how many stages to use. It will also affects the output list.
      num_groups: how many blocks each stage contains.
      use_bias: whether to use bias in all the conv/mlp layers.
      num_supervision_scales: the number of desired supervision scales.
      lrelu_slope: the negative slope parameter in leaky_relu layers.
      use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
        layer.
      use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
        skip connections and multi-stage feature fusion layers.
      high_res_stages: how many stages are specificied as high-res stages. The
        rest (depth - high_res_stages) are called low_res_stages.
      block_size_hr: the block_size parameter for high-res stages.
      block_size_lr: the block_size parameter for low-res stages.
      grid_size_hr: the grid_size parameter for high-res stages.
      grid_size_lr: the grid_size parameter for low-res stages.
      num_bottleneck_blocks: how many bottleneck blocks.
      block_gmlp_factor: the input projection factor for block_gMLP layers.
      grid_gmlp_factor: the input projection factor for grid_gMLP layers.
      input_proj_factor: the input projection factor for the MAB block.
      channels_reduction: the channel reduction factor for SE layer.
      num_outputs: the output channels.
      dropout_rate: Dropout rate.
    Returns:
      The output contains a list of arrays consisting of multi-stage multi-scale
      outputs. For example, if num_stages = num_supervision_scales = 3 (the
      model used in the paper), the output specs are: outputs =
      [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
       [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
       [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
      The final output can be retrieved by outputs[-1][-1].
    """

    def apply(x):
        n, h, w, c = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )  # input image shape

        shortcuts = []
        shortcuts.append(x)

        # Get multi-scale input images
        for i in range(1, num_supervision_scales):
            resizing_layer = layers.Resizing(
                height=h // (2**i), width=w // (2**i), method="nearest"
            )
            shortcuts.append(resizing_layer(x))

        # store outputs from all stages and all scales
        # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)],   # Stage-1 outputs
        #      [(64, 64, 3), (128, 128, 3), (256, 256, 3)],]  # Stage-2 outputs
        outputs_all = []
        sam_features, encs_prev, decs_prev = [], [], []

        for idx_stage in range(num_stages):
            # Input convolution, get multi-scale input features
            x_scales = []
            for i in range(num_supervision_scales):
                x_scale = Conv3x3(
                    filters=(2**i) * features,
                    use_bias=use_bias,
                    name=f"stage_{idx_stage}_input_conv_{i}",
                )(shortcuts[i])

                # If later stages, fuse input features with SAM features from prev stage
                if idx_stage > 0:
                    # use larger blocksize at high-res stages
                    if use_cross_gating:
                        block_size = (
                            block_size_hr if i < high_res_stages else block_size_lr
                        )
                        grid_size = (
                            grid_size_hr if i < high_res_stages else block_size_lr
                        )
                        x_scale, _ = CrossGatingBlock(
                            features=(2**i) * features,
                            block_size=block_size,
                            grid_size=grid_size,
                            dropout_rate=dropout_rate,
                            input_proj_factor=input_proj_factor,
                            upsample_y=False,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_input_fuse_sam_{i}",
                        )(x_scale, sam_features.pop())
                    else:
                        x_scale = Conv1x1(
                            filters=(2**i) * features,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_input_catconv_{i}",
                        )(tf.concat([x_scale, sam_features.pop()], axis=-1))

                x_scales.append(x_scale)

            # start encoder blocks
            encs = []
            x = x_scales[0]  # First full-scale input feature

            for i in range(depth):  # 0, 1, 2
                # use larger blocksize at high-res stages, vice versa.
                block_size = block_size_hr if i < high_res_stages else block_size_lr
                grid_size = grid_size_hr if i < high_res_stages else block_size_lr
                use_cross_gating_layer = True if idx_stage > 0 else False

                # Multi-scale input if multi-scale supervision
                x_scale = x_scales[i] if i < num_supervision_scales else None

                # UNet Encoder block
                enc_prev = encs_prev.pop() if idx_stage > 0 else None
                dec_prev = decs_prev.pop() if idx_stage > 0 else None

                x, bridge = UNetEncoderBlock(
                    num_channels=(2**i) * features,
                    num_groups=num_groups,
                    downsample=True,
                    lrelu_slope=lrelu_slope,
                    block_size=block_size,
                    grid_size=grid_size,
                    block_gmlp_factor=block_gmlp_factor,
                    grid_gmlp_factor=grid_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    channels_reduction=channels_reduction,
                    use_global_mlp=use_global_mlp,
                    dropout_rate=dropout_rate,
                    use_bias=use_bias,
                    use_cross_gating=use_cross_gating_layer,
                    name=f"stage_{idx_stage}_encoder_block_{i}",
                )(x, skip=x_scale, enc=enc_prev, dec=dec_prev)

                # Cache skip signals
                encs.append(bridge)

            # Global MLP bottleneck blocks
            for i in range(num_bottleneck_blocks):
                x = BottleneckBlock(
                    block_size=block_size_lr,
                    grid_size=block_size_lr,
                    features=(2 ** (depth - 1)) * features,
                    num_groups=num_groups,
                    block_gmlp_factor=block_gmlp_factor,
                    grid_gmlp_factor=grid_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    dropout_rate=dropout_rate,
                    use_bias=use_bias,
                    channels_reduction=channels_reduction,
                    name=f"stage_{idx_stage}_global_block_{i}",
                )(x)
            # cache global feature for cross-gating
            global_feature = x

            # start cross gating. Use multi-scale feature fusion
            skip_features = []
            for i in reversed(range(depth)):  # 2, 1, 0
                # use larger blocksize at high-res stages
                block_size = block_size_hr if i < high_res_stages else block_size_lr
                grid_size = grid_size_hr if i < high_res_stages else block_size_lr

                # get additional multi-scale signals
                signal = tf.concat(
                    [
                        UpSampleRatio(
                            num_channels=(2**i) * features,
                            ratio=2 ** (j - i),
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_upsample_ratio_{i}_{j}_encoder"
                        )(enc)
                        for j, enc in enumerate(encs)
                    ],
                    axis=-1,
                )

                # Use cross-gating to cross modulate features
                if use_cross_gating:
                    skips, global_feature = CrossGatingBlock(
                        features=(2**i) * features,
                        block_size=block_size,
                        grid_size=grid_size,
                        input_proj_factor=input_proj_factor,
                        dropout_rate=dropout_rate,
                        upsample_y=True,
                        use_bias=use_bias,
                        name=f"stage_{idx_stage}_cross_gating_block_{i}",
                    )(signal, global_feature)
                else:
                    skips = Conv1x1(filters=(2**i) * features, use_bias=use_bias)(
                        signal
                    )
                    skips = Conv3x3(filters=(2**i) * features, use_bias=use_bias)(
                        skips
                    )

                skip_features.append(skips)

            # start decoder. Multi-scale feature fusion of cross-gated features
            outputs, decs, sam_features = [], [], []
            for i in reversed(range(depth)):
                # use larger blocksize at high-res stages
                block_size = block_size_hr if i < high_res_stages else block_size_lr
                grid_size = grid_size_hr if i < high_res_stages else block_size_lr

                # get multi-scale skip signals from cross-gating block
                signal = tf.concat(
                    [
                        UpSampleRatio(
                            num_channels=(2**i) * features,
                            ratio=2 ** (depth - j - 1 - i),
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_upsample_ratio_{i}_{j}_decoder",
                        )(skip)
                        for j, skip in enumerate(skip_features)
                    ],
                    axis=-1,
                )

                # Decoder block
                x = UNetDecoderBlock(
                    num_channels=(2**i) * features,
                    num_groups=num_groups,
                    lrelu_slope=lrelu_slope,
                    block_size=block_size,
                    grid_size=grid_size,
                    block_gmlp_factor=block_gmlp_factor,
                    grid_gmlp_factor=grid_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    channels_reduction=channels_reduction,
                    use_global_mlp=use_global_mlp,
                    dropout_rate=dropout_rate,
                    use_bias=use_bias,
                    name=f"stage_{idx_stage}_decoder_block_{i}",
                )(x, bridge=signal)

                # Cache decoder features for later-stage's usage
                decs.append(x)

                # output conv, if not final stage, use supervised-attention-block.
                if i < num_supervision_scales:
                    if idx_stage < num_stages - 1:  # not last stage, apply SAM
                        sam, output = SAM(
                            num_channels=(2**i) * features,
                            output_channels=num_outputs,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_supervised_attention_module_{i}",
                        )(x, shortcuts[i])
                        outputs.append(output)
                        sam_features.append(sam)
                    else:  # Last stage, apply output convolutions
                        output = Conv3x3(
                            num_outputs,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_output_conv_{i}",
                        )(x)
                        output = output + shortcuts[i]
                        outputs.append(output)
            # Cache encoder and decoder features for later-stage's usage
            encs_prev = encs[::-1]
            decs_prev = decs

            # Store outputs
            outputs_all.append(outputs)
        return outputs_all

    return apply

In [24]:
def Model(variant=None, training=False, **kw):
    """Factory function to easily create a Model variant like "S".
    Every model file should have this Model() function that returns the flax
    model function. The function name should be fixed.
    Args:
      variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
          | 'M-1' | 'M-2' | 'M-3'
      training: Set it to False during inference. Optionally, modify this
          method to have finer controls over how layers within MAXIM should
          unfrozen.
      **kw: Other UNet config dicts.
    Returns:
      The MAXIM() model function
    """

    if variant is not None:
        config = {
            # params: 6.108515000000001 M, GFLOPS: 93.163716608
            "S-1": {
                "features": 32,
                "depth": 3,
                "num_stages": 1,
                "num_groups": 2,
                "num_bottleneck_blocks": 2,
                "block_gmlp_factor": 2,
                "grid_gmlp_factor": 2,
                "input_proj_factor": 2,
                "channels_reduction": 4,
                "name": "s1",
            },
            # params: 13.35383 M, GFLOPS: 206.743273472
            "S-2": {
                "features": 32,
                "depth": 3,
                "num_stages": 2,
                "num_groups": 2,
                "num_bottleneck_blocks": 2,
                "block_gmlp_factor": 2,
                "grid_gmlp_factor": 2,
                "input_proj_factor": 2,
                "channels_reduction": 4,
                "name": "s2",
            },
            # params: 20.599145 M, GFLOPS: 320.32194560000005
            "S-3": {
                "features": 32,
                "depth": 3,
                "num_stages": 3,
                "num_groups": 2,
                "num_bottleneck_blocks": 2,
                "block_gmlp_factor": 2,
                "grid_gmlp_factor": 2,
                "input_proj_factor": 2,
                "channels_reduction": 4,
                "name": "s3",
            },
            # params: 19.361219000000002 M, 308.495712256 GFLOPs
            "M-1": {
                "features": 64,
                "depth": 3,
                "num_stages": 1,
                "num_groups": 2,
                "num_bottleneck_blocks": 2,
                "block_gmlp_factor": 2,
                "grid_gmlp_factor": 2,
                "input_proj_factor": 2,
                "channels_reduction": 4,
                "name": "m1",
            },
            # params: 40.83911 M, 675.25541888 GFLOPs
            "M-2": {
                "features": 64,
                "depth": 3,
                "num_stages": 2,
                "num_groups": 2,
                "num_bottleneck_blocks": 2,
                "block_gmlp_factor": 2,
                "grid_gmlp_factor": 2,
                "input_proj_factor": 2,
                "channels_reduction": 4,
                "name": "m2",
            },
            # params: 62.317001 M, 1042.014666752 GFLOPs
            "M-3": {
                "features": 64,
                "depth": 3,
                "num_stages": 3,
                "num_groups": 2,
                "num_bottleneck_blocks": 2,
                "block_gmlp_factor": 2,
                "grid_gmlp_factor": 2,
                "input_proj_factor": 2,
                "channels_reduction": 4,
                "name": "m3",
            },
        }[variant]

        for k, v in config.items():
            if k != "name":
                kw.setdefault(k, v)

    inputs = keras.Input((256, 256, 3))
    maxim_model = MAXIM(**kw)
    outputs = maxim_model(inputs)
    final_model = keras.Model(inputs, outputs, name=f'{config["name"]}_model')

    return final_model

In [25]:
# def mod_padding_symmetric(image, factor=64):
#   """Padding the image to be divided by factor."""
#   height, width = image.shape[0], image.shape[1]
#   height_pad, width_pad = ((height + factor) // factor) * factor, (
#       (width + factor) // factor) * factor
#   padh = height_pad - height if height % factor != 0 else 0
#   padw = width_pad - width if width % factor != 0 else 0
#   image = np.pad(
#       image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)],
#       mode='reflect')
#   return image


# image = np.random.randn(512, 512, 3)
# mod_padding_symmetric(image).shape

In [26]:
maxim_s1 = Model(variant="S-2")

2022-09-30 15:30:36.165185: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [27]:
# dummy_inputs = tf.ones((1, 256, 256, 3))

# # Warmup
# print("Benchmarking TF model...")
# for _ in range(2):
#     _ = maxim_s1(dummy_inputs, training=False)

# # Timing
# tf_runtimes = timeit.repeat(
#     lambda: maxim_s1(dummy_inputs, training=False), number=1, repeat=10
# )
# print(f"Average latency (seconds): {np.mean(tf_runtimes)}.")

In [28]:
# @tf.function(jit_compile=True)
# def infer(x):
#     return maxim_s1(x, training=False)

In [29]:
# # Warmup
# print("Benchmarking Jit-compiled TF model...")
# for _ in range(2):
#     _ = infer(dummy_inputs)

# # Timing
# tf_runtimes = timeit.repeat(lambda: infer(dummy_inputs), number=1, repeat=10)
# print(f"Average latency (seconds): {np.mean(tf_runtimes)}.")

In [30]:
import collections
import io


ckpt_path = "gs://gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz"


def recover_tree(keys, values):
    """Recovers a tree as a nested dict from flat names and values.
    This function is useful to analyze checkpoints that are saved by our programs
    without need to access the exact source code of the experiment. In particular,
    it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
    subtree of parameters.
    Args:
      keys: a list of keys, where '/' is used as separator between nodes.
      values: a list of leaf values.
    Returns:
      A nested tree-like dict.
    """
    tree = {}
    sub_trees = collections.defaultdict(list)
    for k, v in zip(keys, values):
        if "/" not in k:
            tree[k] = v
        else:
            k_left, k_right = k.split("/", 1)
            sub_trees[k_left].append((k_right, v))
    for k, kv_pairs in sub_trees.items():
        k_subtree, v_subtree = zip(*kv_pairs)
        tree[k] = recover_tree(k_subtree, v_subtree)
    return tree


def get_params(ckpt_path):
    """Get params checkpoint."""

    with tf.io.gfile.GFile(ckpt_path, "rb") as f:
        data = f.read()
    values = np.load(io.BytesIO(data))
    params = recover_tree(*zip(*values.items()))
    params = params["opt"]["target"]

    return params


params = get_params(ckpt_path)

In [31]:
params.keys()

dict_keys(['UpSampleRatio_0', 'UpSampleRatio_1', 'UpSampleRatio_10', 'UpSampleRatio_11', 'UpSampleRatio_12', 'UpSampleRatio_13', 'UpSampleRatio_14', 'UpSampleRatio_15', 'UpSampleRatio_16', 'UpSampleRatio_17', 'UpSampleRatio_18', 'UpSampleRatio_19', 'UpSampleRatio_2', 'UpSampleRatio_20', 'UpSampleRatio_21', 'UpSampleRatio_22', 'UpSampleRatio_23', 'UpSampleRatio_24', 'UpSampleRatio_25', 'UpSampleRatio_26', 'UpSampleRatio_27', 'UpSampleRatio_28', 'UpSampleRatio_29', 'UpSampleRatio_3', 'UpSampleRatio_30', 'UpSampleRatio_31', 'UpSampleRatio_32', 'UpSampleRatio_33', 'UpSampleRatio_34', 'UpSampleRatio_35', 'UpSampleRatio_36', 'UpSampleRatio_37', 'UpSampleRatio_38', 'UpSampleRatio_39', 'UpSampleRatio_4', 'UpSampleRatio_40', 'UpSampleRatio_41', 'UpSampleRatio_42', 'UpSampleRatio_43', 'UpSampleRatio_44', 'UpSampleRatio_45', 'UpSampleRatio_46', 'UpSampleRatio_47', 'UpSampleRatio_48', 'UpSampleRatio_49', 'UpSampleRatio_5', 'UpSampleRatio_50', 'UpSampleRatio_51', 'UpSampleRatio_52', 'UpSampleRatio_

In [32]:
params["UpSampleRatio_0"]["Conv_0"]["kernel"].shape

(1, 1, 32, 128)

In [33]:
def get_model_vars(model):
    model_variables = model.variables
    model_variables_dict = {}
    for v in model_variables:
        model_variables_dict[v.name] = v

    return model_variables_dict

In [34]:
_ = maxim_s1(tf.ones((1, 256, 256, 3)), training=False)

model_variables_dict = get_model_vars(maxim_s1)
model_variables_dict.keys()

dict_keys(['stage_0_input_conv_0/kernel:0', 'stage_0_input_conv_0/bias:0', 'conv2d/kernel:0', 'conv2d/bias:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_LayerNorm_in/gamma:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_LayerNorm_in/beta:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_in_project/kernel:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_in_project/bias:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_GridGmlpLayer_LayerNorm/gamma:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_GridGmlpLayer_LayerNorm/beta:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_BlockGmlpLayer_LayerNorm/gamma:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_BlockGmlpLayer_LayerNorm/beta:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_GridGmlpLayer_in_project/kernel:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_GridGmlpLayer_in_project/bias:0', 'stage_0_encoder_block_0_SplitHeadMultiAxisGml

In [36]:
for k in model_variables_dict.keys():
    if "upsample" in k:
        print(k, model_variables_dict[k].shape)
        break

stage_0_upsample_ratio_2_0_encoder_point_conv/kernel:0 (1, 1, 32, 128)
