In [14]:
import torch
from torch import nn, optim
import numpy as np
import matplotlib.pyplot as plt

In [96]:
from typing import Any, List, Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor


class ConvLSTMCell(nn.Module):
    """ConvLSTM Cell"""

    def __init__(
        self,
        input_channel: int,
        hidden_channel: int,
        kernel_size: int,
        bias=True,
        activation=torch.tanh,
        batchnorm=False,
    ):
        """
        ConLSTM Cell

        Args:
            input_channel: Number of input channels
            hidden_channel: Number of hidden channels
            kernel_size: Kernel size
            bias: Whether to add bias
            activation: Activation to use
            batchnorm: Whether to use batch norm
        """
        super().__init__()

        self.input_channel = input_channel
        self.hidden_channel = hidden_channel

        self.kernel_size = kernel_size
        self.bias = bias
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(
            in_channels=self.input_channel + self.hidden_channel,
            out_channels=4 * self.hidden_channel,
            kernel_size=self.kernel_size,
            padding=(self.kernel_size //2, self.kernel_size //2),
            bias=self.bias,
            padding_mode="replicate"
        )

        self.reset_parameters()

    def forward(self, x: torch.Tensor, prev_state: list) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute forward pass

        Args:
            x: Input tensor of [Batch, Channel, Height, Width]
            prev_state: Previous hidden state

        Returns:
            The new hidden state and output
        """
        h_prev, c_prev = prev_state

        combined = torch.cat((x, h_prev), dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_channel, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)

        g = self.activation(cc_g)
        c_cur = f * c_prev + i * g

        o = torch.sigmoid(cc_o)

        h_cur = o * self.activation(c_cur)

        return h_cur, c_cur

    def init_hidden(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Initializes the hidden state
        Args:
            x: Input tensor to initialize for

        Returns:
            Tuple containing the hidden states
        """
        b, c, h, w = x.size()  # c = 1 even if using grayscale inputs
        state = (
            torch.zeros(b, self.hidden_channel, h, w),
            torch.zeros(b, self.hidden_channel, h, w),
        )
        state = (state[0].type_as(x), state[1].type_as(x))
        return state

    def reset_parameters(self) -> None:
        """Resets parameters"""
        nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain("tanh"))
        if self.bias:
            self.conv.bias.data.zero_()
        
        
class ConvLSTM(nn.Module):
    def __init__(
        self, 
        input_channels: int,
        hidden_channels: int,
        kernel_size: int, 
        num_layers: int, 
        bias: bool = False,
        activation = torch.tanh,
        return_state_history: bool = False
    ):
        super().__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.activation = activation
        self.num_layers = num_layers
        self.bias = bias
        self.return_state_history = return_state_history
        
        self.cells = nn.ModuleList([])
        
        for step in range(self.num_layers):
            self.cells.append(
                ConvLSTMCell(
                    input_channel=self.input_channels if step == 0 else self.hidden_channels,
                    hidden_channel=self.hidden_channels,
                    kernel_size=self.kernel_size,
                    bias=self.bias,
                    activation=activation,
                )
            )
            
    def forward(
        self, x: torch.Tensor, hidden_state: Optional[list] = None
    ) -> tuple[Tensor, Optional[list[tuple[Tensor, Tensor]]]]:
        """
        Computes the output of the ConvLSTM

        Args:
            x: Input Tensor of shape [Batch, Time, Channel, Width, Height]
            hidden_state: List of hidden states to use, if none passed, it will be generated

        Returns:
            The layer outputs and list of last states
        """
        current_layer_input = torch.unbind(x, dim=1)  # T x (B, C, H, W)
        
        if hidden_state is None:
            hidden_state = self.cells[0].init_hidden(current_layer_input[0])        
        h, c = hidden_state
        
        if self.return_state_history:
            last_state_list = list()
        else:
            last_state_list = None

        for layer in range(self.num_layers):
            output_inner = list()
            
            for t in range(len(current_layer_input)):
                h, c = self.cells[layer](x=current_layer_input[t], prev_state=[h, c])
                output_inner.append(h)

            current_layer_input = output_inner
            
            if self.return_state_history:
                last_state_list.append((h, c))

        layer_output = torch.stack(output_inner, dim=1)
        return layer_output, last_state_list

In [103]:
x = torch.rand((8, 10, 16, 256, 256))

model = ConvLSTM(16, 8, 3, 3)

In [104]:
layer_output, last_state_list = model(x)

In [105]:
layer_output.shape

torch.Size([8, 10, 8, 256, 256])

In [94]:
x = torch.rand((8, 10, 16, 256, 256))
cur_layer_input = torch.unbind(x, dim=1)

In [83]:
x.shape

torch.Size([8, 10, 16, 256, 256])

In [84]:
len(cur_layer_input)

10

In [85]:
cur_layer_input[0].shape

torch.Size([8, 16, 256, 256])

In [70]:
model(x)

ValueError: too many values to unpack (expected 4)