In [1]:
import sys

sys.path.append("..")

In [2]:
from create_maxim_model import Model
from maxim.configs import MAXIM_CONFIGS

2022-10-03 16:07:46.125680: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py#L55

VARIANT = "S-3"
configs = MAXIM_CONFIGS.get(VARIANT)

configs.update(
    {
        "variant": VARIANT,
        "dropout_rate": 0.0,
        "num_outputs": 3,
        "use_bias": True,
        "num_supervision_scales": 3,
    }
)

model = Model(**configs)

2022-10-03 16:07:50.774652: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py

import tensorflow as tf
import numpy as np
import collections 
import io


def recover_tree(keys, values):
    """Recovers a tree as a nested dict from flat names and values.
    This function is useful to analyze checkpoints that are saved by our programs
    without need to access the exact source code of the experiment. In particular,
    it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
    subtree of parameters.
    Args:
      keys: a list of keys, where '/' is used as separator between nodes.
      values: a list of leaf values.
    Returns:
      A nested tree-like dict.
    """
    tree = {}
    sub_trees = collections.defaultdict(list)
    for k, v in zip(keys, values):
        if "/" not in k:
            tree[k] = v
        else:
            k_left, k_right = k.split("/", 1)
            sub_trees[k_left].append((k_right, v))
    for k, kv_pairs in sub_trees.items():
        k_subtree, v_subtree = zip(*kv_pairs)
        tree[k] = recover_tree(k_subtree, v_subtree)
    return tree


def get_params(ckpt_path):
    """Get params checkpoint."""

    with tf.io.gfile.GFile(ckpt_path, "rb") as f:
        data = f.read()
    values = np.load(io.BytesIO(data))
    params = recover_tree(*zip(*values.items()))
    params = params["opt"]["target"]

    return params

In [5]:
CKPT_PATH = "gs://gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz"
jax_params = get_params(CKPT_PATH)

In [6]:
model.count_params() / 1e6

22.212795

In [7]:
def get_model_vars(model):
    model_variables = model.variables
    model_variables_dict = {}
    for v in model_variables:
        model_variables_dict[v.name] = v

    return model_variables_dict

In [8]:
tf_params = get_model_vars(model)
list(tf_params.keys())[:5]

['stage_0_input_conv_0/kernel:0',
 'stage_0_input_conv_0/bias:0',
 'stage_0_encoder_block_0_Conv_0/kernel:0',
 'stage_0_encoder_block_0_Conv_0/bias:0',
 'stage_0_encoder_block_0_SplitHeadMultiAxisGmlpLayer_0_LayerNorm_in/gamma:0']

In [9]:
jax_params.keys()

dict_keys(['UpSampleRatio_0', 'UpSampleRatio_1', 'UpSampleRatio_10', 'UpSampleRatio_11', 'UpSampleRatio_12', 'UpSampleRatio_13', 'UpSampleRatio_14', 'UpSampleRatio_15', 'UpSampleRatio_16', 'UpSampleRatio_17', 'UpSampleRatio_18', 'UpSampleRatio_19', 'UpSampleRatio_2', 'UpSampleRatio_20', 'UpSampleRatio_21', 'UpSampleRatio_22', 'UpSampleRatio_23', 'UpSampleRatio_24', 'UpSampleRatio_25', 'UpSampleRatio_26', 'UpSampleRatio_27', 'UpSampleRatio_28', 'UpSampleRatio_29', 'UpSampleRatio_3', 'UpSampleRatio_30', 'UpSampleRatio_31', 'UpSampleRatio_32', 'UpSampleRatio_33', 'UpSampleRatio_34', 'UpSampleRatio_35', 'UpSampleRatio_36', 'UpSampleRatio_37', 'UpSampleRatio_38', 'UpSampleRatio_39', 'UpSampleRatio_4', 'UpSampleRatio_40', 'UpSampleRatio_41', 'UpSampleRatio_42', 'UpSampleRatio_43', 'UpSampleRatio_44', 'UpSampleRatio_45', 'UpSampleRatio_46', 'UpSampleRatio_47', 'UpSampleRatio_48', 'UpSampleRatio_49', 'UpSampleRatio_5', 'UpSampleRatio_50', 'UpSampleRatio_51', 'UpSampleRatio_52', 'UpSampleRatio_

In [10]:
def modify_convs(jax_params):
    modified_jax_params = collections.OrderedDict()

    for k in jax_params:
        if "conv" in k:
            conv_dict = jax_params.get(k)

            for j in conv_dict:
                modified_param_name = f"{k}/{j}:0"
                params = jax_params.get(k).get(j)
                modified_jax_params.update({modified_param_name: params})

    return modified_jax_params

In [11]:
# modified_jax_params_convs = modify_convs(jax_params)
# modified_jax_params_convs.keys()

```
gamma => scale
beta => bias
```

In [12]:
def modify_cross_gating(jax_params):
    modified_jax_params = {}

    for k in jax_params:
        if "cross_gating" in k:
            for j in jax_params[k]:
                if "SplitHeadMultiAxisGating" not in j:
                    for l in jax_params[k][j]:
                        modified_param_name = f"{k}_{j}/{l}:0"
                        params = jax_params.get(k).get(j).get(l)
                        modified_jax_params.update({modified_param_name: params})
                else:
                    for l in jax_params[k][j]:
                        for m in jax_params[k][j][l]:
                            modified_param_name = f"{k}_{j}_{l}/{m}:0"
                            params = jax_params.get(k).get(j).get(l).get(m)
                            modified_jax_params.update({modified_param_name: params})

    layernorm_dict = {}
    to_be_removed = []
    for k in modified_jax_params:
        if "layernorm" in k.lower():
            param = modified_jax_params.get(k)

            if "scale" in k:
                param_name = k.replace("scale", "gamma")
            elif "bias" in k:
                param_name = k.replace("bias", "beta")

            layernorm_dict.update({param_name: param})
            to_be_removed.append(k)

    for k in to_be_removed:
        del modified_jax_params[k]
    modified_jax_params.update(layernorm_dict)

    return modified_jax_params

In [13]:
# modify_cross_gating(jax_params).keys()

In [14]:
non_upsample = [
    "stage_0_cross_gating_block_0",
    "stage_0_cross_gating_block_1",
    "stage_0_cross_gating_block_2",
    "stage_0_decoder_block_0",
    "stage_0_decoder_block_1",
    "stage_0_decoder_block_2",
    "stage_0_encoder_block_0",
    "stage_0_encoder_block_1",
    "stage_0_encoder_block_2",
    "stage_0_global_block_0",
    "stage_0_global_block_1",
    "stage_0_input_conv_0",
    "stage_0_input_conv_1",
    "stage_0_input_conv_2",
    "stage_0_supervised_attention_module_0",
    "stage_0_supervised_attention_module_1",
    "stage_0_supervised_attention_module_2",
    "stage_1_cross_gating_block_0",
    "stage_1_cross_gating_block_1",
    "stage_1_cross_gating_block_2",
    "stage_1_decoder_block_0",
    "stage_1_decoder_block_1",
    "stage_1_decoder_block_2",
    "stage_1_encoder_block_0",
    "stage_1_encoder_block_1",
    "stage_1_encoder_block_2",
    "stage_1_global_block_0",
    "stage_1_global_block_1",
    "stage_1_input_conv_0",
    "stage_1_input_conv_1",
    "stage_1_input_conv_2",
    "stage_1_input_fuse_sam_0",
    "stage_1_input_fuse_sam_1",
    "stage_1_input_fuse_sam_2",
    "stage_1_supervised_attention_module_0",
    "stage_1_supervised_attention_module_1",
    "stage_1_supervised_attention_module_2",
    "stage_2_cross_gating_block_0",
    "stage_2_cross_gating_block_1",
    "stage_2_cross_gating_block_2",
    "stage_2_decoder_block_0",
    "stage_2_decoder_block_1",
    "stage_2_decoder_block_2",
    "stage_2_encoder_block_0",
    "stage_2_encoder_block_1",
    "stage_2_encoder_block_2",
    "stage_2_global_block_0",
    "stage_2_global_block_1",
    "stage_2_input_conv_0",
    "stage_2_input_conv_1",
    "stage_2_input_conv_2",
    "stage_2_input_fuse_sam_0",
    "stage_2_input_fuse_sam_1",
    "stage_2_input_fuse_sam_2",
    "stage_2_output_conv_0",
    "stage_2_output_conv_1",
    "stage_2_output_conv_2",
]


unique_blocks = set()

for block in non_upsample:
    if "conv" not in block:
        block = block.replace("stage_", "")
        block = "".join([i for i in block if not i.isdigit()])
        unique_blocks.add(block[1 : -1])

print(unique_blocks)

{'input_fuse_sam', 'encoder_block', 'decoder_block', 'global_block', 'cross_gating_block', 'supervised_attention_module'}


In [15]:
# Python3 Program to find depth of a dictionary
# https://www.geeksforgeeks.org/python-find-depth-of-a-dictionary/
def dict_depth(my_dict):
    if isinstance(my_dict, dict):

        return 1 + (max(map(dict_depth, my_dict.values())) if my_dict else 0)

    return 0


dict_depth(jax_params["stage_0_global_block_1"])

5

In [16]:
def modify_unique_block(jax_params, block_name):
    modified_jax_params = collections.OrderedDict()

    for k in jax_params:
        if block_name in k:
            for j in jax_params[k]:
                if dict_depth(jax_params[k][j]) == 1:
                    for l in jax_params[k][j]:
                        modified_param_name = f"{k}_{j}/{l}:0"
                        params = jax_params[k][j][l]
                        modified_jax_params.update({modified_param_name: params})

                elif dict_depth(jax_params[k][j]) == 3:
                    for l in jax_params[k][j]:
                        for m in jax_params[k][j][l]:
                            if not dict_depth(jax_params[k][j][l][m]):
                                modified_param_name = f"{k}_{j}_{l}/{m}:0"
                                params = jax_params[k][j][l][m]
                                modified_jax_params.update(
                                    {modified_param_name: params}
                                )
                            else:
                                for n in jax_params[k][j][l][m]:
                                    modified_param_name = f"{k}_{j}_{l}_{m}/{n}:0"
                                    params = jax_params[k][j][l][m][n]
                                    modified_jax_params.update(
                                        {modified_param_name: params}
                                    )

                elif dict_depth(jax_params[k][j]) == 4:
                    for l in jax_params[k][j]:
                        if dict_depth(jax_params[k][j][l]) == 1:
                            for m in jax_params[k][j][l]:
                                modified_param_name = f"{k}_{j}_{l}/{m}:0"
                                params = jax_params[k][j][l][m]
                                modified_jax_params.update(
                                    {modified_param_name: params}
                                )
                        elif dict_depth(jax_params[k][j][l]) == 3:
                            for m in jax_params[k][j][l]:
                                if dict_depth(jax_params[k][j][l][m]) == 1:
                                    for n in jax_params[k][j][l][m]:
                                        modified_param_name = f"{k}_{j}_{l}_{m}/{n}:0"
                                        params = jax_params[k][j][l][m][n]
                                        modified_jax_params.update(
                                            {modified_param_name: params}
                                        )
                                elif dict_depth(jax_params[k][j][l][m]) == 2:
                                    for n in jax_params[k][j][l][m]:
                                        for o in jax_params[k][j][l][m][n]:
                                            modified_param_name = (
                                                f"{k}_{j}_{l}_{m}_{n}/{o}:0"
                                            )
                                            params = jax_params[k][j][l][m][n][
                                                o
                                            ]
                                            modified_jax_params.update(
                                                {modified_param_name: params}
                                            )
                                            
                                            
    layernorm_dict = {}
    to_be_removed = []
    for k in modified_jax_params:
        if "layernorm" in k.lower():
            param = modified_jax_params.get(k)

            if "scale" in k:
                param_name = k.replace("scale", "gamma")
            elif "bias" in k:
                param_name = k.replace("bias", "beta")

            layernorm_dict.update({param_name: param})
            to_be_removed.append(k)

    for k in to_be_removed:
        del modified_jax_params[k]
    modified_jax_params.update(layernorm_dict)

    return modified_jax_params

In [17]:
def modify_upsample(jax_params):
    i = 1
    modified_jax_params = collections.OrderedDict()
    
    for k in jax_params:
        if "upsample" in k.lower():
            k_t = k.split("_")[0] + "_" + str(i)
            i += 1
            for j in jax_params[k]:
                for l in jax_params[k][j]:
                    modified_param_name = f"{k_t}_{j}/{l}:0"
                    params = jax_params[k][j][l]
                    modified_jax_params.update({modified_param_name: params})

    return modified_jax_params

In [18]:
all_modified_jax_params = collections.OrderedDict()

for unique_block in unique_blocks:
    all_modified_jax_params.update(modify_unique_block(jax_params, unique_block))
    
    
all_modified_jax_params.update(modify_convs(jax_params))
all_modified_jax_params.update(modify_cross_gating(jax_params))
all_modified_jax_params.update(modify_upsample(jax_params))

In [19]:
tf_params = get_model_vars(model)
len(tf_params.keys())

2754

In [20]:
len(all_modified_jax_params.keys())

1968

In [36]:
# set(tf_params.keys()) - set(all_modified_jax_params.keys())

In [46]:
def get_dict_depths(jax_params, block_name):
    depth_list = set()

    for l in jax_params.keys():
        if block_name in l:
            for k in jax_params[l]:
                depth_list.add(dict_depth(jax_params[l][k]))
    return depth_list


final_depth_mapping = collections.defaultdict(set)

for block_name in unique_blocks:
    print(f"============{block_name}============")
    if block_name not in final_depth_mapping:
        final_depth_mapping.update(
            {block_name: get_dict_depths(jax_params, block_name)}
        )
    else:
        final_depth_mapping[block_name].add(get_dict_depths(jax_params, block_name))

    print(f"========================\n")









In [47]:
final_depth_mapping

defaultdict(set,
            {'input_fuse_sam': {1, 2},
             'encoder_block': {1, 3, 4},
             'decoder_block': {1, 5},
             'global_block': {1, 3, 4},
             'cross_gating_block': {1, 2},
             'supervised_attention_module': {1}})

In [48]:
len(final_depth_mapping), len(unique_blocks)

(6, 6)

## Todos

* Need to devise modification strategies for each of these blocks except cross_gating_block. 