In [22]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
from dataclasses import dataclass
from typing import Tuple, Union, overload


class MaskFactory:
    def full(self, kernel_size: Tuple[int, ...]) -> Tensor:
        return torch.ones(kernel_size)

    def causal(self, kernel_size: Tuple[int, ...]) -> Tensor:
        idx = torch.meshgrid([torch.arange(ki) for ki in kernel_size])
        idx = torch.cat(idx, -1)
        cutoff = torch.tensor([*kernel_size])
        return 2 * idx < cutoff


class MaskedConv(nn.Module):
    def __init__(self, conv: Union[nn.Conv1d, nn.Conv2d], mask: Tensor):
        super().__init__()
        self.mask = mask
        self.register_buffer("mask", self.mask)
        self.conv = conv
        assert self.conv.kernel_size == self.mask.shape

    def forward(self, x: Tensor) -> Tensor:
        self.conv.weight *= self.mask
        return self.conv(x)


class LSTMCell(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, bias=True):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = self.H = hidden_features
        self.bias = bias

        self.input_map = nn.Linear(in_features, 4 * hidden_features, bias=bias)
        self.state_map = nn.Linear(hidden_features, 4 * hidden_features, bias=bias)

    def forward(self, x: Tensor, h_c: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        h, c = h_c
        gates = self.state_map(h) + self.input_map(x)
        i, f, g, o = torch.split(gates, self.H, dim=1)
        next_c = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
        next_h = torch.sigmoid(o) * next_c
        return next_h, next_c


class LSTM(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, bias=True):
        super().__init__()
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.bias = bias

        self.cell = LSTMCell(in_features, hidden_features, bias)

    def _forward_tensor(
        self, xs: Tensor, h_c: Tuple[Tensor, Tensor]
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        # x.shape = [L, B, N_in]
        h, c = h_c  # [B, H]
        hs = []
        for x in xs:
            h, c = self.cell(x, (h, c))
            hs.append(h)
        hs = torch.stack(hs)
        return hs, (h, c)

    def _forward_packed(
        self, seq: rnn.PackedSequence, h_c: Tuple[Tensor, Tensor]
    ) -> Tuple[rnn.PackedSequence, Tuple[Tensor, Tensor]]:
        h, c = h_c  # [N, H]
        h, c = h[seq.sorted_indices], c[seq.sorted_indices]
        hs = torch.empty(len(seq.data), self.hidden_features)
        cur_offset = 0
        for step_batch in seq.batch_sizes:
            cur_slice = slice(cur_offset, cur_offset + step_batch)
            h[:step_batch], c[:step_batch] = self.cell(
                seq.data[cur_slice], 
                (h[:step_batch], c[:step_batch]),
            )
            hs[cur_slice] = h[:step_batch]
            cur_offset += step_batch
        h_seq = rnn.PackedSequence(
            hs, seq.batch_sizes, seq.sorted_indices, seq.unsorted_indices
        )
        h, c = h[seq.unsorted_indices], c[seq.unsorted_indices]
        return h_seq, (h, c)

    @overload
    def forward(
        self, seq: Tensor, h_c: Tuple[Tensor, Tensor]
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        ...

    @overload
    def forward(
        self, seq: rnn.PackedSequence, h_c: Tuple[Tensor, Tensor]
    ) -> Tuple[rnn.PackedSequence, Tuple[Tensor, Tensor]]:
        ...

    def forward(self, seq, h_c):
        if isinstance(seq, Tensor):
            return self._forward_tensor(seq, h_c)
        elif isinstance(seq, rnn.PackedSequence):
            return self._forward_packed(seq, h_c)


In [23]:
N = 8
gen = torch.Generator().manual_seed(0)
lengths = torch.randint(16, 32, (N,), generator=gen)
seq_list = [torch.rand((seq_len, 16), generator=gen) for seq_len in lengths]
seq = rnn.pack_sequence(seq_list, enforce_sorted=False)

lstm = LSTM(16, 32)
h, c = torch.rand((N, 32), generator=gen), torch.rand((N, 32), generator=gen)
lstm(seq, (h, c))

(PackedSequence(data=tensor([[ 0.2645, -0.0116,  0.2647,  ...,  0.2017,  0.1374,  0.1409],
         [ 0.2498,  0.2412,  0.2512,  ...,  0.2587,  0.2885,  0.0894],
         [ 0.3387,  0.0392,  0.1993,  ...,  0.1035,  0.1053,  0.2107],
         ...,
         [ 0.2644, -0.1400, -0.0008,  ...,  0.0713,  0.0903, -0.1108],
         [ 0.2506, -0.1405, -0.1005,  ...,  0.1730,  0.0544, -0.0716],
         [ 0.1099, -0.1045,  0.0009,  ...,  0.0914,  0.0951, -0.1583]],
        grad_fn=<CopySlices>), batch_sizes=tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 5, 5, 4, 4, 3,
         3, 3, 3, 2, 1, 1, 1]), sorted_indices=tensor([1, 0, 5, 7, 2, 4, 6, 3]), unsorted_indices=tensor([1, 0, 4, 7, 5, 2, 6, 3])),
 (tensor([[ 1.8957e-01, -6.5620e-03,  8.8885e-02,  1.5788e-01, -1.3660e-01,
            1.3060e-01, -3.2869e-02,  1.6255e-02,  8.9979e-02,  1.7170e-01,
            4.5223e-02, -2.2077e-01,  5.3314e-02,  2.5696e-01,  1.8950e-01,
            4.0378e-01, -9.6875e-02, -9.3238e-02, -1.32

In [10]:
cell = LSTMCell(16, 32)
x = torch.empty(1, 16)
h, c = torch.empty(1, 32), torch.empty(1, 32)
cell(x, (h, c))

(tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]], grad_fn=<MulBackward0>),
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan, nan, nan, nan, nan, nan, nan]], grad_fn=<AddBackward0>))