In [3]:
import torch
import torch.nn as nn

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, units, activation, name="residual_block", **kwargs):
        super(ResidualBlock, self).__init__()
        self._units = units
        assert self._units[0] == self._units[-1], "First and last units must be the same"

        self._activation = activation

        self._layers = nn.ModuleList(
            [nn.Linear(units[i], units[i+1]) for i in range(len(units)-1)]
        )

    def forward(self, inputs):
        residual = inputs
        for h_i in self._layers:
            inputs = self._activation(h_i(inputs))
        residual = residual + inputs
        return residual

In [13]:
resnet = ResidualBlock([1, 2, 3, 1], nn.ReLU())

In [14]:
inputs = torch.randn(10, 1)

In [15]:
inputs = resnet(inputs)