# basics
> implements the basics transform with a naive global state

In [None]:
#| default_exp basics

In [None]:
#| export
"""global state for tracking parameters"""
current_params = []

In [None]:
#| export
def transform(func):

    def apply_f(params, *args, **kwargs):
        current_params.append(params)
        outs = func(*args, **kwargs)
        current_params.pop()
        return outs
    return apply_f

In [None]:
#| export
def get_param(identifier):
    "Get a parameter from the global state"
    # NOTE: ONLY WORKS WITH OUR push-pop implementation
    return current_params[-1][identifier]

Let's exercise the module a little bit

In [None]:
"prework to get back to the state from the top of the tutorial"
def my_stateless_apply(params, x): return params['w'] * x

In [None]:
params = dict(w=5)
my_stateless_apply(params, 5)

25

In [None]:
class MyModule:
    def apply(self, x): return get_param('w') * x

transform(MyModule().apply)(params, 5)

25

Will this work with JAX, even though that library chokes when global state is involved?

In [None]:
#| export
import jax
import jax.numpy as jnp

def linear(x):
    return x @ get_param('w') + get_param('b')

In [None]:
params = dict(w=jnp.ones((3, 5)), b=jnp.ones((5,)))
apply = transform(linear)

jax.jit(apply)(params, jnp.ones((10, 3)))

Array([[4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.],
       [4., 4., 4., 4., 4.]], dtype=float32)

## Why does this work?
The global state before and after the function call is the same, and the function's outputs only depend on the inputs.
- So it's not global state _in the function_.
- It is global state _in the system_.

Details matter in programming, ladies and gents.
So JAX is happy, because we've respected its boundaries (as is key to any good relationship)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()