In [None]:
# what is the variational?

Credits to [geohot](https://github.com/geohot/ai-notebooks/blob/master/mnist_gan.ipynb) for most of this code

In [None]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tf2onnx"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass


import os
import time
import random

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = [x/255.0 for x in [x_train, x_test]]
y_train, y_test = [tf.keras.utils.to_categorical(x) for x in [y_train, y_test]]

In [None]:
ZDIM = 4

def get_encoder():
  x = in1 = Input((28,28))
  x = Reshape((28,28,1))(x)

  x = Conv2D(64, (5,5), padding='same', strides=(2,2))(x)
  x = BatchNormalization()(x)
  x = ELU()(x)

  x = Conv2D(128, (5,5), padding='same', strides=(2,2))(x)
  x = BatchNormalization()(x)
  x = ELU()(x)

  x = Flatten()(x)
  x = Dense(ZDIM)(x)
  return Model(in1, x)

def get_decoder():
  x = in1 = Input((ZDIM,))

  x = Dense(7*7*64)(x)
  x = BatchNormalization()(x)
  x = ELU()(x)
  x = Reshape((7,7,64))(x)

  x = Conv2DTranspose(128, (5,5), strides=(2,2), padding='same')(x)
  x = BatchNormalization()(x)
  x = ELU()(x)

  x = Conv2DTranspose(1, (5,5), strides=(2,2), padding='same')(x)
  x = Activation('sigmoid')(x)
  x = Reshape((28,28))(x)
  return Model(in1, x)

### Regular Autoencoder

In [None]:
# normal autoencoder without the variational
enc = get_encoder()
dec = get_decoder()
ae = Model(enc.input, dec(enc.output))
ae.compile('adam', 'mse')
ae.summary()
# make the epochs larger for better results
ae.fit(x_train, x_train, batch_size=128, epochs=1, shuffle=1, validation_data=(x_test, x_test))

In [None]:
# while the autoencoder might work without the variational, the sampling doesn't
import numpy as np
from matplotlib.pyplot import figure, imshow
imshow(np.concatenate(ae.predict(np.array([random.choice(x_test) for i in range(10)])), axis=1))
figure(figsize=(16,16))
imshow(np.concatenate(ae.layers[-1].predict(np.random.normal(size=(10, ZDIM))), axis=1))

In [None]:
import os 

model_path = os.path.join('ae.onnx')
compiled_model_path = os.path.join('ae.compiled')
pk_path = os.path.join('ae.pk')
vk_path = os.path.join('ae.vk')
settings_path = os.path.join('ae_settings.json')
srs_path = os.path.join('ae_kzg.srs')
witness_path = os.path.join('ae_witness.json')
data_path = os.path.join('ae_input.json')

Now we export the decoder (which presumably is what we want) -- to onnx

In [None]:

import numpy as np
import tf2onnx
import tensorflow as tf
import json

shape = [1, ZDIM]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)

spec = tf.TensorSpec(shape, tf.float32, name='input_0')


tf2onnx.convert.from_keras(dec, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)

data_array = x.reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))


In [None]:
import ezkl

!RUST_LOG=trace
res = ezkl.gen_settings(model_path, settings_path)
assert res == True


In [None]:
cal_path = os.path.join("calibration.json")

data_array = (0.1 * np.random.rand(20, *shape)).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))


ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")

In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [None]:
# srs path
res = await ezkl.get_srs( settings_path)

In [None]:
# now generate the witness file
witness_path = "ae_witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:
res = ezkl.mock(witness_path, compiled_model_path)
assert res == True

In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [None]:
# GENERATE A PROOF


proof_path = os.path.join('ae.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        
            )

print(res)
assert os.path.isfile(proof_path)

In [None]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )

assert res == True
print("verified")

### Variational Autoencoder

In [None]:
in1 = Input((28,28))
x = get_encoder()(in1)

# add the variational
z_mu = Dense(ZDIM)(x)
z_log_var = Dense(ZDIM)(x)
z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])
dec = get_decoder()
dec.output_names=['output']

out = dec(z)

mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28
kl_loss = 1 + z_log_var - K.square(z_mu) - K.exp(z_log_var)
kl_loss = -0.5 * K.mean(kl_loss, axis=-1)

vae = Model(in1, out)
vae.add_loss(K.mean(mse_loss + kl_loss))
vae.compile('adam')

In [None]:
# z is sampled from z_mu and z_log_var with gaussian noise
test = Model(in1, [z, z_mu, z_log_var])
test.predict(x_train[0:1])

In [None]:
vae.fit(x_train, batch_size=128, epochs=1, shuffle=1, validation_data=(x_test, None))

In [None]:
imshow(np.concatenate(vae.predict(np.array([random.choice(x_test) for i in range(10)])), axis=1))
figure(figsize=(16,16))
imshow(np.concatenate(vae.layers[5].predict(np.random.normal(size=(10, ZDIM))), axis=1))

In [None]:
import os 

model_path = os.path.join('vae.onnx')
compiled_model_path = os.path.join('vae.compiled')
pk_path = os.path.join('vae.pk')
vk_path = os.path.join('vae.vk')
settings_path = os.path.join('vae_settings.json')
srs_path = os.path.join('vae_kzg.srs')
witness_path = os.path.join('vae_witness.json')
data_path = os.path.join('vae_input.json')

In [None]:

import numpy as np
import tf2onnx
import tensorflow as tf
import json

# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*[1, ZDIM])

spec = tf.TensorSpec([1, ZDIM], tf.float32, name='input_0')


tf2onnx.convert.from_keras(dec, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)

data_array = x.reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))

In [None]:
import ezkl

!RUST_LOG=trace
res = ezkl.gen_settings(model_path, settings_path)
assert res == True

res = ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
assert res == True
print("verified")

In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [None]:
# srs path
res = await ezkl.get_srs( settings_path)

In [None]:
# now generate the witness file 
witness_path = "vae_witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:
# uncomment to mock prove
# res = ezkl.mock(witness_path, compiled_model_path)
# assert res == True

In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        
    )


assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [None]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        
            )

print(res)
assert os.path.isfile(proof_path)

In [None]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )

assert res == True
print("verified")