# Runbook for converting [Maxim models](https://github.com/google-research/maxim) into Tensorflow.js-compatible models.

This Notebook demonstrates how to convert Maxim models into Tensorflow.js equivalents. It aims to produce larger models with identical performance to the Python models.

[This Notebook consumes the original Jax versions of the models](http://github.com/google-research/maxim/).

Start by installing dependencies, and cloning the relevant repository.

In [None]:
!pip install scikit-image numpy tensorflow tensorflow_hub matplotlib jax flax Pillow ml-collections "tensorflowjs>=4.5.0"
!cd node && npm install && cd ..

In [2]:
# Copy the forked version, avoid any future conflicts
!git clone https://github.com/thekevinscott/maxim cloned-code/maxim

fatal: destination path 'cloned-code/maxim' already exists and is not an empty directory.
fatal: destination path 'cloned-code/tfjs' already exists and is not an empty directory.


In [57]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Model Definitions & Constants

[Here are all the MAXIM trained model checkpoints](https://github.com/google-research/maxim#results-and-pre-trained-models). We choose the models with the best performance (aside from Dehazing, for which we choose both indoor and outdoor, and Enhancement, for which we use both datasets).

We also have model-specific images we'll use to run inference against.

In [None]:
models_and_images = [
    # (Task, Dataset, Sample Image)
    ('Dehazing', 'SOTS-Indoor', 'https://github.com/google-research/maxim/raw/main/maxim/images/Dehazing/input/1444_10.png'),
    ('Dehazing', 'SOTS-Outdoor','https://github.com/google-research/maxim/raw/main/maxim/images/Dehazing/input/0003_0.8_0.2.png'),
    ('Denoising', 'SIDD', 'https://github.com/google-research/maxim/raw/main/maxim/images/Denoising/input/0003_30.png'),
    ('Deblurring', 'GoPro', 'https://github.com/google-research/maxim/raw/main/maxim/images/Deblurring/input/1fromGOPR0950.png'),
    ('Deraining', 'Rain13k', 'https://github.com/google-research/maxim/raw/main/maxim/images/Deraining/input/15.png'),
    ('Enhancement', 'LOL', 'https://github.com/google-research/maxim/raw/main/maxim/images/Enhancement/input/a4541-DSC_0040-2.png'),
    ('Enhancement', 'FiveK', 'https://github.com/google-research/maxim/raw/main/maxim/images/Enhancement/input/a4541-DSC_0040-2.png'),
]

And choose output folders:

In [None]:
import pathlib

MODEL_OUTPUT_FOLDER = pathlib.Path('./models/jax')
IMAGES_OUTPUT_FOLDER = pathlib.Path('./images')
QUANTIZATION_SETTINGS = '' # Can choose float16 or uint16 or uint8
INPUT_SIZE = None # Can specify a fixed input size, or accept dynamically sized images. If dynamic, the image must be divisible by 64 (due to the architecture of MAXIM)

## Create the Models

You can alternatively run this from the command line with:

```
python3 create_jax_and_tf_models.py --task task --dataset dataset --output output_folder --tf_output tf_output --quantization_settings quantization_setting
```

In [None]:
import subprocess
from PIL import Image
import tensorflow as tf
from create_jax_and_tf_models import create_jax_and_tf_models
from evaluate import evaluate_jax_models

for task, dataset, sample_image in models_and_images[:]:
    q_folder = (QUANTIZATION_SETTINGS if QUANTIZATION_SETTINGS else 'uncompressed')
    input_size_folder = f'{INPUT_SIZE}' if INPUT_SIZE else 'none'

    tfjs_output = MODEL_OUTPUT_FOLDER / task / dataset / 'tfjs' / q_folder / input_size_folder
    checkpoint_path = f'gs://gresearch/maxim/ckpt/{task}/{dataset}/checkpoint.npz' # path to the checkpoint on google storage
    print(f'Prepare to get Python Jax model for task {task} and dataset {dataset}')
    jax_model, params = get_jax_model(task, checkpoint_path)
    print(f'Prepare to create TFJS model for task {task} and dataset {dataset}')
    create_jax_and_tf_models(
        jax_model,
        params,
        tfjs_output_folder=str(tfjs_output), 
        quantization_settings=QUANTIZATION_SETTINGS,
        input_size=INPUT_SIZE,
    )
    print(f'Saved TFJS model for task {task} and dataset {dataset} to folder: "{tfjs_output}"')

    print(f'Evaluating Python and Tensorflow.js models for task "{task}" and dataset "{dataset}"')
    evaluate_jax_models(jax_model, params, str(tfjs_output), sample_image, checkpoint_path)
        
    print('-' * 40)


Prepare to get Python Jax model for task Enhancement and dataset LOL
Prepare to create TFJS model for task Enhancement and dataset LOL


2023-05-17 00:59:17.333471: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'xs_0' with dtype float and shape [?,?,?,3]
	 [[{{node xs_0}}]]
2023-05-17 01:00:11.158522: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'serving_default_xs_0' with dtype float and shape [?,?,?,3]
	 [[{{node serving_default_xs_0}}]]


INFO:tensorflow:Assets written to: /tmp/tmptqhni2_2/assets


2023-05-17 01:07:53.709770: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-05-17 01:07:53.709948: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2023-05-17 01:07:53.710063: I tensorflow/core/grappler/clusters/single_machine.cc:358] Starting new session
2023-05-17 01:07:53.710401: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-05-17 01:07:53.710556: I tensorflow/compiler/xla/stream_executor/

In [None]:
records = []
for size in [
    64,
    128,
    256,
    512,
    768,
]:
    input_resolution = (size, size) if size else None
    _1, _2, ssim = evaluate_jax_models(jax_model, params, str(tfjs_output), sample_image, input_resolution=input_resolution)
    records.append((size, ssim))
    
records