# initialization

> Add initialization semantically (and rewrite a bunch of what we've already done)

In [None]:
#| default_exp basics_with_init

## Current problems to solve
- Global state:
  - Names aren't unique between modules, let alone instances of the same module
  - Our objects are represented by a hand full of ariables that are floating around in the ether
- Parameter initialization is a bit of a pain: having lingering variables aloft kinda sucks.

# Solving initialization

We'll turn the object into two functions: one that initializes and another that applies
- This is very functional programming:
  - `init: () -> T`
  - `apply: (T) -> A`

We are going to support this with the `Frame`: a mechanism to better control `get_param(...)`'s behavior
- If initializing, `get_param(...)` will create the param of the correct `shape` (a new argument) and include it among the current params in the `Frame` before returning our result
- Otherwise, the `shape` argument is ignored?

---
As an aside: we're going to tag in `numpy` because [random is effortful in JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers)

---

In [None]:
#| export
from typing import NamedTuple, Dict, Callable
import numpy as np
import jax
import jax.numpy as jnp

In [None]:
#| export
frame_stack = []

class Frame(NamedTuple):
    """Tracks mechanery state during a call of a transformed function."""
    params: Dict[str, jnp.ndarray]
    is_initializing: bool = False

    @classmethod
    def current_frame(cls): return frame_stack[-1]

class TransformedFunc(NamedTuple):
    init: Callable # [[], jnp.ndarray]
    apply: Callable # [[Frame.params], jnp.ndarray]

def transform(f) -> TransformedFunc:

    def init_f(*args, **kwargs):
        frame_stack.append(
            Frame({}, is_initializing=True)
        )
        f(*args, **kwargs) # why do we invoke f? because it's a module?
        frame = frame_stack.pop()
        return frame.params

    def apply_f(params, *args, **kwargs):
        frame_stack.append(Frame(params))
        outputs = f(*args, **kwargs)
        frame_stack.pop()
        return outputs

    return TransformedFunc(init_f, apply_f)

def get_param(identifier, shape=None):
    """Get parameter according to `identifier`, initializing with `shape` if necessary.

    Improvement over the tutorial code: we don't fall to race conditions. I know Pythonistas aren't
    generally concerned with that (what with the language being single-threaded and all) but it's a
    good thing to note: when you're modifying someting that's supposed to be atomic, you best make
    changes while you've got the thing. Never know if you're going to lose your grip.

    Args:
        identifier:
            valid str to identify the parameter
        shape (optional):
            MUST BE INCLUDED IF YOUR PARAMETER IS NOT YET INITIALIZED.
            ignored if the parameter has already been instantiated.
    """
    if (top_frame := Frame.current_frame()).is_initializing:
        top_frame.params[identifier] = np.random.normal(size=shape)

    return top_frame.params[identifier]

In [None]:
"""Testing the functionality that we've just implemented"""
def parameter_shapes(params):
    return jax.tree_util.tree_map(lambda p: p.shape, params)

class Linear:
    def __init__(self, width): self._width = width
    def __call__(self, x):
        w = get_param('w', shape=(x.shape[-1], self._width))
        b = get_param('b', shape=(self._width,))
        return x @ w + b

init, apply = transform(Linear(4))

data = jnp.ones((2, 3))

params = init(data) # this runs Linear.__call__ -> get_param(...) -> frame_stack.top().params[id] = blah
parameter_shapes(params)

{'b': (4,), 'w': (3, 4)}

In [None]:
apply(params, data)

Array([[-0.48675254, -0.06542091,  1.583561  ,  1.0822306 ],
       [-0.48675254, -0.06542091,  1.583561  ,  1.0822306 ]],      dtype=float32)

# Solving unique parameter names: finishing our mini-Haiku
So close to done already?! You know it!

This should also facilitate nesting modules.

1. Give each parameter an unambiguous name.
   - Our scheme will be different from - and incompatible with - real Haiku, but it'll be simpler and still correct
   - The key idea is assign a name based on the position in the call stack
2. We'll now define a `Module` class to solve (1) and more
   - each module will have a unique identifier a la `MyClass/instance-number`
3. We'll define a decorator for `Module` method called (wait for it...) `module_method`, which gives us better access to the call stack, and associated ability to parameter scope.
   - Real Haiku uses `metaclasses` for automatic method wrapping, we're doing this manually
     1. For learning
     2. For simplicity of implementation i.e. avoiding working at the underbellies of Python

In [None]:
#| export
import dataclasses, collections

In [None]:
#| export
@dataclasses.dataclass
class Frame:
    """Tracks what's going on during a call of a transformed function"""
    params: Dict[str, jnp.ndarray]
    is_initializing: bool = False

    """Keeps track of how many modules of each class have been created so far.
    Used to assign new modules unique names"""
    module_counts: Dict[str, int] = dataclasses.field(default_factory=collections.Counter)

    """Keeps track of the entire path to the current module method call.
    Module methods will add themselves to this stack when called.
    Used to give each parameter a unique name corresponding to stack location.
    """
    call_stack: list = dataclasses.field(default_factory=list)

    def create_param_path(self, identifier) -> str:
        """Creates a unique path for param identified by `identifier`"""
        return "/".join(["~"] + self.call_stack + [identifier])

    def create_unique_module_name(self, module_name: str) -> str:
        """creates a unique name to identify this module, by attending its instance count to its name"""
        number = self.module_counts[module_name]
        self.module_counts[module_name] += 1
        # concerns with this state modification:
        # 1. it only refers to this Frame, not all some communal global state
        # 2. create and updating don't have to be separate, but it can ease some thinking.
        return f"{module_name}_{number}"

    @classmethod
    @property
    def current(cls):
        "Current frame on the frame stack"
        return frame_stack[-1] if frame_stack else None

"global state for tracking frames"
frame_stack = []

Shmidge of test code

In [None]:
assert Frame.current is None

test_frame = Frame({})

frame_stack.append(test_frame)
assert Frame.current is test_frame
frame_stack.pop()

Frame(params={}, is_initializing=False, module_counts=Counter(), call_stack=[])

Now for the Module tidbits

In [None]:
#| export
class Module:
    def __init__(self):
        "Assign a unique name for instance of this module for the given `transform` call"
        self._unique_name = Frame.current.create_unique_module_name(
            self.__class__.__name__)

def module_method(f):
    """Decorate a Module method

    In the real Haiku, the user wouldn't see this, as it's handled by the metaclass"""
    def wrapped(self, *args, **kwargs):
        """A version of f that gives us some call stack information"""
        module_name = self._unique_name
        f_name = f.__name__

        call_stack = Frame.current.call_stack
        call_stack.append(module_name)
        call_stack.append(f_name)
        outputs = f(self, *args, **kwargs)
        assert call_stack.pop() == f_name
        assert call_stack.pop() == module_name
        return outputs

    return wrapped

def get_param(identifier, shape=()):
    frame = Frame.current
    param_path = frame.create_param_path(identifier)
    if frame.is_initializing:
        frame.params[param_path] = np.random.normal(size=shape)

    return frame.params[param_path]

class Linear(Module):
    def __init__(self, width):
        super().__init__()
        self._width = width

    @module_method
    def __call__(self, x):
        # same as before
        W = get_param('W', shape=(x.shape[-1], self._width))
        b = get_param('b', shape=(self._width,))
        return x @ W + b

## Additional notes
We don't need to rewrite `transform` because what we have is adequate as is.

Here's some functionality that's missing, but is a part of real Haiku
- [ ] control over initialization (we're just doing everything with `np.random.normal` (which really is going away any day now)
- [ ] Random Number Generation ("rng") handling
- [ ] State handling: easiest to implement, because it's conceptually analogous to parameter handling
- [ ] most validation and error handling
- [ ] freezing parameters once they're created
- [ ] more thread-safety
- [ ] JAX transforms in `transform`s
- [ ] JAX control flow inside of `transform`s
- [ ] More thorough documentation

Still, the bones are here. So let's take it for a spin

In [None]:
data, data.shape

(Array([[1., 1., 1.],
        [1., 1., 1.]], dtype=float32),
 (2, 3))

In [None]:
Linear(4)(data)

Array([[-0.03975022, -0.9340855 , -1.2082163 , -1.6628195 ],
       [-0.03975022, -0.9340855 , -1.2082163 , -1.6628195 ]],      dtype=float32)

In [None]:
init, apply = transform(lambda x: Linear(4)(x))

params = init(data)
parameter_shapes(params)

{'~/Linear_0/__call__/W': (3, 4), '~/Linear_0/__call__/b': (4,)}

In [None]:
apply(params, data)

Array([[-0.4722252, -1.3183072, -1.6448145, -5.9238143],
       [-0.4722252, -1.3183072, -1.6448145, -5.9238143]], dtype=float32)

What about an MLP?

In [None]:
class MLP(Module):

  def __init__(self, widths):
    super().__init__()
    self._widths = widths

  @module_method
  def __call__(self, x):
    for w in self._widths:
      out = Linear(w)(x)
      x = jax.nn.sigmoid(out)
    return out


In [None]:
init, apply = transform(lambda x: MLP([3, 5])(x))
parameter_shapes(init(data))

{'~/MLP_0/__call__/Linear_0/__call__/W': (3, 3),
 '~/MLP_0/__call__/Linear_0/__call__/b': (3,),
 '~/MLP_0/__call__/Linear_1/__call__/W': (3, 5),
 '~/MLP_0/__call__/Linear_1/__call__/b': (5,)}

the same module called multiple times holds the same parameters

In [None]:
class ParameterReuseTest(Module):

  @module_method
  def __call__(self, x):
    f = Linear(x.shape[-1])

    x = f(x)
    x = jax.nn.relu(x)
    return f(x)

init, forward = transform(lambda x: ParameterReuseTest()(x))
parameter_shapes(init(data))

{'~/ParameterReuseTest_0/__call__/Linear_0/__call__/W': (3, 3),
 '~/ParameterReuseTest_0/__call__/Linear_0/__call__/b': (3,)}

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()