In [1]:
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as jopt
import haiku as hk
import mnist
import numpy as np
import math
import datetime
import dataclasses
import optax

from typing import Optional, NamedTuple

jax.config.update("jax_debug_nans", True)
jnp.set_printoptions(suppress=True, precision=2, floatmode='fixed')
flt = jnp.float32
assert jax.devices()[0].device_kind == 'NVIDIA GeForce RTX 3060'

In [2]:
order = np.array(range(784))
# np.random.RandomState(0).shuffle(order)

train_images = mnist.train_images().reshape((-1, 784))[:, order] / 255
train_labels = mnist.train_labels()
train_labels_hot = jax.nn.one_hot(train_labels, 10)

test_img = mnist.test_images().reshape((-1, 784))[:, order] / 255
test_lbl = mnist.test_labels()

# specific processing for this model
train_images = jnp.expand_dims(train_images, axis=2)
test_img = jnp.expand_dims(test_img, axis=2)

In [3]:
@dataclasses.dataclass
class DLN(hk.Module):
    head_len: int
    input_dim: int
    seq_len: int
    name: Optional[str] = None

    def __call__(
        self,
        x # [B, seq_len, input_dim]
    ):
        k = self.head_len * self.input_dim
        size = hk.get_parameter('size', shape=[k], dtype=flt, init=hk.initializers.RandomNormal())
        theta = hk.get_parameter('theta', shape=[k], dtype=flt, init=hk.initializers.RandomNormal(jnp.pi))
        weights = hk.get_parameter('weights', shape=[k,k], dtype=flt, init=hk.initializers.TruncatedNormal(stddev=1./jnp.sqrt(k)))
        linear = jnp.arange(0, self.seq_len).reshape(-1, 1)  # [seq_len, 1]
        diags = jnp.exp((-jnp.square(size) + 1j * theta) * linear)  # (seq_len, head_len * input_dim)
        x = diags * jnp.tile(x, (1, 1, self.head_len))  # [B, seq_len, head_len * input_dim]
        x = jnp.sum(x, axis=1)  # [B, head_len * input_dim]
        x = (x @ weights)  # [B, head_len * input_dim]
        x = jnp.real(x) * jnp.imag(x)  # [B, head_len * input_dim]
        return x

@hk.transform
def model(x):
    x = DLN(head_len=800, input_dim=1, seq_len=784, name='dln')(x)  # [B, head_len * input_dim]
    x = hk.dropout(hk.next_rng_key(), 0.2, x)
    x = hk.Linear(output_size=10, with_bias=True)(x)  # [B, 10]
    x = jax.nn.softmax(x)
    return x

def loss(params: hk.Params, rnd, inputs, outputs):
    guess = model.apply(params, rnd, inputs)
    return -jnp.mean(jnp.log(jnp.sum(guess * outputs, axis=1)))

a = model.init(x=train_images[0:2,:], rng=jax.random.PRNGKey(0))

In [None]:
class State(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState

@jax.jit
def update(state: State, rnd, inputs, outputs) -> State:
    grads = jax.grad(loss)(state.params, rnd, inputs, outputs)
    updates, opt_state = optimizer.update(grads, state.opt_state)
    return State(optax.apply_updates(state.params, updates), opt_state)

def init(optimizer) -> State:
    sample = train_images[0:5,:]
    init_params = model.init(jax.random.PRNGKey(0), sample)
    return State(init_params, optimizer.init(init_params))

optimizer = optax.adam(1e-4)
state = init(optimizer)
t = datetime.datetime.now()
steps = 100000
batch_size = (30,)
rnd = jax.random.split(jax.random.PRNGKey(0), steps)

try:
    for i in range(0, steps):
        inputs = jax.random.choice(rnd[i], train_images, batch_size)
        outputs = jax.random.choice(rnd[i], train_labels_hot, batch_size)
        state = update(state, rnd[i], inputs, outputs)
        if i == 0 or (datetime.datetime.now() - t).total_seconds() > 5.0:
            accuracy = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], train_images[0:100,:]), axis=1) == train_labels[0:100])
            test = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], test_img[0:100,:]), axis=1) == test_lbl[0:100])
            print(f'{i:03d} {accuracy:0.3f} / {test:0.3f}')
            t = datetime.datetime.now()
except KeyboardInterrupt:
    print('interrupt')

accuracy = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], train_images[0:500,:]), axis=1) == train_labels[0:500])
test = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], test_img[0:500,:]), axis=1) == test_lbl[0:500])
print(f'{i:03d} {accuracy:0.3f} / {test:0.3f} (done)')

000 0.220 / 0.170
955 0.890 / 0.850
1937 0.930 / 0.970
interrupt
2098 0.908 / 0.874 (done)


In [None]:
state.params