In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torch.nn.functional as F


class Net(nn.Module):
    """Simple Neural Network class with two heads.

    Attributes:
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        fc31: nn.Module, fully connected layer of first head
        fc32: nn.Module, fully connected layer of second head
    """

    def __init__(self) -> None:
        """Initialize an instance of the Net class."""
        super().__init__()
        self.fc1 = nn.Linear(1, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc31 = nn.Linear(32, 1)

        self.fc32 = nn.Linear(32, 1)
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the neural network.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        mu = self.fc31(x)
        sigma2 = F.softplus(self.fc32(x))
        x = torch.cat([mu, sigma2], dim=1)
        return x

In [3]:
import jax

jax.Array.__qualname__

'Array'

In [4]:
import jax.numpy as jnp

isinstance(jnp.array([1, 2, 3]), jax.Array)

True

In [5]:
from probly.transformation import dropout

net = Net()
drop_net = dropout(net, p=0.1)

In [6]:
net

Net(
  (fc1): Linear(in_features=1, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=32, bias=True)
  (fc31): Linear(in_features=32, out_features=1, bias=True)
  (fc32): Linear(in_features=32, out_features=1, bias=True)
  (act): ReLU()
)

In [7]:
drop_net

Net(
  (fc1): Linear(in_features=1, out_features=32, bias=True)
  (fc2): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=32, bias=True)
  )
  (fc31): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=1, bias=True)
  )
  (fc32): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=32, out_features=1, bias=True)
  )
  (act): ReLU()
)

In [8]:
drop_net.eval()
drop_net(torch.tensor([[1.0]]))

tensor([[-0.0048,  0.5399]], grad_fn=<CatBackward0>)

In [15]:
from probly.representation import Distribution

x = torch.rand((64, 1))
distribution = Distribution(drop_net)
outputs = distribution.predict(x, num_samples=20).tensor
print(outputs.shape)

torch.Size([64, 20, 2])


In [16]:
from probly.transformation import ensemble

ensemble = ensemble(Net(), n_members=5)

In [19]:
from probly.representation.sampling.sampler import EnsembleSampler

ensemble_sampler = EnsembleSampler(ensemble)
x = torch.rand((64, 1))
outputs = ensemble_sampler.sample(x).tensor
print(outputs.shape)

NotImplementedError: Module [ModuleList] is missing the required "forward" function