In [6]:
import torch
import torch.nn as nn
from torchvision import models

In [8]:
model = models.mobilenet_v2(pretrained=True)
model



MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [1]:
import torch
import torch.nn as nn
from torchvision import models

class MobileNetV2_LSTM(nn.Module):
    def __init__(self, num_classes, hidden_dim=256, lstm_layers=1, bidirectional=False):
        super(MobileNetV2_LSTM, self).__init__()
        
        self.feature_extractor = models.mobilenet_v2(pretrained=True).features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.lstm = nn.LSTM(1280, hidden_dim, lstm_layers, batch_first=True, bidirectional=bidirectional)
        
        direction = 2 if bidirectional else 1
        self.fc = nn.Linear(hidden_dim * direction, num_classes)
        
    def forward(self, x):
        batch_size, timesteps, C, H, W = x.size()
        
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.feature_extractor(c_in)
        c_out = self.pool(c_out)
        c_out = c_out.view(batch_size, timesteps, -1)
        
        lstm_out, _ = self.lstm(c_out)
        
        out = self.fc(lstm_out[:, -1, :])
        return out

# Example usage
model = MobileNetV2_LSTM(num_classes=10)
input_tensor = torch.randn(8, 16, 3, 224, 224)  # Batch of 8 videos, each with 16 frames of 224x224 RGB images
output = model(input_tensor)
print(output.shape)  # Should be [8, 10] for batch size of 8 and 10 classes


  Referenced from: <85A36C65-3F71-3C3B-B529-961AE17DBE73> /Users/szaboreka/anaconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(


torch.Size([8, 10])
