# 06 - Stateful computation in JAX

All programs need to deal with state. In ML, program state often is seen in the form of

- model parametrs
- optimizer state, as well as
- stateful layers, such as BatchNorm

As we already know by now, JAX transformations such as `jax.jit` impose constraints on the functions they transform.
These functions must have no side effects. Else, the side-effects will only be executed once, for the first time when it's compiled to the jitted version. On subsequent runs, there will be no side-effects.

Modification of program state is also a kind of side-effect. So, how does JAX deal with updating program state if it does not allow for side effects during function runs. Let's dig a bit more into it.

## A simple example : Counter

In [1]:
import jax
import jax.numpy as jnp

class Counter:
    def __init__(self):
        self.n = 0
        
    def count(self) -> int:
        self.n += 1
        return self.n
    
    def reset(self):
        self.n = 0
        
counter = Counter()

for _ in range(5):
    print(counter.count())

1
2
3
4
5


Now, what if we jit this function.

In [2]:
counter.reset()

In [3]:
jit_count = jax.jit(counter.count)

In [4]:
for _ in range(5):
    print(jit_count())

1
1
1
1
1


well, now you see what we mean by side effects not being allowed.

### Explicit state

As made popular by Function Programming style, state can me managed my making it explicit. i. e. instead of hiding state, pass it as an argument to the function itself.

In [5]:
class CounterWithState:
    def count(self, n):
        return n+1, n+1
    
    def reset(self):
        return 0

In [6]:
counter = CounterWithState()
state = counter.reset()
state

0

In [7]:
for _ in range(5):
    value, state = counter.count(state)
    print(value)

1
2
3
4
5


Now, let's jit this and see.

In [8]:
state = counter.reset()
jit_count = jax.jit(counter.count)

In [9]:
for _ in range(5):
    value, state = jit_count(state)
    print(value)

1
2
3
4
5


## That's really all there is to it.

Make all state explicit by providing it to functions as arguments. That way the hidden state isn't baked in the jitted function.

When there's no hidden state, only the computation is jitted, and the state gets to stay outside of the jitted function

### Does it really make sense to use classes anymore ?

Now that the main purpose of OOP class, which is to manage state is on really used, it's a lot simpler to have functions in a namespace.

### Can you now see the how/why the RNG state is now managed via functions and the state is made explicit ?

The example we worked in previous chapter can now be reworked using functions passing state.

It can be a bit cumbersome to manage a lot of state expicitly, this is where other libraries built on top of jax come in handy