# Scalability of *in situ* backpropagation

In this notebook, we explore the scalability of *in situ* backpropagation as it pertains to the tradeoff between noise and energy efficiency and latency of photonic devices. 
- As far as scalability of the photonic advantage, we do our best to incorporate all of the different elements that contribute to the total energy consumption in the hybrid photonic neural network design, dominated by optoelectronic conversions and signal amplification, and any assumptions for this calculation are provided in the main text and/or Supplementary Material of the paper.
- As far as noise error scaling, we explore the tradeoffs of various errors (e.g., systematic in the various photonic elements and random noise at the photodetector). We then perform large-scale simulations on MNIST data to show that realistic problems can be solved using our approach in the presence of error.



In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
pip install simphox

Note: you may need to restart the kernel to use updated packages.


In [3]:
from simphox.circuit import rectangular
from simphox.utils import random_unitary
from jax import jit
import jax
import numpy as np
import jax.numpy as jnp
from mlflow import log_metric, log_param, log_artifacts
import wandb

import holoviews as hv
hv.extension('bokeh')

from dataclasses import dataclass

jax.config.update('jax_platform_name', 'cpu')  # change to cpu if not using gpu (slower!)
jax.config.update("jax_enable_x64", False)
jax.devices()

[CpuDevice(id=0)]

In [4]:
mesh = rectangular(random_unitary(64).astype(np.complex64)) # use this object to get the necessary functions
mesh_0 = mesh
mesh_1 = rectangular(random_unitary(64).astype(np.complex64))
mesh_2 = rectangular(random_unitary(64).astype(np.complex64))

In [8]:
from keras.datasets import mnist, fashion_mnist
from utils import MNISTDataProcessor

mnist_dp = MNISTDataProcessor(fashion = True)
data_64 = mnist_dp.fourier(4)

y train shape: (60000,)
raw train shape: (60000, 28, 28)
ft train shape: (60000, 28, 28)
after flattening: (60000, 64)


In [9]:
print(data_64.x_train.shape)
print(data_64.y_train.shape)

(60000, 64)
(60000, 10)


In [10]:
print(data_64.y_train[1])

[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


# JIT function with benchmarking

In [11]:
# fn = jit(mesh.propagate_matrix_fn(use_jax=False, explicit=False))
prop_jit = jit(mesh.propagate_matrix_fn(use_jax=True, explicit=True))
prop = mesh.propagate_matrix_fn(use_jax=False, explicit=True)

In [12]:
from jax import grad

matrix_fn = mesh.matrix_fn(use_jax=True)
ones = jnp.ones(64, dtype=jnp.complex64)
from jax import vjp
tr = lambda u: jnp.abs(u[0, 0]) ** 2
fn = lambda params: tr(matrix_fn(params))
fn_jit = jit(fn)

## Pure JAX ONN

Implementation of an ONN in JAX using `simphox` and relevant benchmarking.

In [13]:
from jax.scipy.special import logsumexp
from jax import vmap, vjp

def dropout_softmax(outputs, num_classes: int = 10):
    return jax.nn.log_softmax(outputs[:num_classes]).astype(jnp.float32)
matrix_fn = jit(mesh.matrix_fn(use_jax=True))

def mnist_onn(params, inputs):
    # per-example predictions
    outputs = inputs
    for param in params:
        outputs = jnp.abs(matrix_fn(param, outputs))[:, 0] + 0j
    outputs = jnp.abs(outputs) ** 2
    return dropout_softmax(outputs, 10)

batched_mnist_onn = vmap(mnist_onn, in_axes=(None, 0))

In [14]:
@jit
def batch_loss(params, inputs, targets):
    preds = batched_mnist_onn(params, inputs)
    return -jnp.mean(preds.squeeze() * targets).real

@jit
def loss(params, inputs, targets):
    preds = mnist_onn(params, inputs)
    return -jnp.mean(preds.squeeze() * targets).real

print('Defining initial params...')
init_params = [mesh_0.params, mesh_1.params, mesh_2.params]

Defining initial params...


### Training 3-layer pure JAX ONN on MNIST

In [22]:
from jax.example_libraries.optimizers import adam
from tqdm.notebook import tqdm as pbar
from jax import value_and_grad

step_size = 0.001

num_iters = 1200
batch_size = 512
num_epochs = 50
x_train = jnp.array(data_64.x_train).astype(np.complex64)
y_train = jnp.array(data_64.y_train).astype(np.float32)

# idx = np.random.choice(60000, batch_size)
x_test = jnp.array(data_64.x_test).astype(np.complex64)
y_test = jnp.array(data_64.y_test).astype(np.float32)
batch_loss(get_params(opt_state), x_test, y_test)

Array(0.06834034, dtype=float32)

In [25]:
def evaluate(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_mnist_onn(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

In [26]:
opt_init, opt_update, get_params = adam(step_size=step_size)
init_params = [mesh_0.params, mesh_1.params, mesh_2.params]
init = opt_init(init_params)

def update_fn(i, state):
    start = (i * batch_size) % 60000
    stop = ((i + 1) * batch_size) % 60000
    v, g = value_and_grad(batch_loss)(get_params(state), x_train[start:stop], y_train[start:stop])
    return v, opt_update(i, g, state)
    
losses = []

opt_state = init
for epoch in range(num_epochs):
    iterator = pbar(range(x_train.shape[0]//batch_size))
    for i in iterator:
        v, opt_state = update_fn(i, opt_state)
        losses.append(v)
        iterator.set_description(f"𝓛: {v:.5f}")
    print(f"Batch size {batch_size}, epoch {epoch+1}: test accuracy is {evaluate(get_params(opt_state), x_test, y_test).item()}")

  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 1: test accuracy is 0.7619999647140503


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 2: test accuracy is 0.7849999666213989


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 3: test accuracy is 0.7961999773979187


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 4: test accuracy is 0.8047999739646912


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 5: test accuracy is 0.8120999932289124


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 6: test accuracy is 0.817799985408783


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 7: test accuracy is 0.8211999535560608


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 8: test accuracy is 0.8237999677658081


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 9: test accuracy is 0.8246999979019165


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 10: test accuracy is 0.8263999819755554


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 11: test accuracy is 0.8283999562263489


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 12: test accuracy is 0.8296999931335449


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 13: test accuracy is 0.8313999772071838


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 14: test accuracy is 0.8327999711036682


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 15: test accuracy is 0.8341000080108643


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 16: test accuracy is 0.8344999551773071


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 17: test accuracy is 0.8355000019073486


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 18: test accuracy is 0.8363999724388123


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 19: test accuracy is 0.8364999890327454


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 20: test accuracy is 0.8374999761581421


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 21: test accuracy is 0.8384999632835388


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 22: test accuracy is 0.8400999903678894


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 23: test accuracy is 0.8409000039100647


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 24: test accuracy is 0.8416999578475952


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 25: test accuracy is 0.8427000045776367


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 26: test accuracy is 0.8434999585151672


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 27: test accuracy is 0.8445999622344971


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 28: test accuracy is 0.8445999622344971


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 29: test accuracy is 0.8448999524116516


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 30: test accuracy is 0.8456999659538269


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 31: test accuracy is 0.8459999561309814


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 32: test accuracy is 0.8465999960899353


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 33: test accuracy is 0.8468999862670898


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 34: test accuracy is 0.8476999998092651


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 35: test accuracy is 0.8481000065803528


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 36: test accuracy is 0.8484999537467957


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 37: test accuracy is 0.8489999771118164


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 38: test accuracy is 0.8490999937057495


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 39: test accuracy is 0.849399983882904


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 40: test accuracy is 0.8502999544143677


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 41: test accuracy is 0.8501999974250793


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 42: test accuracy is 0.8499999642372131


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 43: test accuracy is 0.8504999876022339


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 44: test accuracy is 0.8504999876022339


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 45: test accuracy is 0.851099967956543


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 46: test accuracy is 0.8509999513626099


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 47: test accuracy is 0.8513000011444092


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 48: test accuracy is 0.8518999814987183


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 49: test accuracy is 0.8515999913215637


  0%|          | 0/117 [00:00<?, ?it/s]

Batch size 512, epoch 50: test accuracy is 0.8517999649047852
