In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2

        self.conv = nn.Conv2d(in_channels=input_channels + hidden_channels,
                              out_channels=4 * hidden_channels,
                              kernel_size=kernel_size, padding=self.padding)

    def forward(self, input_tensor, hidden_state):
        h_cur, c_cur = hidden_state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel dimension

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_c = torch.split(combined_conv, self.hidden_channels, dim=1)
        
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        c = f * c_cur + i * torch.tanh(cc_c)
        h = o * torch.tanh(c)

        return h, c

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_channels, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_channels, height, width, device=self.conv.weight.device))


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define your data dimensions
batch_size = 15  # Treated as sequence length in this case
channels = 32
height = width = 128

# Instantiate ConvLSTM cell with matching input and hidden state channels
conv_lstm_cell = ConvLSTMCell(input_channels=channels, hidden_channels=channels, kernel_size=3)

# Dummy input with dimensions [15, 32, 128, 128]
input_tensor = torch.randn(batch_size, channels, height, width)

# Initial hidden state
hidden_state = conv_lstm_cell.init_hidden(1, (height, width))  # Batch size is 1 for each time point

# Process each time point through the ConvLSTM cell
output_sequence = []
for t in range(batch_size):
    hidden_state = conv_lstm_cell(input_tensor[t].unsqueeze(0), hidden_state)  # Add batch dimension
    output_sequence.append(hidden_state[0].squeeze(0))  # Remove batch dimension

# Convert list of tensors to a tensor
output_tensor = torch.stack(output_sequence)

# output_tensor now has the shape [15, 32, 128, 128], same as the input


In [3]:
output_tensor.shape

torch.Size([15, 32, 128, 128])