In [None]:
from flax import nnx
import jax
import jax.numpy as jnp

## Flax CNN example

modified from [flax nnx docs MNIST tutorial](https://flax.readthedocs.io/en/latest/mnist_tutorial.html#mnist-tutorial).

In [None]:
from functools import partial


class CNN(nnx.Module):
    """A simple CNN model."""

    def __init__(self, *, rngs: nnx.Rngs) -> None:
        """Init."""
        self.conv1 = nnx.Conv(1, 10, kernel_size=(1, 1), rngs=rngs)
        self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs)
        self.conv2 = nnx.Conv(10, 20, kernel_size=(1, 1), rngs=rngs)
        self.avg_pool = partial(nnx.avg_pool, window_shape=(1, 1), strides=(1, 1))
        self.linear1 = nnx.Linear(20, 100, rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs)
        self.linear2 = nnx.Linear(100, 10, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        """Call."""
        x = self.avg_pool(nnx.relu(self.conv1(x)))
        x = self.avg_pool(nnx.relu(self.conv2(x)))
        x = x.reshape(x.shape[0], -1)  # flatten
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)
        return x


# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))
# Visualize it.
nnx.display(model)

In [None]:
from probly.transformation import ensemble

cnn_ensemble = ensemble(model, num_members=2, reset_params=True, key=2)
nnx.display(cnn_ensemble)

In [None]:
print(f"base model conv1 kernel value:\n{model.conv1.kernel.value}")
print(f"ensemble member 1 conv1 kernel value:\n{cnn_ensemble[0].conv1.kernel.value}")
print(f"ensemble member 2 conv1 kernel value:\n{cnn_ensemble[1].conv1.kernel.value}")

Resetting the parameters works for models with architecture similar to the CNN from the Flax nnx docs. Each ensemble `member` is initialized with its own key, resulting in different kernel values in this instance.

#### Call

In [None]:
seed = 42
x = jax.random.normal(jax.random.key(seed), shape=(1, 1, 1, 1))
cnn_out = model(x)
print(f"base cnn:\n{cnn_out}")
print("ensemble:")
for i, member in enumerate(cnn_ensemble):
    member_out = member(x)
    print(f"member {i + 1}:\n{member_out}")

We can see that the outputs from the `base cnn` and the `ensemble` are different. The outputs of the ensemble members diversify as a result of their independent initializations.

## Flax nnx block architecture

Let's have a look how this ensemble functionality behaves for nested models:

In [None]:
class MixedBlock(nnx.Module):
    def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int) -> None:
        """Init."""
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nnx.Linear(rngs=rngs, in_features=self.in_features, out_features=self.out_features)
        self.layernorm = nnx.LayerNorm(rngs=rngs, num_features=out_features)
        self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)

        # Convolution branch
        self.conv = nnx.Conv(
            rngs=rngs, in_features=self.in_features, out_features=self.out_features, kernel_size=(3, 3), padding="SAME"
        )

        # Recurrent branch
        self.gru = nnx.GRUCell(rngs=rngs, in_features=self.in_features, hidden_features=self.out_features)

    def __call__(self, x: jax.Array, *, rngs: nnx.Rngs = None) -> jax.Array:
        """Call."""
        # flatten input for Linear if needed
        h = x

        if h.ndim > 2:
            h = h.reshape(h.shape[0], -1)

        # dense -> norm -> activation -> dropout
        h1 = self.linear(h)
        h1 = self.layernorm(h1)
        h1 = jax.nn.gelu(h1)
        h1 = self.dropout(h1, rngs=rngs)

        # conv -> activation
        if x.ndim == 4:
            h2 = jax.nn.gelu(self.conv(x))
            h2 = h2.reshape(h2.shape[0], -1)
        else:
            h2 = 0

        # combination
        h = h1 + h2

        # recurrent update
        batch_size = x.shape[0]
        carry = jnp.zeros((batch_size, h.shape[-1]))
        carry, h_rec = self.gru(carry, h)

        return h + h_rec


class NestedFlaxModel(nnx.Module):
    def __init__(self, rngs: nnx.Rngs, in_features: int = 80, out_features: int = 80) -> None:
        """Init."""
        # multiple nested blocks
        self.block = MixedBlock(rngs=rngs, in_features=in_features, out_features=out_features)

        # add attention block
        self.attention = nnx.MultiHeadAttention(
            rngs=rngs,
            num_heads=4,
            in_features=in_features,
            qkv_features=out_features,
            out_features=out_features,
        )

        # add residual MLP
        self.mlp = nnx.Sequential(
            nnx.Linear(rngs=rngs, in_features=in_features, out_features=out_features),
            nnx.LayerNorm(rngs=rngs, num_features=out_features),
            jax.nn.relu,
            nnx.Linear(rngs=rngs, in_features=out_features, out_features=out_features),
        )

    def __call__(self, x: jax.Array, *, rngs: nnx.Rngs = None) -> jax.Array:
        """Call."""
        batch = x.shape[0]
        x_flat = x.reshape(batch, -1)

        x_flat = self.block(x_flat, rngs=rngs)

        # self-attention
        x_att = x_flat[:, None, :]
        x_att = self.attention(x_att, x_att, x_att)
        x_att = x_att.squeeze(1)

        # MLP residual
        x_mlp = self.mlp(x_flat)

        x_out = x_mlp + x_att

        return x_out


rngs = nnx.Rngs(1)
nested_flax_model = NestedFlaxModel(rngs=rngs, in_features=100, out_features=100)
nnx.display(nested_flax_model)

In [None]:
from probly.transformation.ensemble import ensemble

flax_ensemble = ensemble(nested_flax_model, num_members=2, reset_params=True, key=2)
nnx.display(flax_ensemble)

`NestedFlaxModel.__init__` has the params `in_features, out_features` with default `80`.

With the `iter_children()` logic in the `reset_traverser` the `NestedFlaxModel` is not called and therefore not reinitialized. This should be a neat way to reset flax models.