In [1]:
import jax
import jax.numpy as np
import functools



In [2]:
key = jax.random.PRNGKey(11)



## Data loading

In [114]:
def load_x(filename, debug=True):
    # returns full data array: (N, T=var, F=12)
    with open(filename, 'r') as f:
        blocks = f.read().strip().split('\n\n')
    x = [
        np.array([list(map(float, e.split())) for e in block.split('\n')])
        for block in blocks
    ]
    lens = [len(e) for e in x]
    if debug:
        print(f'shape: (N={len(x)}, T={min(lens)}-{max(lens)}, F={len(x[0][0])})')
    return x

In [115]:
def load_y_flat(sizes, debug=True):
    # returns flat output: (N, )
    sizes = list(map(int, sizes.split()))
    y = []
    for idx, size in enumerate(sizes):
        y += [idx+1] * size
    if debug:
        print(len(y))
    return y

In [159]:
def load_y(sizes, debug=True):
    # returns one-hot encoded, but without time dimension: (N, C=9)
    sizes = list(map(int, sizes.split()))
    y = []
    for idx, size in enumerate(sizes):
        y += np.array([[1. if i == idx else 0. for i in range(9)]] * size)
    if debug:
        print(f'({len(y)}, {len(y[0])})', end=' ')
    return y

In [191]:
def inflate_time_dimension(y_train, time_dims, debug=True):
    # inflate single vector along given time dimension
    y = []
    for idx in range(len(y_train)):
        y.append(
            np.array([y_train[idx] for _ in range(time_dims[idx])])
        )
    if debug:
        lens = [len(e) for e in y]
        print(f'(N={len(y)}, T={min(lens)}-{max(lens)}, C={len(y[0][0])})')
    return y

In [192]:
print('train ', end='')
x_train = load_x('data/ae.train')
print('test  ', end='')
x_test = load_x('data/ae.test')

train shape: (N=270, T=7-26, F=12)
test  shape: (N=370, T=7-29, F=12)


In [194]:
print('train ', end='')
y_train = load_y('30 30 30 30 30 30 30 30 30')
y_train = inflate_time_dimension(y_train, time_dims=[e.shape[0] for e in x_train])
print('test  ', end='')
y_test = load_y('31 35 88 44 29 24 40 50 29')
y_test = inflate_time_dimension(y_test, time_dims=[len(e) for e in x_test])

train (270, 9) (N=270, T=7-26, C=9)
test  (370, 9) (N=370, T=7-29, C=9)


## Preprocessing 

***TODO***

# Reservoir

In [198]:
N_res = 100
N_inp = 12
N_out = 9
rhoW_target = 1.25

## Reservoir with JAX

In [199]:
Win = (jax.random.uniform(key, (N_res, 1+N_inp)) - 0.5) * 1.
W = (jax.random.uniform(key, (N_res, N_res)) - 0.5) * 1.
rhoW = np.max(np.absolute(np.linalg.eig(W)[0]))
W *= rhoW_target / rhoW
Wout = (jax.random.uniform(key, (N_out, 1+N_inp+N_res)) - 0.5) * 1.
print(f'Win: {Win.shape}, W: {W.shape}, Wout: {Wout.shape}')

Win: (100, 13), W: (100, 100), Wout: (9, 113)


In [200]:
def forward(u, Win, W, Wout=None, x_init=np.zeros((N_res,))):
    # u: (T, F)
    T = u.shape[0]
    F = u.shape[1]
    assert F == N_inp, 'input shape mismatch'
    X, Y = [], []
    # X: (1+F+N_res, T)
    # Y: (N_out, T)
    x = x_init.copy()
    for t in range(u.shape[0]):
        x = np.tanh(
            np.dot(Win, np.concatenate((np.ones(1,), u[t]))) + np.dot(W, x)
        )
        full_state = np.concatenate((np.ones(1,), u[t], x))
        X.append(full_state)
        if Wout is not None:
            y = np.dot(Wout, full_state)
            Y.append(y)
        # generative mode:
        # u = y
        # predictive mode:
        # u = data[trainLen+t+1]
    # when returning, need to transpose the data arrays, such that dimensions: (T, x)
    if Wout is None:
        return x, np.array(X).T
    else:
        return x, np.array(X).T, np.array(Y).T

In [201]:
x, X = forward(x_train[0], Win, W)
print(f'u: {x_train[0].shape}, x: {x.shape}, X: {X.shape}')

u: (20, 12), x: (100,), X: (113, 20)


In [202]:
x_train[0].shape, y_train[0].shape

((20, 12), (20, 9))

### Reservoir training

TODO: confirm equations are correct

In [204]:
reg = 1e-8
Wout_rc = np.linalg.solve(
    np.dot(X, X.T) + reg * np.eye(1+N_inp+N_res),
    np.dot(X, y_train[0])
)

In [208]:
Wout_rc_ = np.dot(
    np.dot(y_train[0].T, X.T),
    np.linalg.inv(
        np.dot(X, X.T) + reg*np.eye(1+N_inp+N_res)
    )
)

### Backprop training

The current `forward` function will not work with JAX, we need to use JAX primitives to implement it (rather than iterating over lists). The `forward_bp` function uses the `jax.lax.scan` method to apply the forward pass over a list of inputs, while carrying over the network's state from one time step to the next (see the API reference [here](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan)).

We then apply a softmax to the output layer and compute the cross-entropy loss. 

JAX does gradient descent automatically using the `grad` function. With that, we can define our gradient update step.

TODO: implement batching (handle variable sequence lengths)

In [242]:
def softmax(x):
    # normalized softmax
    x_norm = x - np.max(x)
    x_exp = np.exp(x_norm)
    return x_exp / np.sum(x_exp)

In [243]:
def forward_bp(params, u, x_init=np.zeros((N_res,))):
    """ Loop over the time steps of the input sequence
    u:      (time, features)
    x_init: (n_res, )
    """
    Win, W, Wout = params
    x = x_init.copy()

    def apply_fun_scan(params, x, ut):
        """ Perform single step update of the network.
        x:  (n_res, )
        ut: (features, )
        """
        Win, W, Wout = params
        x = np.tanh(
            np.dot(Win, np.concatenate((np.ones(1,), ut))) + np.dot(W, x)
        )
        y = softmax(np.dot(
            Wout,
            np.concatenate((np.ones(1,), ut, x))
        ))
        return x, y

    f = functools.partial(apply_fun_scan, params)
    _, Y = jax.lax.scan(f, x, u)
    return Y

In [247]:
Y = forward_bp((Win, W, Wout), x_train[0])
Y.shape

(20, 9)

In [248]:
def loss(params, u, y_true):
    # cross entropy loss (see Bishop's Pattern Recognition book, page 209).
    y_pred = forward_bp(params, u)
    return -np.sum(np.sum(y_true * np.log(y_pred), axis=1)) / x.shape[0]

In [249]:
loss((Win, W, Wout), x_train[0], y_train[0])

DeviceArray(1.3605596, dtype=float32)

In [267]:
dWin, dW, dWout = jax.grad(loss)((Win, W, Wout), x_train[0], y_train[0])

In [270]:
dW.shape, dW.mean(), dW.std()

((100, 100),
 DeviceArray(-0.0001335, dtype=float32),
 DeviceArray(0.05244131, dtype=float32))

In [271]:
@jax.jit
def update(params, x, y_true, step_size=1e-2):
    grads = jax.grad(loss)(params, x, y_true)
    return [
        w - step_size * dw
        for w, dw in zip(params, grads)
    ]

In [272]:
Win_, W_, Wout_ = update((Win, W, Wout), x_train[0], y_train[0])

## Reservoir with TF

TODO: make this work for ragged tensors.

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa

In [None]:
esn_keras = tf.keras.Sequential([
    tf.keras.layers.Input(shape=[N_inp], ragged=True),
    tfa.layers.ESN(units=N_res, connectivity=0.1, leaky=1.0, spectral_radius=rhoW_target),
    tf.keras.layers.Dense(9, activation='softmax'),
])
esn_keras.compile(
    optimizer=tf.keras.optimizers.SGD(),
    loss='categorical_crossentropy',
    metrics=['mse', 'accuracy']
)

In [None]:
x = tf.ragged.constant([[np.array(a) for a in e] for e in x_train])
x = x.to_tensor(default_value=0., shape=[None, N_inp])
y = tf.ragged.constant([e.tolist() for e in y_train])

In [None]:
history = esn_keras.fit(
    x, y,
    batch_size=64,
    epochs=20
)