In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class TemporalRCNN(nn.Module):
    def __init__(self, input_channels, hidden_dim, map_h, map_w):
        super(TemporalRCNN, self).__init__()

        # Spatial feature extraction layers
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.gap = nn.AdaptiveAvgPool2d((map_h, map_w))  # Reduce to [8x8]

        # LSTM for temporal dependencies
        self.lstm = nn.LSTM(input_size=64 * map_h * map_w, hidden_size=hidden_dim, batch_first=True)

        # Project LSTM output to spatial probability maps
        self.fc = nn.Linear(hidden_dim, map_h * map_w)  # Project to spatial map
        self.conv3 = nn.Conv2d(1, 1, kernel_size=1)  # Optional: Smooth output
        
        self.map_h = map_h
        self.map_w = map_w

    def forward(self, x, lengths):
        batch_size, seq_len, channels, height, width = x.size()  # [B, T, C, H, W]

        # Process each frame independently with convolutional layers
        x = x.view(batch_size * seq_len, channels, height, width)  # [B*T, C, H, W]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))

        # Apply adaptive pooling to reduce spatial size
        x = self.gap(x)  # [B*T, 64, 8, 8]
        x = x.view(batch_size, seq_len, -1)  # [B, T, 64*8*8]

        # Pack the sequences to handle variable lengths
        x_packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)

        # LSTM processes temporal information
        lstm_out, _ = self.lstm(x_packed)  # [B, T, hidden_dim]

        # Unpack the sequences
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)  # [B, T, hidden_dim]

        # Project LSTM output to spatial dimensions using the fully connected layer
        lstm_out = self.fc(lstm_out)  # [B, T, 8*8]
        lstm_out = lstm_out.view(batch_size * seq_len, 1, self.map_h, self.map_w)  # [B*T, 1, 8, 8]

        # Optionally apply a 1x1 convolution to smooth the output
        prob_map = torch.sigmoid(self.conv3(lstm_out))  # [B*T, 1, 8, 8]

        # Reshape back to sequence format
        prob_map = prob_map.view(batch_size, seq_len, self.map_h, self.map_w)  # [B, T, 8, 8]

        return prob_map

# Example usage
if __name__ == "__main__":
    # Example input: batch of 4 sequences, each with 5 frames, 3 channels, and 32x32 resolution
    input_data = torch.randn(4, 5, 3, 32, 32)  # [B, T, C, H, W]

    # Sequence lengths (for variable-length sequences)
    lengths = torch.tensor([5, 4, 3, 2])  # Different sequence lengths for each batch sample

    # Initialize the model
    model = TemporalRCNN(input_channels=3, hidden_dim=128, map_h=16, map_w=16)

    # Move model and data to the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    input_data = input_data.to(device).half()  # Half-precision

    # Forward pass with mixed precision
    with torch.cuda.amp.autocast():
        output = model(input_data, lengths)

    print("Output shape:", output)  # Should be [batch_size, seq_len, 8, 8]


Output shape: tensor([[[[0.2693, 0.2603, 0.2430,  ..., 0.2954, 0.2771, 0.2744],
          [0.2708, 0.2800, 0.2754,  ..., 0.2888, 0.2600, 0.3020],
          [0.2605, 0.2969, 0.2646,  ..., 0.2795, 0.2428, 0.2563],
          ...,
          [0.2335, 0.2825, 0.2656,  ..., 0.2800, 0.2681, 0.2859],
          [0.2485, 0.2739, 0.2502,  ..., 0.2815, 0.2856, 0.2834],
          [0.2771, 0.2703, 0.2969,  ..., 0.2595, 0.2307, 0.3025]],

         [[0.2820, 0.2581, 0.2294,  ..., 0.3047, 0.3032, 0.3018],
          [0.2612, 0.2898, 0.2864,  ..., 0.2529, 0.2433, 0.3301],
          [0.2659, 0.3005, 0.2888,  ..., 0.2681, 0.2314, 0.2698],
          ...,
          [0.2781, 0.2952, 0.2834,  ..., 0.2705, 0.2649, 0.2815],
          [0.2778, 0.2791, 0.2485,  ..., 0.2778, 0.2861, 0.2766],
          [0.2830, 0.2534, 0.3054,  ..., 0.2678, 0.2573, 0.2991]],

         [[0.2747, 0.2588, 0.2520,  ..., 0.3196, 0.3069, 0.2991],
          [0.2747, 0.2791, 0.2483,  ..., 0.2456, 0.2281, 0.3013],
          [0.2847, 0.3276, 0

In [17]:
criterion = nn.BCELoss()  # Binary cross-entropy loss
# Example target: [batch_size, seq_len, 8, 8]
target = torch.randint(0, 2, (4, 5, 16, 16)).float().to(device).half()  # Binary labels

# Forward pass to get predictions
with torch.cuda.amp.autocast():
    output = model(input_data, lengths)  # [B, T, 8, 8]

# Compute BCE loss
loss = criterion(output, target) # Half-precision
print("Loss:", loss.item())


Loss: 0.8076171875
