In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx

from ml_templates.pytree_factory import PyTreeFactory

In [2]:
class StandardScaler(eqx.Module):
    """Standardize input data by removing the mean and scaling to unit variance."""
    mean: float = eqx.field(default_factory=lambda: jnp.array(0.0))
    std: float = eqx.field(default_factory=lambda: jnp.array(1.0))

    @classmethod
    def fit(cls, data, axis=None):
        mean = data.mean(axis=axis)
        std = data.std(axis=axis)
        return cls(mean, std)

    def forward(self, data):
        return (data - self.mean) / self.std

    def inverse(self, data):
        return data * self.std + self.mean

class UnscaledModel(eqx.Module):
    """A model that scales its input before passing it to another model."""
    scaled_model: eqx.Module
    input_scaler: StandardScaler
    output_scaler: StandardScaler

    def __call__(self, x):
        x = self.input_scaler.forward(x)
        x = self.scaled_model(x)
        return self.output_scaler.inverse(x)

In [3]:
def make_mlp(**kwargs):
    if 'seed' not in kwargs:
        kwargs['seed'] = 0
    if 'activation' not in kwargs:
        kwargs['activation'] = jax.nn.gelu
    kwargs['key'] = jrandom.PRNGKey(kwargs.pop('seed'))
    return eqx.nn.MLP(**kwargs)

def make_standard_scaler(shape=()):
    return StandardScaler(mean=jnp.zeros(shape), std=jnp.ones(shape))

# TODO: register subfactory method? A subfactory for unscaled models? Or register upstream factories?
# def make_unscaled_model(
#     scaled_model_gen_and_hyp,
#     input_scaler_gen_and_hyp,
#     output_scaler_gen_and_hyp
# ):
#     # Helper function
#     def _generate(gen_and_hyp):
#         generator_name, hyperparams = gen_and_hyp
#         return self.generate(generator_name, hyperparams)
    
#     # Use factory to generate the contents of the injection model
#     scaled_model = _generate(scaled_model_gen_and_hyp)
#     input_scaler = _generate(input_scaler_gen_and_hyp)
#     output_scaler = _generate(output_scaler_gen_and_hyp)
    
#     return UnscaledModel(scaled_model=scaled_model, input_scaler=input_scaler, output_scaler=output_scaler)



In [4]:
scaler_factory = PyTreeFactory()
scaler_factory.register_generator('standard', StandardScaler)

nn_factory = PyTreeFactory()
nn_factory.register_generator('mlp', make_mlp)

# model_factory = PyTreeFactory(parent_factories={'scaler': scaler_factory, 'nn': nn_factory})
# model_factory.register_generator('unscaled_model', make_unscaled_model)

# hyperparams = dict(

# )
# model_factory.generate("unscaled_model")

class ModelFactory(PyTreeFactory):
    def __init__(self):
        self.parent_factories = {"nn": nn_factory, "scaler": scaler_factory}
        self.generators = {"unscaled_model": self.make_unscaled_model}

    def make_unscaled_model(
        self,
        scaled_model_gen_name,
        scaled_model_hyperparams,
        input_scaler_gen_name,
        input_scaler_hyperparams,
        output_scaler_gen_name,
        output_scaler_hyperparams,
    ):  
        # Use factory to generate the contents of the unscaled model
        scaled_model = self.parent_factories["nn"].generate(scaled_model_gen_name, scaled_model_hyperparams)
        input_scaler = self.parent_factories["scaler"].generate(input_scaler_gen_name, input_scaler_hyperparams)
        output_scaler = self.parent_factories["scaler"].generate(output_scaler_gen_name, output_scaler_hyperparams)
        
        # Create unscaled model
        return UnscaledModel(scaled_model=scaled_model, input_scaler=input_scaler, output_scaler=output_scaler)

In [17]:
model_factory = ModelFactory()
unscaled_model_hyperparams = dict(
    scaled_model_gen_name="mlp",
    scaled_model_hyperparams=dict(in_size=3, out_size=2, width_size=30, depth=3),
    input_scaler_gen_name="standard",
    input_scaler_hyperparams=dict(mean=[1., 2., 3.], std=0.1),
    output_scaler_gen_name="standard",
    output_scaler_hyperparams=dict(mean=[100., 200., 300.], std=10.0),
)
unscaled_model = model_factory.generate("unscaled_model", unscaled_model_hyperparams)

In [21]:
model_factory.save_pytree(unscaled_model, "unscaled_model.eqx", "unscaled_model", unscaled_model_hyperparams)
loaded_unscaled_model = model_factory.load_pytree("unscaled_model.eqx")

In [22]:
eqx.tree_equal(unscaled_model, loaded_unscaled_model)

Array(True, dtype=bool)

In [None]:
class NestedFactory(PyTreeFactory):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_generator('mlp', make_mlp)
        self.register_generator('standard_scaler', make_standard_scaler)
        self.register_subfactory('unscaled_model', make_unscaled_model)

    def make_unscaled_model(self, scaled_model_gen_and_hyp, input_scaler_gen_and_hyp, output_scaler_gen_and_hyp):
        return self.generate('unscaled_model', (scaled_model_gen_and_hyp, input_scaler_gen_and_hyp, output_scaler_gen_and_hyp))

In [13]:
my_factory = PyTreeFactory()    
my_factory.register_generator("make_mlp", make_mlp)
my_factory.register_generator("make_standard_scaler", make_standard_scaler)
my_factory.register_generator("make_unscaled_model", make_unscaled_model)

In [14]:
my_factory.generate("make_mlp", dict(in_size=3, out_size=2, width_size=30, depth=3))

MLP(
  layers=(
    Linear(
      weight=f32[30,3],
      bias=f32[30],
      in_features=3,
      out_features=30,
      use_bias=True
    ),
    Linear(
      weight=f32[30,30],
      bias=f32[30],
      in_features=30,
      out_features=30,
      use_bias=True
    ),
    Linear(
      weight=f32[30,30],
      bias=f32[30],
      in_features=30,
      out_features=30,
      use_bias=True
    ),
    Linear(
      weight=f32[2,30],
      bias=f32[2],
      in_features=30,
      out_features=2,
      use_bias=True
    )
  ),
  activation=<function gelu>,
  final_activation=<function <lambda>>,
  use_bias=True,
  use_final_bias=True,
  in_size=3,
  out_size=2,
  width_size=30,
  depth=3
)