In [1]:
import sys

sys.path.append("..")

In [2]:
from create_maxim_model import Model
from maxim.configs import MAXIM_CONFIGS

2022-10-05 11:30:43.536748: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py#L55

VARIANT = "S-3"
configs = MAXIM_CONFIGS.get(VARIANT)

configs.update(
    {
        "variant": VARIANT,
        "dropout_rate": 0.0,
        "num_outputs": 3,
        "use_bias": True,
        "num_supervision_scales": 3,
    }
)

In [4]:
# From https://github.com/google-research/maxim/blob/main/maxim/run_eval.py

import tensorflow as tf
import pandas as pd
import numpy as np
import collections 
import re
import io


def recover_tree(keys, values):
    """Recovers a tree as a nested dict from flat names and values.
    This function is useful to analyze checkpoints that are saved by our programs
    without need to access the exact source code of the experiment. In particular,
    it can be used to extract an reuse various subtrees of the scheckpoint, e.g.
    subtree of parameters.
    Args:
      keys: a list of keys, where '/' is used as separator between nodes.
      values: a list of leaf values.
    Returns:
      A nested tree-like dict.
    """
    tree = {}
    sub_trees = collections.defaultdict(list)
    for k, v in zip(keys, values):
        if "/" not in k:
            tree[k] = v
        else:
            k_left, k_right = k.split("/", 1)
            sub_trees[k_left].append((k_right, v))
    for k, kv_pairs in sub_trees.items():
        k_subtree, v_subtree = zip(*kv_pairs)
        tree[k] = recover_tree(k_subtree, v_subtree)
    return tree


def get_params(ckpt_path):
    """Get params checkpoint."""

    with tf.io.gfile.GFile(ckpt_path, "rb") as f:
        data = f.read()
    values = np.load(io.BytesIO(data))
    params = recover_tree(*zip(*values.items()))
    params = params["opt"]["target"]

    return params

```
gamma => scale
beta => bias
```

In [5]:
# https://stackoverflow.com/questions/5491913/sorting-list-in-python
def sort_nicely(l): 
    """ Sort the given iterable in the way that humans expect.""" 
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)

In [6]:
def modify_upsample(jax_params):
    modified_jax_params = collections.OrderedDict()

    jax_keys = list(jax_params.keys())
    keys_upsampling = []
    for k in range(len(jax_keys)):
        if "UpSample" in jax_keys[k]:
            keys_upsampling.append(jax_keys[k])
    sorted_keys_upsampling = sort_nicely(keys_upsampling)
    
    i = 1
    for k in sorted_keys_upsampling:
        k_t = k.split("_")[0] + "_" + str(i)
        i += 1
        for j in jax_params[k]:
            for l in jax_params[k][j]:
                modified_param_name = f"{k_t}_{j}/{l}:0"
                params = jax_params[k][j][l]
                modified_jax_params.update({modified_param_name: params})

    return modified_jax_params


def modify_jax_params(jax_params):
    modified_jax_params = collections.OrderedDict()

    for k in jax_params:
        if "UpSample" not in k:
            params = jax_params[k]

            if ("ConvTranspose" in k) and ("bias" not in k):
                params = params.transpose(0, 1, 3, 2)

            split_names = k.split("_")
            modified_param_name = (
                "_".join(split_names[0:-1]) + "/" + split_names[-1] + ":0"
            )

            if "layernorm" in modified_param_name.lower():
                if "scale" in modified_param_name:
                    modified_param_name = modified_param_name.replace("scale", "gamma")
                elif "bias" in modified_param_name:
                    modified_param_name = modified_param_name.replace("bias", "beta")

            modified_jax_params.update({modified_param_name: params})

    return modified_jax_params

In [7]:
def port_jax_params(configs, ckpt_path):
    # Initialize TF Model.
    tf_model = Model(**configs)
    
    # Obtain a mapping of the TF variable names and their values.
    tf_model_variables = tf_model.variables
    tf_model_variables_dict = {}
    for v in tf_model_variables:
        tf_model_variables_dict[v.name] = v
        
    # Obtain the JAX pre-trained variables.
    jax_params = get_params(ckpt_path)
    [flat_jax_dict] = pd.json_normalize(jax_params, sep="_").to_dict(orient="records")
    
    # Amend the JAX variables to match the names of the TF variables.
    modified_jax_params = modify_jax_params(flat_jax_dict)
    modified_jax_params.update(modify_upsample(jax_params))
    
    # Porting.
    tf_weights = []
    i = 0

    for k in modified_jax_params:
        param = modified_jax_params[k]
        tf_weights.append((tf_model_variables_dict[k], param))
        i += 1

    assert i == len(modified_jax_params) == len(tf_model_variables_dict)

    tf.keras.backend.batch_set_value(tf_weights)

    return modified_jax_params, tf_model

In [8]:
_, tf_model = port_jax_params(
    configs, "gs://gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz"
)

2022-10-05 11:30:46.924935: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
