In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx

from ml_templates.pytree_factory import PyTreeFactory

# How to use `PyTreeFactory`

Make a factory for neural networks:

In [2]:
def make_mlp(**kwargs):
    if 'seed' not in kwargs:
        kwargs['seed'] = 0
    if 'activation' not in kwargs:
        kwargs['activation'] = jax.nn.gelu
    kwargs['key'] = jrandom.PRNGKey(kwargs.pop('seed'))
    return eqx.nn.MLP(**kwargs)

nn_factory = PyTreeFactory()
nn_factory.register_generator('mlp', make_mlp)

This factory (as of now) has only one thing it can generate&mdash;an MLP.

Let's also make a factory for input/output scalers:

In [6]:
class StandardScaler(eqx.Module):
    """Standardize input data by removing the mean and scaling to unit variance."""
    mean: float = eqx.field(default_factory=lambda: jnp.array(0.0))
    std: float = eqx.field(default_factory=lambda: jnp.array(1.0))

    @classmethod
    def fit(cls, data, axis=None):
        mean = data.mean(axis=axis)
        std = data.std(axis=axis)
        return cls(mean, std)

    def forward(self, data):
        return (data - self.mean) / self.std

    def inverse(self, data):
        return data * self.std + self.mean

def make_standard_scaler(shape=()):
    return StandardScaler(mean=jnp.zeros(shape), std=jnp.ones(shape))

scaler_factory = PyTreeFactory()
scaler_factory.register_generator('standard', StandardScaler)

Again, this factory can only make one thing&mdash;a standard scaler (i.e., a function that normalizes it's inputs).

Let's make a factory for what we'll call "unscaled" models.
By "unscaled" model, we mean a function that takes an input, scales it, passes it through a NN, and then unscales it.
Note that this factory depends on the output of the NN factory and the scaler factory.
So, we will need to create a nested factory pattern to obey this dependency.
We do this by *subclassing* the `PyTreeFactory` class.

Here is how:

In [7]:
# First, as before, define the type object the factory will produce.
class UnscaledModel(eqx.Module):
    """A model that scales its input before passing it to another model."""
    scaled_model: callable
    input_scaler: callable
    output_scaler: callable

    def __call__(self, x):
        x = self.input_scaler.forward(x)
        x = self.scaled_model(x)
        return self.output_scaler.inverse(x)

# Next, subclass PyTreeFactory to create a factory that is dependent on another factory.
class ModelFactory(PyTreeFactory):
    def __init__(self):
        self.parent_factories = {"nn": nn_factory, "scaler": scaler_factory}
        self.generators = {"unscaled_model": self.make_unscaled_model}

    def make_unscaled_model(
        self,
        scaled_model_gen_name,
        scaled_model_hyperparams,
        input_scaler_gen_name,
        input_scaler_hyperparams,
        output_scaler_gen_name,
        output_scaler_hyperparams,
    ):  
        # Use factory to generate the contents of the unscaled model
        scaled_model = self.parent_factories["nn"].generate(scaled_model_gen_name, scaled_model_hyperparams)
        input_scaler = self.parent_factories["scaler"].generate(input_scaler_gen_name, input_scaler_hyperparams)
        output_scaler = self.parent_factories["scaler"].generate(output_scaler_gen_name, output_scaler_hyperparams)
        
        # Create unscaled model
        return UnscaledModel(scaled_model=scaled_model, input_scaler=input_scaler, output_scaler=output_scaler)
    
# Finally, initialize the factory.
model_factory = ModelFactory()

Let's use our newly created model factory to generate a model:

In [13]:
unscaled_model_hyperparams = dict(
    scaled_model_gen_name="mlp",
    scaled_model_hyperparams=dict(in_size=3, out_size=2, width_size=30, depth=3),
    input_scaler_gen_name="standard",
    input_scaler_hyperparams=dict(mean=[1., 2., 3.], std=0.1),
    output_scaler_gen_name="standard",
    output_scaler_hyperparams=dict(mean=[100., 200., 300.], std=10.0),
)
unscaled_model = model_factory.generate("unscaled_model", unscaled_model_hyperparams)
unscaled_model

UnscaledModel(
  scaled_model=MLP(
    layers=(
      Linear(
        weight=f32[30,3],
        bias=f32[30],
        in_features=3,
        out_features=30,
        use_bias=True
      ),
      Linear(
        weight=f32[30,30],
        bias=f32[30],
        in_features=30,
        out_features=30,
        use_bias=True
      ),
      Linear(
        weight=f32[30,30],
        bias=f32[30],
        in_features=30,
        out_features=30,
        use_bias=True
      ),
      Linear(
        weight=f32[2,30],
        bias=f32[2],
        in_features=30,
        out_features=2,
        use_bias=True
      )
    ),
    activation=<function gelu>,
    final_activation=<function <lambda>>,
    use_bias=True,
    use_final_bias=True,
    in_size=3,
    out_size=2,
    width_size=30,
    depth=3
  ),
  input_scaler=StandardScaler(mean=[1.0, 2.0, 3.0], std=0.1),
  output_scaler=StandardScaler(mean=[100.0, 200.0, 300.0], std=10.0)
)

One of the nice things about this factory pattern is that we can easily save and load models:

In [12]:
model_factory.save_pytree(unscaled_model, "unscaled_model.eqx", "unscaled_model", unscaled_model_hyperparams)
loaded_unscaled_model = model_factory.load_pytree("unscaled_model.eqx")
eqx.tree_equal(unscaled_model, loaded_unscaled_model)

Array(True, dtype=bool)