In [None]:
# https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html

import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

from jax.config import config
config.enable_omnistaging() # Linen requires enabling omnistaging

In [None]:
# https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html
# 

# only need to define output shape
model = nn.Dense(features=5)

key1, key2 = random.split(random.PRNGKey(0))
# (n, m) = 10, 5
# dim(x) = 10, dim(y) = 5
x = random.normal(key1, (10,)) # Dummy input
# shape inference triggered with `x` here
params = model.init(key2, x) # Initialization call
# returns an `immutable` frozen dict
print(jax.tree_map(lambda x: x.shape, params)) # Checking output shapes

# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (xdim, ydim))
b = random.normal(k2, (ydim,))
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.normal(ksample, (nsamples, xdim))
y_samples = jnp.dot(x_samples, W) + b
y_samples += 0.1*random.normal(knoise,(nsamples, ydim)) # Adding noise
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)


def make_mse_func(x_batched, y_batched):
    def mse(params):
        # Define the squared loss for a single pair (x,y)
        def squared_error(x, y):
            pred = model.apply(params, x)
            return jnp.inner(y-pred, y-pred)/2.0
        # We vectorize the previous to compute the average of the loss on all (batched) samples!
        return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
    return jax.jit(mse) # And finally we jit the result.

# Get the sampled loss
loss = make_mse_func(x_samples, y_samples)


alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', loss(true_params))
grad_fn = jax.value_and_grad(loss)

for i in range(101):
    # We perform one gradient update
    loss_val, grad = grad_fn(params)
    params = jax.tree_multimap(lambda old, grad: old - alpha * grad,
                               params, grad)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)


In [None]:
x_samples.shape, y_samples.shape

In [None]:
from flax import optim
from gp import flax_create_optimizer

def flax_create_optimizer(params, lr, optimizer='GradientDescent'):
    optimizer_cls = getattr(optim, optimizer)
    return optimizer_cls(learning_rate=lr).create(params)



# explicit model definition
#
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x


key1, key2 = random.split(random.PRNGKey(0), 2)
model = MLP([ydim])
params = model.init(key2, jnp.ones((1,xdim)))


print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', model.apply(params, x_samples[0]))

def get_loss_fn(x_batched, y_batched):
    def loss_fn(params):
        # params -> objective
        def loss_fn_one(x, y):
            pred = model.apply(params, x)
            return jnp.inner(y-pred, y-pred)/2.0
        return jnp.mean(jax.vmap(loss_fn_one)(x_batched, y_batched), axis=0)
    return jax.jit(loss_fn)

def loss_fn(params):
    # params -> objective
    def loss_fn_one(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y-pred, y-pred)/2.0
    return jnp.mean(jax.vmap(loss_fn_one)(x_samples, y_samples), axis=0)

loss = jax.jit(loss_fn)
f = loss_fn


# `flax.optim`

# optimizer = create_optimizer(params, alpha)


from gp import flax_run_optim, log_func_default

params = flax_run_optim(loss_fn, params, lr=.2, num_steps=10,
                        log_func=lambda i,f,params: print(i, f(params)))





In [None]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros
        

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param('kernel',
                            self.kernel_init,
                            (inputs.shape[-1], self.features))
        
        print(kernel.shape, inputs.shape)
        
        y = lax.dot_general(inputs, kernel,
                            (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias
        return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)