In [1]:
import sys

sys.path.append("..")

In [2]:
from create_maxim_model import Model
from maxim.configs import MAXIM_CONFIGS

2022-10-01 19:18:23.414196: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py#L55

VARIANT = "S-3"
configs = MAXIM_CONFIGS.get(VARIANT)

configs.update(
    {
        "variant": VARIANT,
        "dropout_rate": 0.0,
        "num_outputs": 3,
        "use_bias": True,
        "num_supervision_scales": 3,
    }
)

model = Model(**configs)

2022-10-01 19:18:26.351227: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py

import tensorflow as tf
import numpy as np
import collections 
import io


def recover_tree(keys, values):
    """Recovers a tree as a nested dict from flat names and values.
    This function is useful to analyze checkpoints that are saved by our programs
    without need to access the exact source code of the experiment. In particular,
    it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
    subtree of parameters.
    Args:
      keys: a list of keys, where '/' is used as separator between nodes.
      values: a list of leaf values.
    Returns:
      A nested tree-like dict.
    """
    tree = {}
    sub_trees = collections.defaultdict(list)
    for k, v in zip(keys, values):
        if "/" not in k:
            tree[k] = v
        else:
            k_left, k_right = k.split("/", 1)
            sub_trees[k_left].append((k_right, v))
    for k, kv_pairs in sub_trees.items():
        k_subtree, v_subtree = zip(*kv_pairs)
        tree[k] = recover_tree(k_subtree, v_subtree)
    return tree


def get_params(ckpt_path):
    """Get params checkpoint."""

    with tf.io.gfile.GFile(ckpt_path, "rb") as f:
        data = f.read()
    values = np.load(io.BytesIO(data))
    params = recover_tree(*zip(*values.items()))
    params = params["opt"]["target"]

    return params

In [5]:
CKPT_PATH = "gs://gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz"
jax_params = get_params(CKPT_PATH)

In [6]:
model.count_params() / 1e6

22.212795

In [7]:
def get_model_vars(model):
    model_variables = model.variables
    model_variables_dict = {}
    for v in model_variables:
        model_variables_dict[v.name] = v

    return model_variables_dict

In [8]:
tf_params = get_model_vars(model)
list(tf_params.keys())[:5]

['stage_0_input_conv_0/kernel:0',
 'stage_0_input_conv_0/bias:0',
 'stage_0_encoder_block_0_conv_in/kernel:0',
 'stage_0_encoder_block_0_conv_in/bias:0',
 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_LayerNorm_in/gamma:0']

In [9]:
jax_params.keys()

dict_keys(['UpSampleRatio_0', 'UpSampleRatio_1', 'UpSampleRatio_10', 'UpSampleRatio_11', 'UpSampleRatio_12', 'UpSampleRatio_13', 'UpSampleRatio_14', 'UpSampleRatio_15', 'UpSampleRatio_16', 'UpSampleRatio_17', 'UpSampleRatio_18', 'UpSampleRatio_19', 'UpSampleRatio_2', 'UpSampleRatio_20', 'UpSampleRatio_21', 'UpSampleRatio_22', 'UpSampleRatio_23', 'UpSampleRatio_24', 'UpSampleRatio_25', 'UpSampleRatio_26', 'UpSampleRatio_27', 'UpSampleRatio_28', 'UpSampleRatio_29', 'UpSampleRatio_3', 'UpSampleRatio_30', 'UpSampleRatio_31', 'UpSampleRatio_32', 'UpSampleRatio_33', 'UpSampleRatio_34', 'UpSampleRatio_35', 'UpSampleRatio_36', 'UpSampleRatio_37', 'UpSampleRatio_38', 'UpSampleRatio_39', 'UpSampleRatio_4', 'UpSampleRatio_40', 'UpSampleRatio_41', 'UpSampleRatio_42', 'UpSampleRatio_43', 'UpSampleRatio_44', 'UpSampleRatio_45', 'UpSampleRatio_46', 'UpSampleRatio_47', 'UpSampleRatio_48', 'UpSampleRatio_49', 'UpSampleRatio_5', 'UpSampleRatio_50', 'UpSampleRatio_51', 'UpSampleRatio_52', 'UpSampleRatio_

In [10]:
def modify_convs(jax_params):
    modified_jax_params = collections.OrderedDict()

    for k in jax_params:
        if "conv" in k:
            conv_dict = jax_params.get(k)

            for j in conv_dict:
                modified_param_name = f"{k}/{j}:0"
                params = jax_params.get(k).get(j)
                modified_jax_params.update({modified_param_name: params})

    return modified_jax_params

In [11]:
modified_jax_params_convs = modify_convs(jax_params)
modified_jax_params_convs.keys()

odict_keys(['stage_0_input_conv_0/bias:0', 'stage_0_input_conv_0/kernel:0', 'stage_0_input_conv_1/bias:0', 'stage_0_input_conv_1/kernel:0', 'stage_0_input_conv_2/bias:0', 'stage_0_input_conv_2/kernel:0', 'stage_1_input_conv_0/bias:0', 'stage_1_input_conv_0/kernel:0', 'stage_1_input_conv_1/bias:0', 'stage_1_input_conv_1/kernel:0', 'stage_1_input_conv_2/bias:0', 'stage_1_input_conv_2/kernel:0', 'stage_2_input_conv_0/bias:0', 'stage_2_input_conv_0/kernel:0', 'stage_2_input_conv_1/bias:0', 'stage_2_input_conv_1/kernel:0', 'stage_2_input_conv_2/bias:0', 'stage_2_input_conv_2/kernel:0', 'stage_2_output_conv_0/bias:0', 'stage_2_output_conv_0/kernel:0', 'stage_2_output_conv_1/bias:0', 'stage_2_output_conv_1/kernel:0', 'stage_2_output_conv_2/bias:0', 'stage_2_output_conv_2/kernel:0'])

```
gamma => scale
beta => bias
```

In [40]:
def modify_cross_gating(jax_params):
    modified_jax_params = collections.OrderedDict()

    for k in jax_params:
        if "cross_gating" in k:
            for j in jax_params[k]:
                if "SplitHeadMultiAxisGating" not in j:
                    for l in jax_params[k][j]:
                        modified_param_name = f"{k}_{j}/{l}:0"

                        if "scale" in modified_param_name:
                            modified_param_name = modified_param_name.replace(
                                "scale", "gamma"
                            )
                        elif "LayerNorm" in modified_param_name and "bias" in modified:
                            modified_param_name = modified_param_name.replace(
                                "bias", "beta"
                            )
                        params = jax_params.get(k).get(j).get(l)
                        modified_jax_params.update({modified_param_name: params})
                else:
                    for l in jax_params[k][j]:
                        for m in jax_params[k][j][l]:
                            modified_param_name = f"{k}_{j}_{l}/{m}:0"
                            if "scale" in modified:
                                modified_param_name = modified_param_name.replace(
                                    "scale", "gamma"
                                )
                            elif "LayerNorm" in modified and "bias" in modified:
                                modified_param_name = modified_param_name.replace(
                                    "bias", "beta"
                                )

                            params = jax_params.get(k).get(j).get(l).get(m)
                            modified_jax_params.update({modified_param_name: params})

    return modified_jax_params

In [53]:
non_upsample = [
    "stage_0_cross_gating_block_0",
    "stage_0_cross_gating_block_1",
    "stage_0_cross_gating_block_2",
    "stage_0_decoder_block_0",
    "stage_0_decoder_block_1",
    "stage_0_decoder_block_2",
    "stage_0_encoder_block_0",
    "stage_0_encoder_block_1",
    "stage_0_encoder_block_2",
    "stage_0_global_block_0",
    "stage_0_global_block_1",
    "stage_0_input_conv_0",
    "stage_0_input_conv_1",
    "stage_0_input_conv_2",
    "stage_0_supervised_attention_module_0",
    "stage_0_supervised_attention_module_1",
    "stage_0_supervised_attention_module_2",
    "stage_1_cross_gating_block_0",
    "stage_1_cross_gating_block_1",
    "stage_1_cross_gating_block_2",
    "stage_1_decoder_block_0",
    "stage_1_decoder_block_1",
    "stage_1_decoder_block_2",
    "stage_1_encoder_block_0",
    "stage_1_encoder_block_1",
    "stage_1_encoder_block_2",
    "stage_1_global_block_0",
    "stage_1_global_block_1",
    "stage_1_input_conv_0",
    "stage_1_input_conv_1",
    "stage_1_input_conv_2",
    "stage_1_input_fuse_sam_0",
    "stage_1_input_fuse_sam_1",
    "stage_1_input_fuse_sam_2",
    "stage_1_supervised_attention_module_0",
    "stage_1_supervised_attention_module_1",
    "stage_1_supervised_attention_module_2",
    "stage_2_cross_gating_block_0",
    "stage_2_cross_gating_block_1",
    "stage_2_cross_gating_block_2",
    "stage_2_decoder_block_0",
    "stage_2_decoder_block_1",
    "stage_2_decoder_block_2",
    "stage_2_encoder_block_0",
    "stage_2_encoder_block_1",
    "stage_2_encoder_block_2",
    "stage_2_global_block_0",
    "stage_2_global_block_1",
    "stage_2_input_conv_0",
    "stage_2_input_conv_1",
    "stage_2_input_conv_2",
    "stage_2_input_fuse_sam_0",
    "stage_2_input_fuse_sam_1",
    "stage_2_input_fuse_sam_2",
    "stage_2_output_conv_0",
    "stage_2_output_conv_1",
    "stage_2_output_conv_2",
]


unique_blocks = set()

for block in non_upsample:
    if "conv" not in block:
        block = block.replace("stage_", "")
        block = "".join([i for i in block if not i.isdigit()])
        unique_blocks.add(block[1 : -1])

print(unique_blocks)

{'global_block', 'cross_gating_block', 'decoder_block', 'supervised_attention_module', 'input_fuse_sam', 'encoder_block'}


In [77]:
# Python3 Program to find depth of a dictionary
def dict_depth(my_dict):
    if isinstance(my_dict, dict):

        return 1 + (max(map(dict_depth, my_dict.values())) if my_dict else 0)

    return 0


dict_depth(jax_params["stage_0_global_block_1"])

5

In [80]:
for k in jax_params:
    if "global_block" in k:
        print(f"{k}: {dict_depth(jax_params[k])}")
        for j in jax_params[k]:
            for l in jax_params[k][j]:
                print(f"{k}: {j}: {dict_depth(jax_params[k][j])}")
                if ("bias" not in l) and ("kernel" not in l) and ("scale" not in l):
                    for m in jax_params[k][j][l]:
                        print(f"{k}: {j}: {l}: {dict_depth(jax_params[k][j][l])}")
                        if (
                            ("bias" not in m)
                            and ("kernel" not in m)
                            and ("scale" not in m)
                        ):
                            for n in jax_params[k][j][l][m]:
                                print(
                                    f"{k}: {j}: {l}: {m} {dict_depth(jax_params[k][j][l][m])}"
                                )
#                                 print(f"{k}: {j}: {l}: {m}: {n}")

stage_0_global_block_0: 5
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: 4
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: 3
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: BlockGatingUnit 2
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: BlockGatingUnit 2
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: 3
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: LayerNorm 1
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: LayerNorm 1
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: 3
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: in_project 1
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: in_project 1
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: 3
stage_0_global_block_0: SplitHeadMultiAxisGmlpLayer_0: BlockGmlpLayer: out_project 1
stage_0_global_block_0: Spli