# Hello, World!

In [5]:
import jax
import jax.numpy as jnp

input_dim = 28 * 28
output_dim = 10
batch_size = 128

# Generate random training data
key = jax.random.PRNGKey(0)
inputs = jax.random.normal(key, (input_dim, batch_size))
targets = jax.random.normal(key, (output_dim, batch_size))

In [4]:
from modula.atom import Linear
from modula.bond import ReLU

width = 256

mlp = Linear(output_dim, width)
mlp @= ReLU() @ Linear(width, width) 
mlp @= ReLU() @ Linear(width, input_dim)

print(mlp)

mlp.jit()

CompositeModule
...consists of 3 atoms and 2 bonds
...non-smooth
...input sensitivity is 1
...contributes proportion 3 to feature learning of any supermodule


In [12]:
from modula.error import SquareError

steps = 1000
learning_rate = 0.1
error = SquareError()

key = jax.random.PRNGKey(0)
w = mlp.initialize(key)
w = mlp.project(w)

for step in range(steps):
    # compute outputs and activations
    outputs, activations = mlp(inputs, w)
    
    # compute loss
    loss = error(outputs, targets)
    
    # compute error gradient
    error_grad = error.grad(outputs, targets)
    
    # compute gradient of weights
    grad_w, _ = mlp.backward(w, activations, error_grad)
    
    # dualize gradient
    d_w = mlp.dualize(grad_w)

    # compute scheduled learning rate
    lr = learning_rate * (1 - step / steps)
    
    # update weights
    w = [weight - lr * d_weight for weight, d_weight in zip(w, d_w)]

    if step % 100 == 0:
        print(f"Step {step}, Loss {loss}")


Step 0, Loss 0.9790326952934265
Step 100, Loss 0.0018738203216344118
Step 200, Loss 0.0014391584554687142
Step 300, Loss 0.0010814154520630836
Step 400, Loss 0.0008106177556328475
Step 500, Loss 0.0005738214822486043
Step 600, Loss 0.0003808117180597037
Step 700, Loss 0.00022766715846955776
Step 800, Loss 0.00011454012565081939
Step 900, Loss 3.979807297582738e-05
