In [None]:
from dataclasses import KW_ONLY
from typing import Optional

import torch
from torch import Tensor, nn

from tsdm.models.activations import ACTIVATIONS
from tsdm.utils.config import Config

In [None]:
class Dense(nn.Module):
    class HP(Config):
        input_size: int
        output_size: int
        activation: str | nn.Module | Config = "relu"

    def __init__(
        self, input_size: int, output_size: int, activation: str | nn.Module = "ReLU"
    ):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)

        if isinstance(activation, str):
            activation = ACTIVATIONS[activation]()
        self.activation = activation

    def forward(self, x: Tensor) -> Tensor:
        x = self.linear(x)
        x = self.activation(x)
        return x


model = Dense(3, 5)
x = torch.randn(16, 3)
model(x)

In [None]:
class MLP(nn.Sequential):
    class HP(Config):
        input_size: int
        output_size: int
        _: KW_ONLY
        latent_size: Optional[int] = None
        num_hidden: int = 0

    config: HP

    def __init__(self, *layers: nn.Module) -> None:
        super().__init__(*layers)

    @classmethod
    def from_config(
        cls,
        *args,
        **kwargs,
    ):
        config = cls.HP(*args, **kwargs)
        config |= {"latent_size": 1}
        layers: list[nn.Module] = []

        # input layer
        layer = nn.Linear(config.input_size, config.latent_size)
        nn.init.kaiming_normal_(layer.weight, nonlinearity="linear")
        nn.init.kaiming_normal_(layer.bias[None], nonlinearity="linear")
        layers.append(layer)

        # hidden layers
        for _ in range(config.num_hidden):
            layers.append(nn.ReLU())
            layer = nn.Linear(config.latent_size, config.latent_size)
            nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
            nn.init.kaiming_normal_(layer.bias[None], nonlinearity="relu")
            layers.append(layer)

        # output_layer
        layers.append(nn.ReLU())
        layer = nn.Linear(config.latent_size, config.output_size)
        nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
        nn.init.kaiming_normal_(layer.bias[None], nonlinearity="relu")
        layers.append(layer)

        module = cls(*layers)
        module.config = config
        return module

In [None]:
MLP.from_config(64, 10, latent_size=32, num_hidden=2)