In [None]:
%load_ext autoreload
%matplotlib widget
import os
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math
import tensorflow as tf
from tqdm.autonotebook import trange
from keras.datasets import mnist

%autoreload 1
%aimport jsi

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# training dataset generation
data_len = 10000

# Define a function to create a serialized tf.train.Example
def create_example(data, label):
    feature = {
        "data": tf.train.Feature(float_list=tf.train.FloatList(value=data.flatten())),
        "label": tf.train.Feature(float_list=tf.train.FloatList(value=label)),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

# Write examples to a TFRecord file
filename = "my_dataset.tfrecord"
with tf.io.TFRecordWriter(filename) as writer:
    for i in trange(data_len):
        js = np.random.uniform(low=0.0, high=1, size=(3, 5))
        phis = np.random.uniform(low=0.0, high=2*np.pi, size=(3, 5))
        loss = np.random.uniform(low=0.0, high=0.5, size=(2, 3, 28*2))
        g = np.random.uniform(low=0.0, high=0.1)
        input = {
            'nodes' : 28,
            'padding' : 0,
            'n_rings' : 3,
            'orth_itr' : 10,
            'length': 5,
            'js': np.array(js * np.exp(1j * phis), dtype=np.complex64), 
            'jr': np.array(loss[0], dtype=np.complex64), 
            'g': np.array([g], dtype=np.complex64), 
            'y0s': np.array(loss[1], dtype=np.complex64) 
        }
        _, _, output = jsi.jsi_backprop(input, train=False)
        # try:
        #     plt.close(fig)
        # except:
        #     pass
        # fig=jsi.pltSect(output, 0, 0, 28, 28)
        if np.isnan(output).any():
            data_len += 1
            continue
        output_norm = np.abs(output) / np.max(np.abs(output))
        example = create_example(output_norm, np.concatenate((js.flatten(), phis.flatten(), loss.flatten(), [g]), axis=0))
        writer.write(example.SerializeToString())


In [None]:
# Read from TFRecord file
# Create a dataset from the TFRecords file
dataset = tf.data.TFRecordDataset(["my_dataset.tfrecord"])

# Define the feature description (adjust shapes and types as needed)
feature_description = {
    'data': tf.io.FixedLenFeature([28*28], tf.float32),
    'label': tf.io.FixedLenFeature([3*5*2+2*3*28*2+1], tf.float32)
}

# Parse the serialized example
def _parse_function(example_proto):
    return tf.io.parse_single_example(example_proto, feature_description)

# Apply the parsing function to each example
parsed_dataset = dataset.map(_parse_function)

# Process the data
try:
    for example in parsed_dataset:
        data, label = example['data'], example['label']
        print("data: ", len(data))
        print("label: ", len(label))
except tf.errors.OutOfRangeError:
    print("Finished reading data.")

In [None]:
# training using cnn
input = {
    'nodes' : 28,
    'padding' : 0,
    'n_rings' : 3,
    'orth_itr' : 10,
    'length': 5,
}
model = jsi.jsi_conv(input, "my_dataset.tfrecord", epochs=10)

In [None]:
# verify against mnist data
# index of the mnist model
index = 10

(train_X, train_y), (test_X, test_y) = mnist.load_data()

pred = model.predict(test_X / 255)

input = {
    'nodes' : 28,
    'padding' : 0,
    'n_rings' : 3,
    'orth_itr' : 10,
    'length': 5,
    'js': np.array(pred[index][0:15].reshape((3, 5)) * np.exp(1j * pred[index][15:30].reshape((3, 5))), dtype=np.complex64), 
    'jr': np.array(pred[index][30:198].reshape((3, 56)), dtype=np.complex64), 
    'g': np.array([pred[index][-1]], dtype=np.complex64), 
    'y0s': np.array(pred[index][198:366].reshape((3, 56)), dtype=np.complex64) 
}
print(input['js'])
_, _, output = jsi.jsi_backprop(input, train=False)
try:
    plt.close(fig)
except:
    pass
fig = jsi.pltCtst(test_X[index] / 255, output, 0, 0, 28, 28)

In [None]:
# verify against jsi generated with some constraint

js = np.random.uniform(low=0.0, high=1, size=(3, 5))
phis = np.random.uniform(low=0.0, high=2*np.pi, size=(3, 5))
loss = np.random.uniform(low=0.0, high=0.5, size=(2, 3, 28*2))
g = np.random.uniform(low=0.0, high=0.1)
# loss = np.ones((2, 3, 28*2)) / 10
# g = 0.03
input = {
    'nodes' : 28,
    'padding' : 0,
    'n_rings' : 3,
    'orth_itr' : 10,
    'length': 5,
    'js': np.array(js * np.exp(1j * phis), dtype=np.complex64), 
    'jr': np.array(loss[0], dtype=np.complex64), 
    'g': np.array([g], dtype=np.complex64), 
    'y0s': np.array(loss[1], dtype=np.complex64) 
}
_, _, output = jsi.jsi_backprop(input, train=False)
target = np.abs(output) / np.max(np.abs(output))

index = 0
pred = model.predict(np.array([target]))

input = {
    'nodes' : 28,
    'padding' : 0,
    'n_rings' : 3,
    'orth_itr' : 10,
    'length': 5,
    'js': np.array(pred[index][0:15].reshape((3, 5)) * np.exp(1j * pred[index][15:30].reshape((3, 5))), dtype=np.complex64), 
    'jr': np.array(pred[index][30:198].reshape((3, 56)), dtype=np.complex64), 
    'g': np.array([pred[index][-1]], dtype=np.complex64), 
    'y0s': np.array(pred[index][198:366].reshape((3, 56)), dtype=np.complex64) 
}
print(input['js'])
_, _, output = jsi.jsi_backprop(input, train=False)
try:
    plt.close(fig)
except:
    pass
fig = jsi.pltCtst(target, output, 0, 0, 28, 28)