In [1]:
# imports
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as jopt
import equinox as eqx
import mnist
import numpy.random as npr
import math
import datetime

jax.config.update("jax_debug_nans", True)
jnp.set_printoptions(suppress=True, precision=2, floatmode='fixed')
flt = jnp.float32
assert jax.devices()[0].device_kind == 'NVIDIA GeForce RTX 3060'

In [2]:
# data source
order = jnp.array(range(784))
# np.random.RandomState(0).shuffle(order)

train_images = mnist.train_images().reshape((-1, 784))[:, order] / 255
train_labels = mnist.train_labels()
train_labels_hot = jax.nn.one_hot(train_labels, 10)

test_img = mnist.test_images().reshape((-1, 784))[:, order] / 255
test_lbl = mnist.test_labels()

# specific processing for this model
train_images = jnp.expand_dims(train_images, axis=2)
test_img = jnp.expand_dims(test_img, axis=2)

In [74]:
# DLN

# B: batch size
# L: input length
# D: input dimensions
# x(t) = x(t-1) * e(-abs(size) + i * theta) + y(t)
class DLN(eqx.Module):
  size: jax.Array # D
  theta: jax.Array # D
  linear: eqx.nn.Linear # D -> D
  D: int = eqx.static_field()

  def __init__(self, key: jax.random.KeyArray, D: int):
    size_key, theta_key = jax.random.split(key)
    self.size = jax.random.normal(size_key, [D])
    self.theta = jax.random.normal(theta_key, [D]) * jnp.pi
    self.linear = eqx.nn.Linear(D, D, use_bias=False, key=key)
    self.D = D

  def __call__(self, y):  # [B, L, D]
    return jax.vmap(self.batchless)(y)

  def batchless(self, y):  # [L, D]
    return jax.vmap(self.dimensionless, (1, 0, 0))(y, self.size, self.theta)

  def dimensionless(self, y, size, theta):  # [L],
    # theta adjustment for size ~ 0
    z = jnp.exp(-jnp.abs(size)) # + 1j * theta
    def combine(a, b):
      pa, va = a
      pb, vb = b
      return jnp.stack([pa + pb, va * jnp.exp(z * pb) + vb])
    y = jnp.stack([jnp.ones(y.shape), y])  # [2, L]
    y = jnp.take(jax.lax.associative_scan(combine, y, axis=1), 1, axis=0)  # [L]
    # y = self.linear(y)
    return y

a = DLN(jax.random.PRNGKey(42), 1)
b = a(jnp.array([[[1], [2], [3], [2], [1]], [[2], [1], [1], [1], [1]]]))
#b = a(jnp.array([[[1], [2], [3], [4], [1]]]))
print(b)


[[[ 1.00  4.39 13.49 34.24 82.82]]

 [[ 2.00  5.78 14.81 36.40 87.98]]]
