
Credits to [geohot](https://github.com/geohot/ai-notebooks/blob/master/mnist_gan.ipynb) for most of this code

In [None]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tf2onnx"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

# make sure you have the dependencies required here already installed
import ezkl
import os
import json
import time
import random
import logging

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist

# uncomment for more descriptive logging 
# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
# logging.basicConfig(format=FORMAT)
# logging.getLogger().setLevel(logging.INFO)

# Can we build a simple GAN that can produce all 10 mnist digits?

In [None]:

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = [x/255.0 for x in [x_train, x_test]]
y_train, y_test = [tf.keras.utils.to_categorical(x) for x in [y_train, y_test]]

In [None]:
ZDIM = 100

opt = Adam()


# discriminator
# 0 if it's fake, 1 if it's real
x = in1 = Input((28,28))
x = Reshape((28,28,1))(x)

x = Conv2D(64, (5,5), padding='same', strides=(2,2))(x)
x = BatchNormalization()(x)
x = ELU()(x)

x = Conv2D(128, (5,5), padding='same', strides=(2,2))(x)
x = BatchNormalization()(x)
x = ELU()(x)

x = Flatten()(x)
x = Dense(128)(x)
x = BatchNormalization()(x)
x = ELU()(x)
x = Dense(1, activation='sigmoid')(x)
dm = Model(in1, x)
dm.compile(opt, 'binary_crossentropy')
dm.summary()

# generator, output digits
x = in1 = Input((ZDIM,))

x = Dense(7*7*64)(x)
x = BatchNormalization()(x)
x = ELU()(x)
x = Reshape((7,7,64))(x)

x = Conv2DTranspose(128, (5,5), strides=(2,2), padding='same')(x)
x = BatchNormalization()(x)
x = ELU()(x)

x = Conv2DTranspose(1, (5,5), strides=(2,2), padding='same')(x)
x = Activation('sigmoid')(x)
x = Reshape((28,28))(x)

gm = Model(in1, x)
gm.compile('adam', 'mse')
gm.output_names=['output']
gm.summary()

opt = Adam()

# GAN
dm.trainable = False
x = dm(gm.output)
tm = Model(gm.input, x)
tm.compile(opt, 'binary_crossentropy')

dlosses, glosses = [], []

In [None]:
import numpy as np
from matplotlib.pyplot import figure, imshow, show

BS = 256

# GAN training loop
# make larger if you want it to look better
for i in range(1):
  # train discriminator
  dm.trainable = True
  real_i = x_train[np.random.choice(x_train.shape[0], BS)]
  fake_i = gm.predict_on_batch(np.random.normal(0,1,size=(BS,ZDIM)))
  dloss_r = dm.train_on_batch(real_i, np.ones(BS))
  dloss_f = dm.train_on_batch(fake_i, np.zeros(BS))
  dloss = (dloss_r + dloss_f)/2

  # train generator
  dm.trainable = False
  gloss_0 = tm.train_on_batch(np.random.normal(0,1,size=(BS,ZDIM)), np.ones(BS))
  gloss_1 = tm.train_on_batch(np.random.normal(0,1,size=(BS,ZDIM)), np.ones(BS))
  gloss = (gloss_0 + gloss_1)/2

  if i%50 == 0:
    print("%4d: dloss:%8.4f   gloss:%8.4f" % (i, dloss, gloss))
  dlosses.append(dloss)
  glosses.append(gloss)
    
  if i%250 == 0:
    
    figure(figsize=(16,16))
    imshow(np.concatenate(gm.predict(np.random.normal(size=(10,ZDIM))), axis=1))
    show()

In [None]:
from matplotlib.pyplot import plot, legend
figure(figsize=(8,8))
plot(dlosses[100:], label="Discriminator Loss")
plot(glosses[100:], label="Generator Loss")
legend()

In [None]:
x = []
for i in range(10):
  x.append(np.concatenate(gm.predict(np.random.normal(size=(10,ZDIM))), axis=1))
imshow(np.concatenate(x, axis=0))

In [None]:
import os 

model_path = os.path.join('gan.onnx')
compiled_model_path = os.path.join('gan.compiled')
pk_path = os.path.join('gan.pk')
vk_path = os.path.join('gan.vk')
settings_path = os.path.join('gan_settings.json')
srs_path = os.path.join('gan_kzg.srs')
witness_path = os.path.join('gan_witness.json')
data_path = os.path.join('gan_input.json')



Now we export the _generator_ to onnx

In [None]:

import numpy as np
import tf2onnx
import tensorflow as tf
import json

shape = [1, ZDIM]
# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 0.1*np.random.rand(1,*shape)

spec = tf.TensorSpec(shape, tf.float32, name='input_0')


tf2onnx.convert.from_keras(gm, input_signature=[spec], inputs_as_nchw=['input_0'], opset=12, output_path=model_path)

data_array = x.reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))


In [None]:
import ezkl

run_args = ezkl.PyRunArgs()
run_args.input_visibility = "private"
run_args.param_visibility = "fixed"
run_args.output_visibility = "public"
run_args.variables = [("batch_size", 1)]

!RUST_LOG=trace
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True


In [None]:
cal_path = os.path.join("calibration.json")

data_array = (0.2 * np.random.rand(20, *shape)).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))


ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources", scales=[0,6])

In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [None]:
# srs path
res = await ezkl.get_srs( settings_path)

In [None]:
# now generate the witness file 
witness_path = "gan_witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:
# uncomment to mock prove
# res = ezkl.mock(witness_path, compiled_model_path)
# assert res == True

In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [None]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        
            )

print(res)
assert os.path.isfile(proof_path)

In [None]:
# VERIFY IT
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )

assert res == True
print("verified")