In [1]:
import sys

sys.path.append("..")

In [2]:
from create_maxim_model import Model
from maxim.configs import MAXIM_CONFIGS

2022-10-09 13:18:28.774098: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py#L55

VARIANT = "S-3"
configs = MAXIM_CONFIGS.get(VARIANT)

configs.update(
    {
        "variant": VARIANT,
        "dropout_rate": 0.0,
        "num_outputs": 3,
        "use_bias": True,
        "num_supervision_scales": 3,
    }
)

In [4]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py

import tensorflow as tf
import pandas as pd
import numpy as np
import collections 
import re
import io

tf.keras.utils.set_random_seed(2022)


def recover_tree(keys, values):
    """Recovers a tree as a nested dict from flat names and values.
    This function is useful to analyze checkpoints that are saved by our programs
    without need to access the exact source code of the experiment. In particular,
    it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
    subtree of parameters.
    Args:
      keys: a list of keys, where '/' is used as separator between nodes.
      values: a list of leaf values.
    Returns:
      A nested tree-like dict.
    """
    tree = {}
    sub_trees = collections.defaultdict(list)
    for k, v in zip(keys, values):
        if "/" not in k:
            tree[k] = v
        else:
            k_left, k_right = k.split("/", 1)
            sub_trees[k_left].append((k_right, v))
    for k, kv_pairs in sub_trees.items():
        k_subtree, v_subtree = zip(*kv_pairs)
        tree[k] = recover_tree(k_subtree, v_subtree)
    return tree


def get_params(ckpt_path):
    """Get params checkpoint."""

    with tf.io.gfile.GFile(ckpt_path, "rb") as f:
        data = f.read()
    values = np.load(io.BytesIO(data))
    params = recover_tree(*zip(*values.items()))
    params = params["opt"]["target"]

    return params

```
gamma (TF) => scale (JAX)
beta (TF) => bias (JAX)
```

In [5]:
# https://stackoverflow.com/questions/5491913/sorting-list-in-python
def sort_nicely(l): 
    """ Sort the given iterable in the way that humans expect.""" 
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)

In [6]:
def modify_upsample(jax_params):
    modified_jax_params = collections.OrderedDict()

    jax_keys = list(jax_params.keys())
    keys_upsampling = []
    for k in range(len(jax_keys)):
        if "UpSample" in jax_keys[k]:
            keys_upsampling.append(jax_keys[k])
    sorted_keys_upsampling = sort_nicely(keys_upsampling)
    
    i = 1
    for k in sorted_keys_upsampling:
        k_t = k.split("_")[0] + "_" + str(i)
        i += 1
        for j in jax_params[k]:
            for l in jax_params[k][j]:
                modified_param_name = f"{k_t}_{j}/{l}:0"
                params = jax_params[k][j][l]
                modified_jax_params.update({modified_param_name: params})

    return modified_jax_params


def modify_jax_params(jax_params):
    modified_jax_params = collections.OrderedDict()

    for k in jax_params:
        if "UpSample" not in k:
            params = jax_params[k]

            if ("ConvTranspose" in k) and ("bias" not in k):
                params = params.transpose(0, 1, 3, 2)

            split_names = k.split("_")
            modified_param_name = (
                "_".join(split_names[0:-1]) + "/" + split_names[-1] + ":0"
            )

            if "layernorm" in modified_param_name.lower():
                if "scale" in modified_param_name:
                    modified_param_name = modified_param_name.replace("scale", "gamma")
                elif "bias" in modified_param_name:
                    modified_param_name = modified_param_name.replace("bias", "beta")
                
            modified_jax_params.update({modified_param_name: params})

    return modified_jax_params

In [7]:
def port_jax_params(configs, ckpt_path):
    # Initialize TF Model.
    tf_model = Model(**configs)
    
    # Obtain a mapping of the TF variable names and their values.
    tf_model_variables = tf_model.variables
    tf_model_variables_dict = {}
    for v in tf_model_variables:
        tf_model_variables_dict[v.name] = v
        
    # Obtain the JAX pre-trained variables.
    jax_params = get_params(ckpt_path)
    [flat_jax_dict] = pd.json_normalize(jax_params, sep="_").to_dict(orient="records")
    
    # Amend the JAX variables to match the names of the TF variables.
    modified_jax_params = modify_jax_params(flat_jax_dict)
    modified_jax_params.update(modify_upsample(jax_params))
    
    # Porting.
    tf_weights = []
    i = 0

    for k in modified_jax_params:
        param = modified_jax_params[k]
        tf_weights.append((tf_model_variables_dict[k], param))
        i += 1

    assert i == len(modified_jax_params) == len(tf_model_variables_dict)

    tf.keras.backend.batch_set_value(tf_weights)

    return modified_jax_params, tf_model

In [8]:
_, tf_model = port_jax_params(
    configs, "gs://gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz"
)

2022-10-09 13:19:01.445333: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "INVALID_ARGUMENT: Error executing an HTTP request: HTTP response code 400 with body '{
  "error": "invalid_grant",
  "error_description": "Bad Request"
}'". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".


In [9]:
preds = tf_model.predict(np.ones((1, 256, 256, 3)))



In [10]:
if isinstance(preds, list):
    preds = preds[-1]
    if isinstance(preds, list):
        preds = preds[-1]
        
preds = np.array(preds[0], np.float32)
print(f"Predictions: {preds.shape, preds[0, :3]}")

Predictions: ((256, 256, 3), array([[0.979627  , 0.9843435 , 0.97292805],
       [1.0002174 , 0.99844694, 0.9940431 ],
       [0.99668646, 0.9964776 , 0.9939418 ]], dtype=float32))


In [11]:
tf_model.save_weights("denoising_sidd.h5")

In [12]:
# np.testing.assert_allclose(
#     preds[0, :3],
#     np.array(
#         [
#             [1.0001332, 1.0020351, 0.99739677],
#             [0.999209, 1.0013864, 0.99625367],
#             [1.0003445, 1.0004236, 0.996228],
#         ]
#     ),
# )

## Check if the params were successfully ported by running assertions

In [13]:
# Obtain the JAX pre-trained variables.
ckpt_path = "gs://gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz"
jax_params = get_params(ckpt_path)
[flat_jax_dict] = pd.json_normalize(jax_params, sep="_").to_dict(orient="records")

# Amend the JAX variables to match the names of the TF variables.
modified_jax_params = modify_jax_params(flat_jax_dict)
modified_jax_params.update(modify_upsample(jax_params))

In [14]:
modified_tf_params = tf_model.variables
tf_model_variables_dict = {}
for v in modified_tf_params:
    tf_model_variables_dict[v.name] = v

In [15]:
unmatched_params = []

for k in modified_jax_params:
    jax_params = modified_jax_params[k]
    tf_params = tf_model_variables_dict[k].numpy()

    try:
        np.testing.assert_allclose(jax_params, tf_params)
    except:
        unmatched_params.append(k)

In [16]:
unmatched_params

[]

## JAX implementation of MAXIM

In [17]:
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main file for the MAXIM model."""

import functools
from typing import Any, Sequence, Tuple

import einops
import flax.linen as nn
import jax
import jax.numpy as jnp


Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3))
Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1))
ConvT_up = functools.partial(nn.ConvTranspose,
                             kernel_size=(2, 2),
                             strides=(2, 2))
Conv_down = functools.partial(nn.Conv,
                              kernel_size=(4, 4),
                              strides=(2, 2))

weight_initializer = nn.initializers.normal(stddev=2e-2)


class MlpBlock(nn.Module):
  """A 1-hidden-layer MLP block, applied over the last dimension."""
  mlp_dim: int
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, d = x.shape
    x = nn.Dense(self.mlp_dim, use_bias=self.use_bias,
                 kernel_init=weight_initializer)(x)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
    x = nn.Dense(d, use_bias=self.use_bias,
                 kernel_init=weight_initializer)(x)
    return x


def block_images_einops(x, patch_size):
  """Image to patches."""
  batch, height, width, channels = x.shape
  grid_height = height // patch_size[0]
  grid_width = width // patch_size[1]
  x = einops.rearrange(
      x, "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
      gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1])
  return x


def unblock_images_einops(x, grid_size, patch_size):
  """patches to images."""
  x = einops.rearrange(
      x, "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
      gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1])
  return x


class UpSampleRatio(nn.Module):
  """Upsample features given a ratio > 0."""
  features: int
  ratio: float
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    n, h, w, c = x.shape
    x = jax.image.resize(
        x,
        shape=(n, int(h * self.ratio), int(w * self.ratio), c),
        method="bilinear")
    x = Conv1x1(features=self.features, use_bias=self.use_bias)(x)
    return x


class CALayer(nn.Module):
  """Squeeze-and-excitation block for channel attention.

  ref: https://arxiv.org/abs/1709.01507
  """
  features: int
  reduction: int = 4
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    # 2D global average pooling
    y = jnp.mean(x, axis=[1, 2], keepdims=True)
    # Squeeze (in Squeeze-Excitation)
    y = Conv1x1(self.features // self.reduction, use_bias=self.use_bias)(y)
    y = nn.relu(y)
    # Excitation (in Squeeze-Excitation)
    y = Conv1x1(self.features, use_bias=self.use_bias)(y)
    y = nn.sigmoid(y)
    return x * y


class RCAB(nn.Module):
  """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
  features: int
  reduction: int = 4
  lrelu_slope: float = 0.2
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    shortcut = x
    x = nn.LayerNorm(name="LayerNorm")(x)
    x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv1")(x)
    x = nn.leaky_relu(x, negative_slope=self.lrelu_slope)
    x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv2")(x)
    x = CALayer(features=self.features, reduction=self.reduction,
                use_bias=self.use_bias, name="channel_attention")(x)
    return x + shortcut


class GridGatingUnit(nn.Module):
  """A SpatialGatingUnit as defined in the gMLP paper.

  The 'spatial' dim is defined as the second last.
  If applied on other dims, you should swapaxes first.
  """
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    u, v = jnp.split(x, 2, axis=-1)
    v = nn.LayerNorm(name="intermediate_layernorm")(v)
    n = x.shape[-3]   # get spatial dim
    v = jnp.swapaxes(v, -1, -3)
    v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer)(v)
    v = jnp.swapaxes(v, -1, -3)
    return u * (v + 1.)


class GridGmlpLayer(nn.Module):
  """Grid gMLP layer that performs global mixing of tokens."""
  grid_size: Sequence[int]
  use_bias: bool = True
  factor: int = 2
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, num_channels = x.shape
    gh, gw = self.grid_size
    fh, fw = h // gh, w // gw
    x = block_images_einops(x, patch_size=(fh, fw))
    # gMLP1: Global (grid) mixing part, provides global grid communication.
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="in_project")(y)
    y = nn.gelu(y)
    y = GridGatingUnit(use_bias=self.use_bias, name="GridGatingUnit")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="out_project")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic)
    x = x + y
    x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw))
    return x


class BlockGatingUnit(nn.Module):
  """A SpatialGatingUnit as defined in the gMLP paper.

  The 'spatial' dim is defined as the **second last**.
  If applied on other dims, you should swapaxes first.
  """
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    u, v = jnp.split(x, 2, axis=-1)
    v = nn.LayerNorm(name="intermediate_layernorm")(v)
    n = x.shape[-2]  # get spatial dim
    v = jnp.swapaxes(v, -1, -2)
    v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer)(v)
    v = jnp.swapaxes(v, -1, -2)
    return u * (v + 1.)


class BlockGmlpLayer(nn.Module):
  """Block gMLP layer that performs local mixing of tokens."""
  block_size: Sequence[int]
  use_bias: bool = True
  factor: int = 2
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, num_channels = x.shape
    fh, fw = self.block_size
    gh, gw = h // fh, w // fw
    x = block_images_einops(x, patch_size=(fh, fw))
    # MLP2: Local (block) mixing part, provides within-block communication.
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="in_project")(y)
    y = nn.gelu(y)
    y = BlockGatingUnit(use_bias=self.use_bias, name="BlockGatingUnit")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="out_project")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic)
    x = x + y
    x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw))
    return x


class ResidualSplitHeadMultiAxisGmlpLayer(nn.Module):
  """The multi-axis gated MLP block."""
  block_size: Sequence[int]
  grid_size: Sequence[int]
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  use_bias: bool = True
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    shortcut = x
    n, h, w, num_channels = x.shape
    x = nn.LayerNorm(name="LayerNorm_in")(x)
    x = nn.Dense(num_channels * self.input_proj_factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="in_project")(x)
    x = nn.gelu(x)

    u, v = jnp.split(x, 2, axis=-1)
    # GridGMLPLayer
    u = GridGmlpLayer(
        grid_size=self.grid_size,
        factor=self.grid_gmlp_factor,
        use_bias=self.use_bias,
        dropout_rate=self.dropout_rate,
        name="GridGmlpLayer")(u, deterministic)

    # BlockGMLPLayer
    v = BlockGmlpLayer(
        block_size=self.block_size,
        factor=self.block_gmlp_factor,
        use_bias=self.use_bias,
        dropout_rate=self.dropout_rate,
        name="BlockGmlpLayer")(v, deterministic)

    x = jnp.concatenate([u, v], axis=-1)

    x = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="out_project")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic)
    x = x + shortcut
    return x


class RDCAB(nn.Module):
  """Residual dense channel attention block. Used in Bottlenecks."""
  features: int
  reduction: int = 16
  use_bias: bool = True
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = MlpBlock(
        mlp_dim=self.features,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="channel_mixing")(
            y, deterministic=deterministic)
    y = CALayer(
        features=self.features,
        reduction=self.reduction,
        use_bias=self.use_bias,
        name="channel_attention")(
            y)
    x = x + y
    return x


class BottleneckBlock(nn.Module):
  """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic):
    """Applies the Mixer block to inputs."""
    assert x.ndim == 4  # Input has shape [batch, h, w, c]
    n, h, w, num_channels = x.shape

    # input projection
    x = Conv1x1(self.features, use_bias=self.use_bias, name="input_proj")(x)
    shortcut_long = x

    for i in range(self.num_groups):
      x = ResidualSplitHeadMultiAxisGmlpLayer(
          grid_size=self.grid_size,
          block_size=self.block_size,
          grid_gmlp_factor=self.grid_gmlp_factor,
          block_gmlp_factor=self.block_gmlp_factor,
          input_proj_factor=self.input_proj_factor,
          use_bias=self.use_bias,
          dropout_rate=self.dropout_rate,
          name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic)
      # Channel-mixing part, which provides within-patch communication.
      x = RDCAB(
          features=self.features,
          reduction=self.channels_reduction,
          use_bias=self.use_bias,
          name=f"channel_attention_block_1_{i}")(
              x)

    # long skip-connect
    x = x + shortcut_long
    return x


class UNetEncoderBlock(nn.Module):
  """Encoder block in MAXIM."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  lrelu_slope: float = 0.2
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  downsample: bool = True
  use_global_mlp: bool = True
  use_bias: bool = True
  use_cross_gating: bool = False

  @nn.compact
  def __call__(self, x: jnp.ndarray, skip: jnp.ndarray = None,
               enc: jnp.ndarray = None, dec: jnp.ndarray = None, *,
               deterministic: bool = True) -> jnp.ndarray:
    if skip is not None:
      x = jnp.concatenate([x, skip], axis=-1)

    # convolution-in
    x = Conv1x1(self.features, use_bias=self.use_bias)(x)
    shortcut_long = x

    for i in range(self.num_groups):
      if self.use_global_mlp:
        x = ResidualSplitHeadMultiAxisGmlpLayer(
            grid_size=self.grid_size,
            block_size=self.block_size,
            grid_gmlp_factor=self.grid_gmlp_factor,
            block_gmlp_factor=self.block_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            use_bias=self.use_bias,
            dropout_rate=self.dropout_rate,
            name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic)
      x = RCAB(
          features=self.features,
          reduction=self.channels_reduction,
          use_bias=self.use_bias,
          name=f"channel_attention_block_1{i}")(x)

    x = x + shortcut_long

    if enc is not None and dec is not None:
      assert self.use_cross_gating
      x, _ = CrossGatingBlock(
          features=self.features,
          block_size=self.block_size,
          grid_size=self.grid_size,
          dropout_rate=self.dropout_rate,
          input_proj_factor=self.input_proj_factor,
          upsample_y=False,
          use_bias=self.use_bias,
          name="cross_gating_block")(
              x, enc + dec, deterministic=deterministic)

    if self.downsample:
      x_down = Conv_down(self.features, use_bias=self.use_bias)(x)
      return x_down, x
    else:
      return x


class UNetDecoderBlock(nn.Module):
  """Decoder block in MAXIM."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  lrelu_slope: float = 0.2
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  downsample: bool = True
  use_global_mlp: bool = True
  use_bias: bool = True

  @nn.compact
  def __call__(self, x: jnp.ndarray, bridge: jnp.ndarray = None,
               deterministic: bool = True) -> jnp.ndarray:
    x = ConvT_up(self.features, use_bias=self.use_bias)(x)

    x = UNetEncoderBlock(
        self.features,
        num_groups=self.num_groups,
        lrelu_slope=self.lrelu_slope,
        block_size=self.block_size,
        grid_size=self.grid_size,
        block_gmlp_factor=self.block_gmlp_factor,
        grid_gmlp_factor=self.grid_gmlp_factor,
        channels_reduction=self.channels_reduction,
        use_global_mlp=self.use_global_mlp,
        dropout_rate=self.dropout_rate,
        downsample=False,
        use_bias=self.use_bias)(x, skip=bridge, deterministic=deterministic)
    return x


class GetSpatialGatingWeights(nn.Module):
  """Get gating weights for cross-gating MLP block."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  input_proj_factor: int = 2
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic):
    n, h, w, num_channels = x.shape

    # input projection
    x = nn.LayerNorm(name="LayerNorm_in")(x)
    x = nn.Dense(
        num_channels * self.input_proj_factor,
        use_bias=self.use_bias,
        name="in_project")(
            x)
    x = nn.gelu(x)
    u, v = jnp.split(x, 2, axis=-1)

    # Get grid MLP weights
    gh, gw = self.grid_size
    fh, fw = h // gh, w // gw
    u = block_images_einops(u, patch_size=(fh, fw))
    dim_u = u.shape[-3]
    u = jnp.swapaxes(u, -1, -3)
    u = nn.Dense(
        dim_u, use_bias=self.use_bias, kernel_init=nn.initializers.normal(2e-2),
        bias_init=nn.initializers.ones)(u)
    u = jnp.swapaxes(u, -1, -3)
    u = unblock_images_einops(u, grid_size=(gh, gw), patch_size=(fh, fw))

    # Get Block MLP weights
    fh, fw = self.block_size
    gh, gw = h // fh, w // fw
    v = block_images_einops(v, patch_size=(fh, fw))
    dim_v = v.shape[-2]
    v = jnp.swapaxes(v, -1, -2)
    v = nn.Dense(
        dim_v, use_bias=self.use_bias, kernel_init=nn.initializers.normal(2e-2),
        bias_init=nn.initializers.ones)(v)
    v = jnp.swapaxes(v, -1, -2)
    v = unblock_images_einops(v, grid_size=(gh, gw), patch_size=(fh, fw))

    x = jnp.concatenate([u, v], axis=-1)
    x = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic)
    return x


class CrossGatingBlock(nn.Module):
  """Cross-gating MLP block."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  dropout_rate: float = 0.0
  input_proj_factor: int = 2
  upsample_y: bool = True
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, y, deterministic=True):
    # Upscale Y signal, y is the gating signal.
    if self.upsample_y:
      y = ConvT_up(self.features, use_bias=self.use_bias)(y)

    x = Conv1x1(self.features, use_bias=self.use_bias)(x)
    n, h, w, num_channels = x.shape
    y = Conv1x1(num_channels, use_bias=self.use_bias)(y)

    assert y.shape == x.shape
    shortcut_x = x
    shortcut_y = y

    # Get gating weights from X
    x = nn.LayerNorm(name="LayerNorm_x")(x)
    x = nn.Dense(num_channels, use_bias=self.use_bias, name="in_project_x")(x)
    x = nn.gelu(x)
    gx = GetSpatialGatingWeights(
        features=num_channels,
        block_size=self.block_size,
        grid_size=self.grid_size,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="SplitHeadMultiAxisGating_x")(
            x, deterministic=deterministic)

    # Get gating weights from Y
    y = nn.LayerNorm(name="LayerNorm_y")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias, name="in_project_y")(y)
    y = nn.gelu(y)
    gy = GetSpatialGatingWeights(
        features=num_channels,
        block_size=self.block_size,
        grid_size=self.grid_size,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="SplitHeadMultiAxisGating_y")(
            y, deterministic=deterministic)

    # Apply cross gating: X = X * GY, Y = Y * GX
    y = y * gx
    y = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project_y")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic=deterministic)
    y = y + shortcut_y

    x = x * gy  # gating x using y
    x = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project_x")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic)
    x = x + y + shortcut_x  # get all aggregated signals
    return x, y


class SAM(nn.Module):
  """Supervised attention module for multi-stage training.

  Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
  """
  features: int
  output_channels: int = 3
  use_bias: bool = True

  @nn.compact
  def __call__(self, x: jnp.ndarray, x_image: jnp.ndarray, *,
               train: bool) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Apply the SAM module to the input and features.

    Args:
      x: the output features from UNet decoder with shape (h, w, c)
      x_image: the input image with shape (h, w, 3)
      train: Whether it is training

    Returns:
      A tuple of tensors (x1, image) where (x1) is the sam features used for the
        next stage, and (image) is the output restored image at current stage.
    """
    # Get features
    x1 = Conv3x3(self.features, use_bias=self.use_bias)(x)

    # Output restored image X_s
    if self.output_channels == 3:
      image = Conv3x3(self.output_channels, use_bias=self.use_bias)(x) + x_image
    else:
      image = Conv3x3(self.output_channels, use_bias=self.use_bias)(x)

    # Get attention maps for features
    x2 = nn.sigmoid(Conv3x3(self.features, use_bias=self.use_bias)(image))

    # Get attended feature maps
    x1 = x1 * x2

    # Residual connection
    x1 = x1 + x
    return x1, image


class MAXIM(nn.Module):
  """The MAXIM model function with multi-stage and multi-scale supervision.

  For more model details, please check the CVPR paper:
  MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)

  Attributes:
    features: initial hidden dimension for the input resolution.
    depth: the number of downsampling depth for the model.
    num_stages: how many stages to use. It will also affects the output list.
    num_groups: how many blocks each stage contains.
    use_bias: whether to use bias in all the conv/mlp layers.
    num_supervision_scales: the number of desired supervision scales.
    lrelu_slope: the negative slope parameter in leaky_relu layers.
    use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
      layer.
    use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
      skip connections and multi-stage feature fusion layers.
    high_res_stages: how many stages are specificied as high-res stages. The
      rest (depth - high_res_stages) are called low_res_stages.
    block_size_hr: the block_size parameter for high-res stages.
    block_size_lr: the block_size parameter for low-res stages.
    grid_size_hr: the grid_size parameter for high-res stages.
    grid_size_lr: the grid_size parameter for low-res stages.
    num_bottleneck_blocks: how many bottleneck blocks.
    block_gmlp_factor: the input projection factor for block_gMLP layers.
    grid_gmlp_factor: the input projection factor for grid_gMLP layers.
    input_proj_factor: the input projection factor for the MAB block.
    channels_reduction: the channel reduction factor for SE layer.
    num_outputs: the output channels.
    dropout_rate: Dropout rate.

  Returns:
    The output contains a list of arrays consisting of multi-stage multi-scale
    outputs. For example, if num_stages = num_supervision_scales = 3 (the
    model used in the paper), the output specs are: outputs =
    [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
     [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
     [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
    The final output can be retrieved by outputs[-1][-1].
  """
  features: int = 64
  depth: int = 3
  num_stages: int = 2
  num_groups: int = 1
  use_bias: bool = True
  num_supervision_scales: int = 1
  lrelu_slope: float = 0.2
  use_global_mlp: bool = True
  use_cross_gating: bool = True
  high_res_stages: int = 2
  block_size_hr: Sequence[int] = (16, 16)
  block_size_lr: Sequence[int] = (8, 8)
  grid_size_hr: Sequence[int] = (16, 16)
  grid_size_lr: Sequence[int] = (8, 8)
  num_bottleneck_blocks: int = 1
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  num_outputs: int = 3
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x: jnp.ndarray, *, train: bool = False) -> Any:

    n, h, w, c = x.shape  # input image shape
    shortcuts = []
    shortcuts.append(x)
    # Get multi-scale input images
    for i in range(1, self.num_supervision_scales):
      shortcuts.append(jax.image.resize(
          x, shape=(n, h // (2**i), w // (2**i), c), method="nearest"))

    # store outputs from all stages and all scales
    # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)],   # Stage-1 outputs
    #      [(64, 64, 3), (128, 128, 3), (256, 256, 3)],]  # Stage-2 outputs
    outputs_all = []
    sam_features, encs_prev, decs_prev = [], [], []

    for idx_stage in range(self.num_stages):
      # Input convolution, get multi-scale input features
      x_scales = []
      for i in range(self.num_supervision_scales):
        x_scale = Conv3x3(
            (2**i) * self.features,
            use_bias=self.use_bias,
            name=f"stage_{idx_stage}_input_conv_{i}")(
                shortcuts[i])

        # If later stages, fuse input features with SAM features from prev stage
        if idx_stage > 0:
          # use larger blocksize at high-res stages
          if self.use_cross_gating:
            block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
            grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr
            x_scale, _ = CrossGatingBlock(
                features=(2**i) * self.features,
                block_size=block_size,
                grid_size=grid_size,
                dropout_rate=self.dropout_rate,
                input_proj_factor=self.input_proj_factor,
                upsample_y=False,
                use_bias=self.use_bias,
                name=f"stage_{idx_stage}_input_fuse_sam_{i}")(
                    x_scale, sam_features.pop(), deterministic=not train)
          else:
            x_scale = Conv1x1(
                (2**i) * self.features,
                use_bias=self.use_bias,
                name=f"stage_{idx_stage}_input_catconv_{i}")(
                    jnp.concatenate(
                        [x_scale, sam_features.pop()], axis=-1))

        x_scales.append(x_scale)

      # start encoder blocks
      encs = []
      x = x_scales[0]  # First full-scale input feature

      for i in range(self.depth):  # 0, 1, 2
        # use larger blocksize at high-res stages, vice versa.
        block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
        grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr
        use_cross_gating_layer = True if idx_stage > 0 else False

        # Multi-scale input if multi-scale supervision
        x_scale = x_scales[i] if i < self.num_supervision_scales else None

        # UNet Encoder block
        enc_prev = encs_prev.pop() if idx_stage > 0 else None
        dec_prev = decs_prev.pop() if idx_stage > 0 else None

        x, bridge = UNetEncoderBlock(
            features=(2**i) * self.features,
            num_groups=self.num_groups,
            downsample=True,
            lrelu_slope=self.lrelu_slope,
            block_size=block_size,
            grid_size=grid_size,
            block_gmlp_factor=self.block_gmlp_factor,
            grid_gmlp_factor=self.grid_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            channels_reduction=self.channels_reduction,
            use_global_mlp=self.use_global_mlp,
            dropout_rate=self.dropout_rate,
            use_bias=self.use_bias,
            use_cross_gating=use_cross_gating_layer,
            name=f"stage_{idx_stage}_encoder_block_{i}")(
                x,
                skip=x_scale,
                enc=enc_prev,
                dec=dec_prev,
                deterministic=not train)

        # Cache skip signals
        encs.append(bridge)

      # Global MLP bottleneck blocks
      for i in range(self.num_bottleneck_blocks):
        x = BottleneckBlock(
            block_size=self.block_size_lr,
            grid_size=self.block_size_lr,
            features=(2**(self.depth - 1)) * self.features,
            num_groups=self.num_groups,
            block_gmlp_factor=self.block_gmlp_factor,
            grid_gmlp_factor=self.grid_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            dropout_rate=self.dropout_rate,
            use_bias=self.use_bias,
            channels_reduction=self.channels_reduction,
            name=f"stage_{idx_stage}_global_block_{i}")(
                x, deterministic=not train)
      # cache global feature for cross-gating
      global_feature = x

      # start cross gating. Use multi-scale feature fusion
      skip_features = []
      for i in reversed(range(self.depth)):  # 2, 1, 0
        # use larger blocksize at high-res stages
        block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
        grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr

        # get additional multi-scale signals
        signal = jnp.concatenate([
            UpSampleRatio(
                (2**i) * self.features,
                ratio=2**(j - i),
                use_bias=self.use_bias)(enc) for j, enc in enumerate(encs)
        ],
                                 axis=-1)

        # Use cross-gating to cross modulate features
        if self.use_cross_gating:
          skips, global_feature = CrossGatingBlock(
              features=(2**i) * self.features,
              block_size=block_size,
              grid_size=grid_size,
              input_proj_factor=self.input_proj_factor,
              dropout_rate=self.dropout_rate,
              upsample_y=True,
              use_bias=self.use_bias,
              name=f"stage_{idx_stage}_cross_gating_block_{i}")(
                  signal, global_feature, deterministic=not train)
        else:
          skips = Conv1x1(
              (2**i) * self.features, use_bias=self.use_bias)(
                  signal)
          skips = Conv3x3((2**i) * self.features, use_bias=self.use_bias)(skips)

        skip_features.append(skips)

      # start decoder. Multi-scale feature fusion of cross-gated features
      outputs, decs, sam_features = [], [], []
      for i in reversed(range(self.depth)):
        # use larger blocksize at high-res stages
        block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
        grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr

        # get multi-scale skip signals from cross-gating block
        signal = jnp.concatenate([
            UpSampleRatio(
                (2**i) * self.features,
                ratio=2**(self.depth - j - 1 - i),
                use_bias=self.use_bias)(skip)
            for j, skip in enumerate(skip_features)
        ],
                                 axis=-1)

        # Decoder block
        x = UNetDecoderBlock(
            features=(2**i) * self.features,
            num_groups=self.num_groups,
            lrelu_slope=self.lrelu_slope,
            block_size=block_size,
            grid_size=grid_size,
            block_gmlp_factor=self.block_gmlp_factor,
            grid_gmlp_factor=self.grid_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            channels_reduction=self.channels_reduction,
            use_global_mlp=self.use_global_mlp,
            dropout_rate=self.dropout_rate,
            use_bias=self.use_bias,
            name=f"stage_{idx_stage}_decoder_block_{i}")(
                x, bridge=signal, deterministic=not train)

        # Cache decoder features for later-stage's usage
        decs.append(x)

        # output conv, if not final stage, use supervised-attention-block.
        if i < self.num_supervision_scales:
          if idx_stage < self.num_stages - 1:  # not last stage, apply SAM
            sam, output = SAM(
                (2**i) * self.features,
                output_channels=self.num_outputs,
                use_bias=self.use_bias,
                name=f"stage_{idx_stage}_supervised_attention_module_{i}")(
                    x, shortcuts[i], train=train)
            outputs.append(output)
            sam_features.append(sam)
          else:  # Last stage, apply output convolutions
            output = Conv3x3(self.num_outputs,
                             use_bias=self.use_bias,
                             name=f"stage_{idx_stage}_output_conv_{i}")(x)
            output = output + shortcuts[i]
            outputs.append(output)
      # Cache encoder and decoder features for later-stage's usage
      encs_prev = encs[::-1]
      decs_prev = decs

      # Store outputs
      outputs_all.append(outputs)
    return outputs_all


def Model(*, variant=None, **kw):
  """Factory function to easily create a Model variant like "S".

  Every model file should have this Model() function that returns the flax
  model function. The function name should be fixed.

  Args:
    variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
        | 'M-1' | 'M-2' | 'M-3'
    **kw: Other UNet config dicts.

  Returns:
    The MAXIM() model function
  """

  if variant is not None:
    config = {
        # params: 6.108515000000001 M, GFLOPS: 93.163716608
        "S-1": {
            "features": 32,
            "depth": 3,
            "num_stages": 1,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 13.35383 M, GFLOPS: 206.743273472
        "S-2": {
            "features": 32,
            "depth": 3,
            "num_stages": 2,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 20.599145 M, GFLOPS: 320.32194560000005
        "S-3": {
            "features": 32,
            "depth": 3,
            "num_stages": 3,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 19.361219000000002 M, 308.495712256 GFLOPs
        "M-1": {
            "features": 64,
            "depth": 3,
            "num_stages": 1,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 40.83911 M, 675.25541888 GFLOPs
        "M-2": {
            "features": 64,
            "depth": 3,
            "num_stages": 2,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 62.317001 M, 1042.014666752 GFLOPs
        "M-3": {
            "features": 64,
            "depth": 3,
            "num_stages": 3,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
    }[variant]

    for k, v in config.items():
      kw.setdefault(k, v)

  return MAXIM(**kw)

## Assertion of outputs

In [18]:
from maxim.layers import Resizing

In [19]:
dummy_inputs = tf.random.normal((1, 365, 385, 3))

resizing_tf = Resizing(256, 256, method="bilinear", antialias=True)

tf_outputs = resizing_tf(dummy_inputs).numpy()
# tf_outputs = tf.image.resize(dummy_inputs, (256, 256), method="bilinear").numpy()
jax_outputs = jax.image.resize(dummy_inputs.numpy(), (1, 256, 256, 3), method="bilinear")

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), rtol=1e-5, atol=1e-5)

In [20]:
weight_initializer = nn.initializers.ones

class MlpBlock(nn.Module):
  """A 1-hidden-layer MLP block, applied over the last dimension."""
  mlp_dim: int
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, d = x.shape
    x = nn.Dense(self.mlp_dim, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer)(x)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
    x = nn.Dense(d, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer)(x)
    return x

In [21]:
from tensorflow.keras import backend as K 
from tensorflow.keras import layers 

def MlpBlockTF(
    mlp_dim: int,
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    name: str = "mlp_block",
):
    """A 1-hidden-layer MLP block, applied over the last dimension."""

    def apply(x):
        d = K.int_shape(x)[-1]
        x = layers.Dense(mlp_dim, use_bias=use_bias, kernel_initializer="ones", bias_initializer="ones", name=f"{name}_Dense_0")(x)
        x = tf.nn.gelu(x, approximate=True)
        x = layers.Dropout(dropout_rate)(x)
        x = layers.Dense(d, use_bias=use_bias, name=f"{name}_Dense_1", kernel_initializer="ones", bias_initializer="ones")(x)
        return x

    return apply

In [22]:
dummy_inputs = tf.random.normal((1, 256, 256, 32))

In [23]:
from jax import random

mlp_block_jax = MlpBlock(128)
variables = mlp_block_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = mlp_block_jax.apply(variables, dummy_inputs.numpy())

In [24]:
mlp_block_tf = MlpBlockTF(128)
tf_outputs = mlp_block_tf(dummy_inputs.numpy())

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), rtol=1e-4, atol=1e-4)

In [25]:
class UpSampleRatio(nn.Module):
  """Upsample features given a ratio > 0."""
  features: int
  ratio: float
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    n, h, w, c = x.shape
    x = jax.image.resize(
        x,
        shape=(n, int(h * self.ratio), int(w * self.ratio), c),
        method="bilinear")
    x = Conv1x1(features=self.features, use_bias=self.use_bias, kernel_init=weight_initializer, bias_init=weight_initializer)(x)
    return x

In [26]:
Conv1x1TF = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")


def UpSampleRatioTF(
    num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
):
    """Upsample features given a ratio > 0."""

    def apply(x):
        n, h, w, c = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )

        # Following `jax.image.resize()`
        x = Resizing(
            height=tf.cast(h * ratio, tf.int32),
            width=tf.cast(w * ratio, tf.int32),
            method="bilinear",
            antialias=True,
        )(x)

        x = Conv1x1TF(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0", kernel_initializer="ones", bias_initializer="ones")(x)
        return x

    return apply

In [27]:
upsample_block_jax = UpSampleRatio(features=128, ratio=2)
variables = upsample_block_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = upsample_block_jax.apply(variables, dummy_inputs.numpy())

In [28]:
upsample_block_tf = UpSampleRatioTF(128, 2)
tf_outputs = upsample_block_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), rtol=1e-4, atol=1e-4)

In [29]:
class CALayer(nn.Module):
  """Squeeze-and-excitation block for channel attention.
  ref: https://arxiv.org/abs/1709.01507
  """
  features: int
  reduction: int = 4
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    # 2D global average pooling
    y = jnp.mean(x, axis=[1, 2], keepdims=True)
    # Squeeze (in Squeeze-Excitation)
    y = Conv1x1(self.features // self.reduction, use_bias=self.use_bias, kernel_init=weight_initializer, bias_init=weight_initializer)(y)
    y = nn.relu(y)
    # Excitation (in Squeeze-Excitation)
    y = Conv1x1(self.features, use_bias=self.use_bias, kernel_init=weight_initializer, bias_init=weight_initializer)(y)
    y = nn.sigmoid(y)
    return x * y

In [30]:
def CALayerTF(
    num_channels: int,
    reduction: int = 4,
    use_bias: bool = True,
    name: str = "channel_attention",
):
    """Squeeze-and-excitation block for channel attention.
    ref: https://arxiv.org/abs/1709.01507
    """

    def apply(x):
        # 2D global average pooling
        y = layers.GlobalAvgPool2D(keepdims=True)(x)
        # Squeeze (in Squeeze-Excitation)
        y = Conv1x1TF(
            filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0",
            kernel_initializer="ones", bias_initializer="ones",
        )(y)
        y = tf.nn.relu(y)
        # Excitation (in Squeeze-Excitation)
        y = Conv1x1TF(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1", kernel_initializer="ones", bias_initializer="ones")(y)
        y = tf.nn.sigmoid(y)
        return x * y

    return apply

In [31]:
ca_block_jax = CALayer(features=32)
variables = ca_block_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = ca_block_jax.apply(variables, dummy_inputs.numpy())

In [32]:
ca_block_tf = CALayerTF(32)
tf_outputs = ca_block_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), rtol=1e-5, atol=1e-5)

In [33]:
class RCAB(nn.Module):
  """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
  features: int
  reduction: int = 4
  lrelu_slope: float = 0.2
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    shortcut = x
    x = nn.LayerNorm(name="LayerNorm")(x)
    x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv1", kernel_init=weight_initializer, bias_init=weight_initializer)(x)
    x = nn.leaky_relu(x, negative_slope=self.lrelu_slope)
    x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv2", kernel_init=weight_initializer, bias_init=weight_initializer)(x)
    x = CALayer(features=self.features, reduction=self.reduction,
                use_bias=self.use_bias, name="channel_attention")(x)
    return x + shortcut

In [34]:
Conv3x3TF = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")

def RCABTF(
    num_channels: int,
    reduction: int = 4,
    lrelu_slope: float = 0.2,
    use_bias: bool = True,
    name: str = "residual_ca",
):
    """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""

    def apply(x):
        shortcut = x
        x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
        x = Conv3x3TF(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1", kernel_initializer="ones", bias_initializer="ones")(x)
        x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
        x = Conv3x3TF(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2", kernel_initializer="ones", bias_initializer="ones")(x)
        x = CALayerTF(
            num_channels=num_channels,
            reduction=reduction,
            use_bias=use_bias,
            name=f"{name}_channel_attention",
        )(x)
        return x + shortcut
    
    return apply

In [35]:
rcab_block_jax = RCAB(features=32)
variables = rcab_block_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = rcab_block_jax.apply(variables, dummy_inputs.numpy())

In [36]:
rcab_block_tf = RCABTF(32)
tf_outputs = rcab_block_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), rtol=1e-5, atol=1e-5)

In [37]:
class GridGatingUnit(nn.Module):
  """A SpatialGatingUnit as defined in the gMLP paper.
  The 'spatial' dim is defined as the second last.
  If applied on other dims, you should swapaxes first.
  """
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    u, v = jnp.split(x, 2, axis=-1)
    v = nn.LayerNorm(name="intermediate_layernorm")(v)
    n = x.shape[-3]   # get spatial dim
    v = jnp.swapaxes(v, -1, -3)
    v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer, bias_init=weight_initializer)(v)
    v = jnp.swapaxes(v, -1, -3)
    return u * (v + 1.)

In [38]:
import tensorflow.experimental.numpy as tnp

class SwapAxes(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x, axis_one, axis_two):
        return tnp.swapaxes(x, axis_one, axis_two)

    def get_config(self):
        config = super().get_config().copy()
        return config


def GridGatingUnitTF(use_bias: bool = True, name: str = "grid_gating_unit"):
    """A SpatialGatingUnit as defined in the gMLP paper.
    The 'spatial' dim is defined as the second last.
    If applied on other dims, you should swapaxes first.
    """

    def apply(x):
        u, v = tf.split(x, 2, axis=-1)
        v = layers.LayerNormalization(
            epsilon=1e-06, name=f"{name}_intermediate_layernorm"
        )(v)
        n = K.int_shape(x)[-3]  # get spatial dim
        v = SwapAxes()(v, -1, -3)
        v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0", kernel_initializer="ones", bias_initializer="ones")(v)
        v = SwapAxes()(v, -1, -3)
        return u * (v + 1.)

    return apply

In [39]:
gg_block_jax = GridGatingUnit()
variables = gg_block_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = gg_block_jax.apply(variables, dummy_inputs.numpy())

In [40]:
gg_block_tf = GridGatingUnitTF()
tf_outputs = gg_block_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), atol=1e-4, rtol=1e-5)

In [41]:
class GridGmlpLayer(nn.Module):
  """Grid gMLP layer that performs global mixing of tokens."""
  grid_size: Sequence[int]
  use_bias: bool = True
  factor: int = 2
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, num_channels = x.shape
    gh, gw = self.grid_size
    fh, fw = h // gh, w // gw
    x = block_images_einops(x, patch_size=(fh, fw))
    # gMLP1: Global (grid) mixing part, provides global grid communication.
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer,
                name="in_project")(y)
    y = nn.gelu(y)
    y = GridGatingUnit(use_bias=self.use_bias, name="GridGatingUnit")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer,
                 bias_init=weight_initializer, name="out_project")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic)
    x = x + y
    x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw))
    return x

In [42]:
from maxim.layers import BlockImages, SwapAxes, UnblockImages


def GridGmlpLayerTF(
    grid_size,
    use_bias: bool = True,
    factor: int = 2,
    dropout_rate: float = 0.0,
    name: str = "grid_gmlp",
):
    """Grid gMLP layer that performs global mixing of tokens."""

    def apply(x):
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        gh, gw = grid_size
        fh, fw = h // gh, w // gw

        x = BlockImages()(x, patch_size=(fh, fw))
        # gMLP1: Global (grid) mixing part, provides global grid communication.
        y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
        y = layers.Dense(
            num_channels * factor,
            use_bias=use_bias,
            name=f"{name}_in_project",
            kernel_initializer="ones", bias_initializer="ones",
        )(y)
        y = tf.nn.gelu(y, approximate=True)
        y = GridGatingUnitTF(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
        y = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_out_project",
            kernel_initializer="ones", bias_initializer="ones",
        )(y)
        y = layers.Dropout(dropout_rate)(y)
        x = x + y
        x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
        return x

    return apply

In [43]:
gmlp_jax = GridGmlpLayer(grid_size=(16, 16))
variables = gmlp_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = gmlp_jax.apply(variables, dummy_inputs.numpy())

In [44]:
gmlp_tf = GridGmlpLayerTF(grid_size=(16, 16))
tf_outputs = gmlp_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), atol=1e-5, rtol=1e-5)

In [45]:
class BlockGatingUnit(nn.Module):
  """A SpatialGatingUnit as defined in the gMLP paper.
  The 'spatial' dim is defined as the **second last**.
  If applied on other dims, you should swapaxes first.
  """
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    u, v = jnp.split(x, 2, axis=-1)
    v = nn.LayerNorm(name="intermediate_layernorm")(v)
    n = x.shape[-2]  # get spatial dim
    v = jnp.swapaxes(v, -1, -2)
    v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer, bias_init=weight_initializer)(v)
    v = jnp.swapaxes(v, -1, -2)
    return u * (v + 1.)

class BlockGmlpLayer(nn.Module):
  """Block gMLP layer that performs local mixing of tokens."""
  block_size: Sequence[int]
  use_bias: bool = True
  factor: int = 2
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, num_channels = x.shape
    fh, fw = self.block_size
    gh, gw = h // fh, w // fw
    x = block_images_einops(x, patch_size=(fh, fw))
    # MLP2: Local (block) mixing part, provides within-block communication.
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer, name="in_project")(y)
    y = nn.gelu(y)
    y = BlockGatingUnit(use_bias=self.use_bias, name="BlockGatingUnit")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer, name="out_project")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic)
    x = x + y
    x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw))
    return x

In [46]:
def BlockGatingUnitTF(use_bias: bool = True, name: str = "block_gating_unit"):
    """A SpatialGatingUnit as defined in the gMLP paper.
    The 'spatial' dim is defined as the **second last**.
    If applied on other dims, you should swapaxes first.
    """

    def apply(x):
        u, v = tf.split(x, 2, axis=-1)
        v = layers.LayerNormalization(
            epsilon=1e-06, name=f"{name}_intermediate_layernorm"
        )(v)
        n = K.int_shape(x)[-2]  # get spatial dim
        v = SwapAxes()(v, -1, -2)
        v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0", kernel_initializer="ones", bias_initializer="ones")(v)
        v = SwapAxes()(v, -1, -2)
        return u * (v + 1.0)

    return apply


def BlockGmlpLayerTF(
    block_size,
    use_bias: bool = True,
    factor: int = 2,
    dropout_rate: float = 0.0,
    name: str = "block_gmlp",
):
    """Block gMLP layer that performs local mixing of tokens."""

    def apply(x):
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        fh, fw = block_size
        gh, gw = h // fh, w // fw
        x = BlockImages()(x, patch_size=(fh, fw))
        # MLP2: Local (block) mixing part, provides within-block communication.
        y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
        y = layers.Dense(
            num_channels * factor,
            use_bias=use_bias,
            name=f"{name}_in_project",
            kernel_initializer="ones", bias_initializer="ones",
        )(y)
        y = tf.nn.gelu(y, approximate=True)
        y = BlockGatingUnitTF(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
        y = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_out_project",
            kernel_initializer="ones", bias_initializer="ones",
        )(y)
        y = layers.Dropout(dropout_rate)(y)
        x = x + y
        x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
        return x

    return apply


In [48]:
bmlp_jax = BlockGmlpLayer(block_size=(16, 16))
variables = bmlp_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = bmlp_jax.apply(variables, dummy_inputs.numpy())

In [49]:
bmlp_tf = BlockGmlpLayerTF(block_size=(16, 16))
tf_outputs = bmlp_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), atol=1e-5, rtol=1e-5)

In [50]:
class ResidualSplitHeadMultiAxisGmlpLayer(nn.Module):
  """The multi-axis gated MLP block."""
  block_size: Sequence[int]
  grid_size: Sequence[int]
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  use_bias: bool = True
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    shortcut = x
    n, h, w, num_channels = x.shape
    x = nn.LayerNorm(name="LayerNorm_in")(x)
    x = nn.Dense(num_channels * self.input_proj_factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer, name="in_project")(x)
    x = nn.gelu(x)

    u, v = jnp.split(x, 2, axis=-1)
    
    # GridGMLPLayer
    u = GridGmlpLayer(
        grid_size=self.grid_size,
        factor=self.grid_gmlp_factor,
        use_bias=self.use_bias,
        dropout_rate=self.dropout_rate,
        name="GridGmlpLayer")(u, deterministic)

    # BlockGMLPLayer
    v = BlockGmlpLayer(
        block_size=self.block_size,
        factor=self.block_gmlp_factor,
        use_bias=self.use_bias,
        dropout_rate=self.dropout_rate,
        name="BlockGmlpLayer")(v, deterministic)

    x = jnp.concatenate([u, v], axis=-1)

    x = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, bias_init=weight_initializer, name="out_project")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic)
    x = x + shortcut
    return x

In [51]:
def ResidualSplitHeadMultiAxisGmlpLayerTF(
    block_size,
    grid_size,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    use_bias: bool = True,
    dropout_rate: float = 0.0,
    name: str = "residual_split_head_maxim",
):
    """The multi-axis gated MLP block."""

    def apply(x):
        shortcut = x
        n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )
        x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)

        x = layers.Dense(
            int(num_channels) * input_proj_factor,
            use_bias=use_bias,
            name=f"{name}_in_project",
            kernel_initializer="ones", bias_initializer="ones"
        )(x)
        x = tf.nn.gelu(x, approximate=True)

        u, v = tf.split(x, 2, axis=-1)

        # GridGMLPLayer
        u = GridGmlpLayerTF(
            grid_size=grid_size,
            factor=grid_gmlp_factor,
            use_bias=use_bias,
            dropout_rate=dropout_rate,
            name=f"{name}_GridGmlpLayer",
        )(u)

        # BlockGMLPLayer
        v = BlockGmlpLayerTF(
            block_size=block_size,
            factor=block_gmlp_factor,
            use_bias=use_bias,
            dropout_rate=dropout_rate,
            name=f"{name}_BlockGmlpLayer",
        )(v)

        x = tf.concat([u, v], axis=-1)

        x = layers.Dense(
            num_channels,
            use_bias=use_bias,
            name=f"{name}_out_project",
            kernel_initializer="ones", bias_initializer="ones"
        )(x)
        x = layers.Dropout(dropout_rate)(x)
        x = x + shortcut
        return x

    return apply

In [52]:
r_maxim_jax = ResidualSplitHeadMultiAxisGmlpLayer(block_size=(16, 16), grid_size=(16, 16))
variables = r_maxim_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = r_maxim_jax.apply(variables, dummy_inputs.numpy())

In [53]:
r_maxim_tf = ResidualSplitHeadMultiAxisGmlpLayerTF(block_size=(16, 16), grid_size=(16, 16))
tf_outputs = r_maxim_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), atol=1e-5, rtol=1e-5)

In [68]:
class RDCAB(nn.Module):
  """Residual dense channel attention block. Used in Bottlenecks."""
  features: int
  reduction: int = 16
  use_bias: bool = True
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = MlpBlock(
        mlp_dim=self.features,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="channel_mixing")(
            y, deterministic=deterministic)
    y = CALayer(
        features=self.features,
        reduction=self.reduction,
        use_bias=self.use_bias,
        name="channel_attention")(
            y)
    x = x + y
    return x

In [78]:
def RDCABTF(
    num_channels: int,
    reduction: int = 16,
    use_bias: bool = True,
    dropout_rate: float = 0.0,
    name: str = "rdcab",
):
    """Residual dense channel attention block. Used in Bottlenecks."""

    def apply(x):
        y = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_LayerNorm")(x)
        y = MlpBlockTF(
            mlp_dim=num_channels,
            dropout_rate=dropout_rate,
            use_bias=True,
            name=f"{name}_channel_mixing",
        )(y)
        y = CALayerTF(
            num_channels=num_channels,
            reduction=reduction,
            use_bias=True,
            name=f"{name}_channel_attention",
        )(y)
        x = x + y
        return x

    return apply

In [79]:
rdcab_jax = RDCAB(32)
variables = rdcab_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = rdcab_jax.apply(variables, dummy_inputs.numpy())

In [80]:
rdcab_tf = RDCABTF(32)
tf_outputs = rdcab_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), atol=1e-5, rtol=1e-5)

In [91]:
class BottleneckBlock(nn.Module):
  """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic=True):
    """Applies the Mixer block to inputs."""
    assert x.ndim == 4  # Input has shape [batch, h, w, c]
    n, h, w, num_channels = x.shape

    # input projection
    x = Conv1x1(self.features, use_bias=self.use_bias, name="input_proj", kernel_init=weight_initializer, bias_init=weight_initializer)(x)
    shortcut_long = x

    for i in range(self.num_groups):
      x = ResidualSplitHeadMultiAxisGmlpLayer(
          grid_size=self.grid_size,
          block_size=self.block_size,
          grid_gmlp_factor=self.grid_gmlp_factor,
          block_gmlp_factor=self.block_gmlp_factor,
          input_proj_factor=self.input_proj_factor,
          use_bias=self.use_bias,
          dropout_rate=self.dropout_rate,
          name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic)
      # Channel-mixing part, which provides within-patch communication.
      x = RDCAB(
          features=self.features,
          reduction=self.channels_reduction,
          use_bias=self.use_bias,
          name=f"channel_attention_block_1_{i}")(
              x)

    # long skip-connect
    x = x + shortcut_long
    return x

In [92]:
def BottleneckBlockTF(
    features: int,
    block_size,
    grid_size,
    num_groups: int = 1,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    channels_reduction: int = 4,
    dropout_rate: float = 0.0,
    use_bias: bool = True,
    name: str = "bottleneck_block",
):
    """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""

    def apply(x):
        """Applies the Mixer block to inputs."""

        # input projection
        x = Conv1x1TF(filters=features, use_bias=use_bias, name=f"{name}_input_proj", kernel_initializer="ones", bias_initializer="ones")(x)
        shortcut_long = x

        for i in range(num_groups):
            x = ResidualSplitHeadMultiAxisGmlpLayerTF(
                grid_size=grid_size,
                block_size=block_size,
                grid_gmlp_factor=grid_gmlp_factor,
                block_gmlp_factor=block_gmlp_factor,
                input_proj_factor=input_proj_factor,
                use_bias=use_bias,
                dropout_rate=dropout_rate,
                name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
            )(x)
            # Channel-mixing part, which provides within-patch communication.
            x = RDCABTF(
                num_channels=features,
                reduction=channels_reduction,
                use_bias=use_bias,
                name=f"{name}_channel_attention_block_1_{i}",
            )(x)

        # long skip-connect
        x = x + shortcut_long
        return x

    return apply

In [93]:
bb_jax = BottleneckBlock(32, (16, 16), (16, 16))
variables = bb_jax.init(random.PRNGKey(0), dummy_inputs.numpy())
jax_outputs = bb_jax.apply(variables, dummy_inputs.numpy())

In [94]:
bb_tf = BottleneckBlockTF(32, (16, 16), (16, 16))
tf_outputs = bb_tf(dummy_inputs)

np.testing.assert_allclose(tf_outputs, np.array(jax_outputs), atol=1e-5, rtol=1e-5)

In [96]:
inputs = tf.random.normal((1, 256, 256, 3))
num_supervision_scales = 3
x = inputs.numpy()

n, h, w, c = x.shape  # input image shape

shortcuts_jax = []
shortcuts_jax.append(x)
# Get multi-scale input images
for i in range(1, num_supervision_scales):
    image_resized = jax.image.resize(
        x, shape=(n, h // (2**i), w // (2**i), c), method="nearest")
    shortcuts_jax.append(np.array(image_resized))

In [101]:
num_supervision_scales = 3
x = inputs

n, h, w, c = x.shape  # input image shape

shortcuts_tf = []
shortcuts_tf.append(x)


for i in range(1, num_supervision_scales):
    resizing_layer = Resizing(
        height=h // (2 ** i),
        width=w // (2 ** i),
        method="nearest",
        antialias=True,
    )  # Following `jax.image.resize()`.
    shortcuts_tf.append(resizing_layer(x).numpy())

In [102]:
for shortcut_jax, shortcut_tf in zip(shortcuts_jax, shortcuts_tf):
    assert np.allclose(shortcut_jax, shortcut_tf), f"difference: {np.max(shortcut_jax - shortcut_tf)}"

In [104]:
outputs_all_jax = []
num_stages = 3
features = 32

for idx_stage in range(num_stages):
    # Input convolution, get multi-scale input features
    x_scales = []
    for i in range(num_supervision_scales):
        conv3x3_jax = Conv3x3(
            (2**i) * features,
            use_bias=True,
            kernel_init=nn.initializers.ones,
            bias_init=nn.initializers.ones,
            name=f"stage_{idx_stage}_input_conv_{i}")

        variables = conv3x3_jax.init(random.PRNGKey(0), shortcuts_jax[i])
        x_scale = conv3x3_jax.apply(variables, shortcuts_jax[i])
        outputs_all_jax.append(np.array(x_scale))

In [106]:
outputs_all_tf = []

for idx_stage in range(num_stages):
    # Input convolution, get multi-scale input features
    x_scales = []
    for i in range(num_supervision_scales):
        x_scale = Conv3x3TF(
            filters=(2 ** i) * features,
            use_bias=True,
            name=f"stage_{idx_stage}_input_conv_{i}",
            kernel_initializer="ones",
            bias_initializer="ones"
        )(shortcuts_tf[i])
        outputs_all_tf.append(x_scale.numpy())

In [107]:
for shortcut_jax, shortcut_tf in zip(outputs_all_jax, outputs_all_tf):
    assert np.allclose(shortcut_jax, shortcut_tf), f"difference: {np.max(shortcut_jax - shortcut_tf)}"