In [1]:
import os
import sys
sys.path.append(os.getcwd())

In [5]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit
import optax

# Dataset generation
def generate_add_task_dataset(num_examples, timesteps, t1, t2, tau_task, rng_key):
    Xs, Ys = [], []
    for _ in range(num_examples):
        rng_key, subkey = random.split(rng_key)
        X, Y = generate_single_sequence(timesteps, t1, t2, tau_task, subkey)
        Xs.append(X)
        Ys.append(Y)
    return jnp.stack(Xs), jnp.stack(Ys)  # Shape: (num_examples, timesteps, 2)

def generate_single_sequence(timesteps, t1, t2, tau_task, rng_key):
    N = timesteps // tau_task
    x = jax.random.bernoulli(rng_key, 0.5, (N,)).astype(jnp.float32)
    y = 0.5 + 0.5 * jnp.roll(x, t1) - 0.25 * jnp.roll(x, t2)
    X = jnp.asarray([x, 1 - x]).T
    Y = jnp.asarray([y, 1 - y]).T
    X = jnp.tile(X, tau_task).reshape((timesteps, 2))
    Y = jnp.tile(Y, tau_task).reshape((timesteps, 2))
    return X, Y

# RNN model
def init_rnn_params(hidden_size, input_size, output_size, rng_key):
    k1, k2, k3, k4 = random.split(rng_key, 4)
    return {
        'Wxh': random.normal(k1, (input_size, hidden_size)) * 0.01,
        'Whh': random.normal(k2, (hidden_size, hidden_size)) * 0.01,
        'bh': jnp.zeros((hidden_size,)),
        'Why': random.normal(k3, (hidden_size, output_size)) * 0.01,
        'by': jnp.zeros((output_size,))
    }

def rnn_step(params, h, x):
    # x: (batch_size, input_size), h: (batch_size, hidden_size)
    h_new = jnp.tanh(jnp.dot(x, params['Wxh']) + jnp.dot(h, params['Whh']) + params['bh'])
    logits = jnp.dot(h_new, params['Why']) + params['by']
    return h_new, logits

def rnn_scan(params, h_init, xs):
    # xs: (batch_size, truncate_steps, input_size), h_init: (batch_size, hidden_size)
    def step(carry, x):
        h = carry
        h_new, logits = rnn_step(params, h, x)
        return h_new, logits
    h_final, logits = jax.lax.scan(step, h_init, xs, dimension=1)  # Scan over truncate_steps
    return h_final, logits  # logits: (batch_size, truncate_steps, output_size)

# Loss function
def cross_entropy_loss(logits, targets):
    # logits, targets: (batch_size, truncate_steps, output_size)
    probs = jax.nn.softmax(logits, axis=-1)
    return -jnp.mean(jnp.sum(targets * jnp.log(probs + 1e-10), axis=-1))

def compute_loss(params, h_init, xs, ys):
    _, logits = rnn_scan(params, h_init, xs)
    return cross_entropy_loss(logits, ys)

# Training with TBPTT
@jit
def update_step(params, opt_state, h_init, xs, ys):
    loss, grads = jax.value_and_grad(compute_loss)(params, h_init, xs, ys)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    h_final, _ = rnn_scan(params, h_init, xs)
    return params, opt_state, loss, h_final

def train_rnn(params, X, Y, hidden_size, truncate_steps, learning_rate, num_epochs, batch_size, rng_key):
    optimizer = optax.sgd(learning_rate)  # SGD as requested
    opt_state = optimizer.init(params)

    num_examples = X.shape[0]  # 1000
    timesteps = X.shape[1]  # 500
    num_chunks = timesteps // truncate_steps  # 500 / 20 = 25

    # Reshape X and Y into chunks
    X_chunks = X.reshape(num_examples, num_chunks, truncate_steps, input_size)
    Y_chunks = Y.reshape(num_examples, num_chunks, truncate_steps, output_size)

    for epoch in range(num_epochs):
        rng_key, subkey = random.split(rng_key)
        indices = random.permutation(subkey, num_examples)
        total_loss = 0.0
        num_batches = (num_examples + batch_size - 1) // batch_size

        for i in range(0, num_examples, batch_size):
            batch_indices = indices[i:min(i + batch_size, num_examples)]
            batch_X = X_chunks[batch_indices]  # (batch_size, num_chunks, truncate_steps, input_size)
            batch_Y = Y_chunks[batch_indices]  # (batch_size, num_chunks, truncate_steps, output_size)
            h_init = jnp.zeros((batch_indices.shape[0], hidden_size))  # Dynamic batch size

            batch_loss = 0.0
            for chunk in range(num_chunks):
                chunk_X = batch_X[:, chunk, :, :]  # (batch_size, truncate_steps, input_size)
                chunk_Y = batch_Y[:, chunk, :, :]  # (batch_size, truncate_steps, output_size)
                params, opt_state, loss, h_init = update_step(params, opt_state, h_init, chunk_X, chunk_Y)
                batch_loss += loss

            total_loss += batch_loss / num_chunks

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")

    return params

# Parameters
t1, t2 = 3, 5
tau_task = 1
num_examples = 1000
timesteps = 500
input_size = 2
output_size = 2
truncate_steps = 20
hidden_size = 64
learning_rate = 0.01
num_epochs = 10
batch_size = 128

# Generate dataset
prng = random.PRNGKey(0)
prng, new_prng = random.split(prng)
X, Y = generate_add_task_dataset(num_examples, timesteps, t1, t2, tau_task, new_prng)
print(f"X shape: {X.shape}, Y shape: {Y.shape}")  # (1000, 500, 2)

# Initialize model
prng, new_prng = random.split(prng)
params = init_rnn_params(hidden_size, input_size, output_size, new_prng)

# Train
trained_params = train_rnn(params, X, Y, hidden_size, truncate_steps, learning_rate, num_epochs, batch_size, new_prng)

X shape: (1000, 500, 2), Y shape: (1000, 500, 2)


TypeError: add got incompatible shapes for broadcasting: (20, 64), (128, 64).