## Intro 

Stax is a simple python module stored in one `.py` [file](https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm). It includes simple functions that implement layers such as `Dense` and `Conv1d`. The key difference here is that Stax complies with the fundamental Jax paradigm: everything is done in functions (no OOP classes). This is inherently a good idea because a neural network is simply a parameterized function that applies some transformation to a given input to produce an output: it is just a function. The reason why most libraries such as PyTorch use OOP modules for NN's, is not just a function... it is a *parameterized* function. Classes allow us to store **state** + **function**. But however, stax takes the approach of using just functions for each layer by using the ""`init` and `apply`" approach. 

Here is an example: 

```python
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
  """Layer constructor function for a dense (fully-connected) layer."""
  def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
    return output_shape, (W, b)
  def apply_fun(params, inputs, **kwargs):
    W, b = params
    return jnp.dot(inputs, W) + b
  return init_fun, apply_fun
```

Here, we return two functions, one for returning the initialized params and the other to take in data + the params to apply the function to real inputs. 

In [5]:
from jax.example_libraries import stax

In [6]:
init_fun, apply_fun = stax.Dense(1)

In [16]:
import jax.numpy as jnp
import numpy as np
dummy_x = jnp.array(np.random.rand(3,2))

In [17]:
from jax import random, numpy as np

key = random.PRNGKey(42)

init_fun(key, input_shape=(dummy_x.shape))

((3, 1),
 (DeviceArray([[ 0.57130516],
               [-0.59154177]], dtype=float32),
  DeviceArray([0.01369469], dtype=float32)))

- `init` returns the params shape and the actual params based on the shape of the input. 
- `apply` takes in the params and the input to transform and returns the output. 

To-do: chaining together layers with the goal of outputting one pair of init and apply (i.e., a model!). 