In [1]:
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as jopt
import haiku as hk
import mnist
import numpy as np
import math
import datetime
import dataclasses
import optax

from typing import Optional, NamedTuple

jax.config.update("jax_debug_nans", True)
jnp.set_printoptions(suppress=True, precision=2, floatmode='fixed')
flt = jnp.float32
assert jax.devices()[0].device_kind == 'NVIDIA GeForce RTX 3060'

In [2]:
order = np.array(range(784))
np.random.RandomState(0).shuffle(order)

train_images = mnist.train_images().reshape((-1, 784))[:, order] / 255
train_labels = mnist.train_labels()
train_labels_hot = jax.nn.one_hot(train_labels, 10)

test_img = mnist.test_images().reshape((-1, 784))[:, order] / 255
test_lbl = mnist.test_labels()

# specific processing for this model
train_images = jnp.expand_dims(train_images, axis=2)
test_img = jnp.expand_dims(test_img, axis=2)

In [3]:
@dataclasses.dataclass
class DLN(hk.Module):
    head_len: int
    input_dim: int
    seq_len: int
    name: Optional[str] = None

    def __call__(
        self,
        x # [B, seq_len, input_dim]
    ):
        k = self.head_len * self.input_dim
        size = hk.get_parameter('size', shape=[k], dtype=flt, init=hk.initializers.RandomNormal())
        theta = hk.get_parameter('theta', shape=[k], dtype=flt, init=hk.initializers.RandomNormal(jnp.pi))
        weights = hk.get_parameter('weights', shape=[k,k], dtype=flt, init=hk.initializers.TruncatedNormal(stddev=1./jnp.sqrt(k)))
        z = -jnp.square(size) + 1j * theta
        
        def combine(a, b):
            pos = a[:,:,:,0] + b[:,:,:,0]
            val = a[:,:,:,1] * jnp.exp(z * b[:,:,:,0]) + b[:,:,:,1]
            return jnp.stack([pos, val], axis=-1)
        x = jnp.tile(x, (1, 1, self.head_len))  # [B, seq_len, head_len * input_dim]
        x = jnp.stack([jnp.ones(x.shape), x], axis=-1, dtype=complex)  # [B, seq_len, head_len * input_dim, 2]
        x = jax.lax.associative_scan(combine, x, axis=1)[:,:,:,1]  # [B, seq_len, head_len * input_dim]
        x = (x @ weights)  # [B, seq_len, head_len * input_dim]
        x = jnp.real(x) * jnp.imag(x)  # [B, seq_len, head_len * input_dim]
        return x

@hk.transform
def model(x):
    middle_len = 20
    x = DLN(head_len=middle_len, input_dim=1, seq_len=784, name='dln_root')(x)  # [B, seq_len, head_len * input_dim]
    x = hk.Linear(output_size=middle_len)(x) + x  # [B, seq_len, middle_len]
    x = jax.nn.relu(x)
    for depth in range(3):
        x = DLN(head_len=1, input_dim=middle_len, seq_len=784, name=f'dln_{depth}')(x)  # [B, seq_len, head_len * middle_len]
        x = hk.Linear(output_size=middle_len)(x) + x  # [B, seq_len, middle_len]
        x = jax.nn.relu(x)
    # x = DLN(head_len=10, input_dim=middle_len, seq_len=784, name='dln')(x)  # [B, seq_len, head_len * input_dim]
    # x = hk.dropout(hk.next_rng_key(), 0.2, x)
    x = x[:, -1, :]  # [B, middle_len]
    x = hk.Linear(output_size=10, with_bias=True)(x)  # [B, 10]
    x = jax.nn.softmax(x)
    return x

def loss(params: hk.Params, rnd, inputs, outputs):
    guess = model.apply(params, rnd, inputs)
    return -jnp.mean(jnp.log(0.001+jnp.sum(guess * outputs, axis=1)))

a = model.init(x=train_images[0:5,:], rng=jax.random.PRNGKey(0))

In [5]:
class State(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState
    loss_value: float

@jax.jit
def update(state: State, rnd, inputs, outputs) -> State:
    loss_value, grads = jax.value_and_grad(loss)(state.params, rnd, inputs, outputs)
    updates, opt_state = optimizer.update(grads, state.opt_state)
    return State(optax.apply_updates(state.params, updates), opt_state, loss_value)

def init(optimizer) -> State:
    sample = train_images[0:5,:]
    init_params = model.init(jax.random.PRNGKey(0), sample)
    return State(init_params, optimizer.init(init_params), -1.0)

optimizer = optax.adam(1e-3)
state = init(optimizer)
t = u = datetime.datetime.now()
steps = 100000
batch_size = (30,)
rnd = jax.random.split(jax.random.PRNGKey(0), steps)

try:
    for i in range(0, steps):
        inputs = jax.random.choice(rnd[i], train_images, batch_size)
        outputs = jax.random.choice(rnd[i], train_labels_hot, batch_size)
        state = update(state, rnd[i], inputs, outputs)
        if i == 0 or (datetime.datetime.now() - t).total_seconds() > 5.0:
            if i == 0 or (datetime.datetime.now() - u).total_seconds() > 20.0:
                print(f'{i:04d} {state.loss_value:0.2f} ', end='')
                accuracy = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], train_images[0:50,:]), axis=1) == train_labels[0:50])
                test = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], test_img[0:50,:]), axis=1) == test_lbl[0:50])
                print(f'{accuracy:0.2f} / {test:0.2f}')
                u = datetime.datetime.now()
            else:
                print(f'{i:04d} {state.loss_value:0.2f}')
            t = datetime.datetime.now()

except KeyboardInterrupt:
    print(f'interrupt {i:03d} ', end='')

accuracy = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], train_images[0:200,:]), axis=1) == train_labels[0:200])
test = jnp.mean(jnp.argmax(model.apply(state.params, rnd[i], test_img[0:200,:]), axis=1) == test_lbl[0:200])
print(f'{accuracy:0.3f} / {test:0.3f} (done)')

0000 6.04 0.16 / 0.14
0001 5.16
0401 2.58
0808 2.29
1215 2.26 0.22 / 0.26
1622 2.14
2024 2.23
2425 2.17
2831 2.18 0.22 / 0.20
3238 2.29
3645 1.93
4043 1.95
4448 2.00 0.36 / 0.26
4855 1.84
5262 1.84
5669 1.78
6075 1.65 0.56 / 0.40
6482 1.45
6889 1.44
7296 0.89
7703 0.97 0.60 / 0.58
8110 0.82
8517 0.45
8924 0.87
9330 0.71 0.70 / 0.60
9732 0.72
interrupt 9773 0.755 / 0.685 (done)
