# An Introduction to Subensembles in Probly

The goal of this notebook is to showcase the subensemble transformation and how to use it.

## Different Examples from Different Libraries

The supported libraries are Flax nnx and Torch.<br>

Define simple models for both libaries to showcase the subensemble transformation.






In [None]:
import torch
from torch import nn


class TorchModel(nn.Module):
    def __init__(self) -> None:
        """Initialize the neural network model."""
        super().__init__()
        self.linear1 = nn.Linear(784, 64)
        self.linear2 = nn.Linear(64, 32)
        self.linear3 = nn.Linear(32, 16)
        self.linear4 = nn.Linear(16, 8)
        self.linear5 = nn.Linear(8, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the neural network."""
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        x = torch.relu(x)
        x = self.linear3(x)
        x = torch.relu(x)
        x = self.linear4(x)
        x = torch.relu(x)
        x = self.linear5(x)
        return x


pytorch_model = TorchModel()
print(pytorch_model)

In [None]:
from flax import nnx
import jax


class FlaxModel(nnx.Module, pytree=False):
    def __init__(self, *, rngs: nnx.Rngs) -> None:
        """Initialize the neural network model."""
        super().__init__()
        self.linear1 = nnx.Linear(784, 64, rngs=rngs)
        self.linear2 = nnx.Linear(64, 32, rngs=rngs)
        self.linear3 = nnx.Linear(32, 16, rngs=rngs)
        self.linear4 = nnx.Linear(16, 8, rngs=rngs)
        self.linear5 = nnx.Linear(8, 10, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass of the neural network."""
        self.linear1 = nnx.relu(self.linear1(x))
        self.linear2 = nnx.relu(self.linear2(self.linear1))
        self.linear3 = nnx.relu(self.linear3(self.linear2))
        return x


rng = nnx.Rngs(0)
flax_model = FlaxModel(rngs=rng)
nnx.display(flax_model)

Now we have to simple models from two different libraries and can use them to transform them into an subensemble.

## Down to Business

We can do different types of subensembles where we can choose the number of layers for heads or pass a model directly to be the head.<br>
In the following we show the different ways to use the subensemble transformation.

### 1. Default Example

We give the transformation a model and the number of heads but do not pass anymore arguments.

In [None]:
from probly.transformation.subensemble import subensemble

num_heads = 3


subensemble_flax = subensemble(flax_model, num_heads=num_heads)
nnx.display("Subensemble Flax Model:")
nnx.display(subensemble_flax)

subensemble_pytorch = subensemble(pytorch_model, num_heads=num_heads)
print(subensemble_pytorch)

Note, that the default layer used for the heads is the last layer as shown in the example above.

### 2. Head_Layer>1 Example

We give the transformation a model, mumber of heads and the number of layers for the heads.

In [None]:
from probly.transformation.subensemble import subensemble

num_heads = 3
head_layer = 2

subensemble_flax = subensemble(flax_model, num_heads=num_heads, head_layer=head_layer)
nnx.display(subensemble_flax)
subensemble_pytorch = subensemble(pytorch_model, num_heads=num_heads, head_layer=head_layer)
print(subensemble_pytorch)

Note, that the last two layers are used for the head if explicitly passed with the argument head_layer.

### 3. Full-model-as-head Example

We give the transformation a model, number of heads and a model to be used as head.

#### 3.1 TorchModel to pass as a head.

In [None]:
import torch
from torch import nn


class TorchHeadModel(nn.Module):
    def __init__(self) -> None:
        """Initialize the neural network model."""
        super().__init__()
        self.fc1 = nn.Linear(10, 8)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(8, 8)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(8, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the neural network."""
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x


# Example
torch_head_model = TorchHeadModel()
print(torch_head_model)

#### 3.2 FlaxModel to pass as a head.

In [None]:
from flax import nnx
import jax


class FlaxHeadModel(nnx.Module, pytree=False):
    def __init__(self, *, rngs: nnx.Rngs) -> None:
        """Initialize the neural network model."""
        super().__init__()
        self.linear1head = nnx.Linear(10, 8, rngs=rngs)
        self.linear2head = nnx.Linear(8, 8, rngs=rngs)
        self.linear3head = nnx.Linear(8, 10, rngs=rngs)

    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass of the neural network."""
        x = x.reshape((x.shape[0], -1))
        x = nnx.relu(self.linear1head(x))
        x = nnx.relu(self.linear2head(x))
        return self.linear3head(x)


rng = nnx.Rngs(0)
flax_head_model = FlaxHeadModel(rngs=rng)
nnx.display(flax_head_model)

#### 3.3 Transformation

In [None]:
from probly.transformation.subensemble import subensemble

num_heads = 3
head_flax = flax_head_model
head_torch = torch_head_model

subensemble_flax = subensemble(flax_model, num_heads=num_heads, head=head_flax)
nnx.display(subensemble_flax)
subensemble_pytorch = subensemble(pytorch_model, num_heads=num_heads, head=head_torch)
print(subensemble_pytorch)

Note, that the passed model for the head is in every head but be careful with the right dimensions.

### 4. Params for the heads

The subensemble transformation supports resetting the parameters in the different heads.<br>
To enable resetting one has to pass it as an argument.

#### 4.1 Resetting parameters


In [None]:
import numpy as np

from probly.transformation.subensemble import subensemble

num_heads = 3
reset_params = True

subensemble_flax = subensemble(flax_model, num_heads=num_heads, reset_params=reset_params)


for i in range(num_heads - 1):
    member = subensemble_flax[i].layers[-1]
    params_heads = jax.tree_util.tree_leaves(member)
    weights = np.array(params_heads[1])
    weights_rounded = np.round(weights, 2)
    print(f"\nFlaxHead {i} weights:")
    for row in weights_rounded:
        print(row)


subensemble_pytorch = subensemble(pytorch_model, num_heads=num_heads, reset_params=reset_params)

for i in range(num_heads - 1):
    member_last_layer = subensemble_pytorch[i][-1:]
    print(f"\nPyTorch Head {i} weights:")
    for name, param in member_last_layer.named_parameters():
        if "weight" in name:
            w = param.data.cpu().numpy()
            print(np.round(w, 2))

Note, that the heads have different values as weigths.

#### 4.2 Not-Resetting Parameters



In [None]:
import numpy as np

from probly.transformation.subensemble import subensemble

num_heads = 3
reset_params = False

subensemble_flax = subensemble(flax_model, num_heads=num_heads, reset_params=reset_params)


for i in range(num_heads - 1):
    head_layer = subensemble_flax[i].layers[-1]
    params_heads = jax.tree_util.tree_leaves(head_layer)
    weights = np.array(params_heads[1])
    weights_rounded = np.round(weights, 2)

    print(f"\nFlax Head {i} weights:")
    for row in weights_rounded:
        print(row)

subensemble_pytorch = subensemble(pytorch_model, num_heads=num_heads, reset_params=reset_params)


for i in range(num_heads - 1):
    member_last_layer = subensemble_pytorch[i][-1:]
    print(f"\nPyTorch Head {i} weights:")
    for name, param in member_last_layer.named_parameters():
        if "weight" in name:
            w = param.data.cpu().numpy()
            print(np.round(w, 2))

Note, that the heads have the same values as weigths.

### 5. Outputs of the Subensemble

The subensemble makes different predictions per head to make it more robust. 

In [None]:
import jax
import jax.numpy as jnp

from probly.representation.sampling.sampler import Sampler

reset_params = True
num_heads = 3


subensemble_flax = subensemble(flax_model, num_heads=num_heads, reset_params=reset_params)


for i, mem in enumerate(subensemble_flax):
    key = jax.random.PRNGKey(0)

    # 100 samples, 3 features
    X = jax.random.normal(key, (784, 3))
    X = jnp.expand_dims(X, axis=1)

    # Simple linear labels
    true_w = jnp.array([1.5, -2.0, 0.7])

    y = X @ true_w + 0.1 * jax.random.normal(key, (784,))
    data = y

    sample = Sampler(mem)
    pred = sample.predict(data, num_samples=5)
    predmean = pred.mean()

    print(f"Prediction-Mean of Head {i}: ")
    print(predmean[:3, :5])

In [None]:
import torch

from probly.representation.sampling.sampler import Sampler

reset_params = True
num_heads = 3


subensemble_pytorch = subensemble(pytorch_model, num_heads=num_heads, reset_params=reset_params)


for i, mem in enumerate(subensemble_pytorch):
    torch.manual_seed(5)

    # 784 samples, 3 features
    X = torch.randn(784, 3)
    X = X.unsqueeze(1)  # shape: (784, 1, 3)

    # Simple linear labels
    true_w = torch.tensor([1.5, -2.0, 0.7])

    y = X @ true_w + 0.1 * torch.randn(784)
    data = y

    sample = Sampler(mem)
    pred = sample.predict(data, num_samples=5)
    predmean = pred.mean()

    print(f"Prediction-Mean of Head {i}: ")
    print(predmean[:3, :5])