In [None]:
# Ref: https://blog.keras.io/building-autoencoders-in-tf.keras.html

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import tensorflow as tf

In [None]:
input_img = tf.keras.Input(shape=(28, 28, 1), name='input_img')

x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
x = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
x = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
decoded = tf.keras.layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = tf.keras.Model(input_img, decoded)

In [None]:
autoencoder.summary()

In [None]:
# build decoder

encoded_input = tf.keras.Input(shape=(4, 4, 8), name='encoded_input')
x = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded_input)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(16, (3, 3), activation='relu')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
decoded = tf.keras.layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

decoder = tf.keras.Model(encoded_input, decoded)

In [None]:
decoder.summary()

In [None]:
from tensorflow.keras.utils import plot_model

# plot model and display image
plot_model(decoder, to_file='decoder.png')
plot_model(autoencoder, to_file='autoencoder.png')

In [None]:
# iterate over the layers and export models from first to current layer to h5 files
for i in range(1, len(autoencoder.layers)):
    model_i = tf.keras.Model(inputs=autoencoder.inputs, outputs=autoencoder.layers[i].output, name=f"autoencoder_{i}")
    model_i.save(f"autoencoder_{i}/model.h5")

In [None]:
for i in range(1, len(decoder.layers)):
    model_i = tf.keras.Model(inputs=decoder.inputs, outputs=decoder.layers[i].output, name=f"decoder_{i}")
    model_i.save(f"decoder_{i}/model.h5")

In [None]:
import tf2onnx

spec = tf.TensorSpec([1, 28, 28, 1], tf.float32, name='input_img')

for i in range(1, len(autoencoder.layers)):
    model_i = tf.keras.models.load_model(f'autoencoder_{i}/model.h5')
    tf2onnx.convert.from_keras(
        model_i,
        input_signature=[spec],
        inputs_as_nchw=['input_img'],
        opset=12,
        output_path=f'autoencoder_{i}/model.onnx'
    )
    tf2onnx.convert.from_keras(
        model_i,
        input_signature=[spec],
        # inputs_as_nchw=['input_img'],
        opset=18,
        output_path=f'autoencoder_{i}/opset18.onnx'
    )

In [None]:
spec = tf.TensorSpec([1, 4, 4, 8], tf.float32, name='encoded_input')

for i in range(1, len(decoder.layers)):
    model_i = tf.keras.models.load_model(f'decoder_{i}/model.h5')
    tf2onnx.convert.from_keras(
        model_i,
        input_signature=[spec],
        inputs_as_nchw=['encoded_input'],
        opset=12,
        output_path=f'decoder_{i}/model.onnx'
    )
    tf2onnx.convert.from_keras(
        model_i,
        input_signature=[spec],
        # inputs_as_nchw=['encoded_input'],
        opset=18,
        output_path=f'decoder_{i}/opset18.onnx'
    )

In [None]:
# Ref: https://github.com/zkonduit/ezkl/blob/bceac2fab530fd01701aec3d8018ce318f6c42e1/examples/notebooks/mnist_vae.ipynb
!RUST_LOG=trace

# import os
import ezkl
import json


for i in range(1, len(autoencoder.layers)):
    print(f'autoencoder_{i}/model.onnx')
    model_path = os.path.join(f'autoencoder_{i}/model.onnx')
    settings_path = os.path.join(f'autoencoder_{i}/settings.json')

    res = ezkl.gen_settings(model_path, settings_path)
    assert res == True

    # read the settings from json
    with open(settings_path, 'r') as f:
        settings = json.load(f)
    
    # print the "num_rows" from the settings
    print(settings['num_rows'])


In [None]:
for i in range(1, len(decoder.layers)):
    model_path = os.path.join(f'decoder_{i}/model.onnx')
    settings_path = os.path.join(f'decoder_{i}/settings.json')

    res = ezkl.gen_settings(model_path, settings_path)
    assert res == True

    # read the settings from json
    with open(settings_path, 'r') as f:
        settings = json.load(f)
    
    # print the "num_rows" from the settings
    print(settings['num_rows'])

In [None]:
import sys
sys.path.append('..')
from keras2circom.keras2circom import circom, transpiler
circom.dir_parse('../keras2circom/node_modules/circomlib-ml/circuits/', skips=['util.circom', 'circomlib-matrix', 'circomlib', 'crypto'])

In [None]:
for i in range(1, len(autoencoder.layers)):
    args = {
        '<model.h5>': f'autoencoder_{i}/model.h5',
        '--output': f'autoencoder_{i}',
        '--raw': False,
        '--decimals': "18"
    }
    transpiler.transpile(args['<model.h5>'], args['--output'], args['--raw'], args['--decimals'])

In [None]:
for i in range(1, len(decoder.layers)):
    args = {
        '<model.h5>': f'decoder_{i}/model.h5',
        '--output': f'decoder_{i}',
        '--raw': False,
        '--decimals': "18"
    }
    transpiler.transpile(args['<model.h5>'], args['--output'], args['--raw'], args['--decimals'])