In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import models
from torch.utils.data import DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
class RNN_Model(nn.Module):
    def __init__(self, num_classes, hidden_size, num_layers):
        super(RNN_Model, self).__init__()
        # Load the pretrained ResNet-18 model
        self.resnet = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)
        self.resnet = nn.Sequential(*(list(self.resnet.children())[:-1]))
        
        # RNN (LSTM) layer
        self.lstm = nn.LSTM(input_size=512, hidden_size=hidden_size,
                            num_layers=num_layers, batch_first=True)
        
        # Classification layer
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x shape: [batch, time, channels, height, width]
        batch_size, timesteps, C, H, W = x.size()
        
        # Flatten dimensions for ResNet
        x = x.view(batch_size * timesteps, C, H, W)
        
        # Feature extraction through ResNet
        with torch.no_grad():
            features = self.resnet(x)
        
        # Reshape for LSTM
        features = features.view(batch_size, timesteps, -1)
        
        # Sequence processing through LSTM
        lstm_out, _ = self.lstm(features)
        
        # Classification
        out = self.fc(lstm_out[:, -1, :])
        return out

# Hyperparameters
num_classes = 2 # Define the number of classes
hidden_size = 256 # LSTM hidden size
num_layers = 2 # Number of LSTM layers

# Model instance
model = RNN_Model(num_classes, hidden_size, num_layers)
model.to(device)

# Example input (batch size, time steps, channels, height, width)
input_tensor = torch.randn(1, 5, 3, 224, 224).to(device)

# Forward pass
output = model(input_tensor)
print(output)


tensor([[-0.0217,  0.1010]], device='cuda:0', grad_fn=<AddmmBackward0>)
