In [None]:
import torch
import torch.nn as nn

class HybridModel(nn.Module):
    def __init__(self, num_classes, input_channels=1, hidden_size=128):
        super(HybridModel, self).__init__()
        # CNN: Extract features from 64x64 grayscale images
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32x32
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 16x16
        )
        # RNN: Process features (single time step)
        self.rnn = nn.LSTM(input_size=64 * 16 * 16, hidden_size=hidden_size, num_layers=1, batch_first=True)
        # Linear layer: Output num_classes (1401)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x: (batch_size, 1, 64, 64)
        batch_size = x.size(0)
        x = self.cnn(x)  # Shape: (batch_size, 64, 16, 16)
        x = x.view(batch_size, 1, -1)  # Shape: (batch_size, 1, 64*16*16)
        x, _ = self.rnn(x)  # Shape: (batch_size, 1, hidden_size)
        x = self.fc(x)  # Shape: (batch_size, 1, num_classes)
        return x