In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# --------------------------
# CONFIGURATION
# --------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = 5000
embed_size = 256
hidden_size = 512
max_seq_len = 20
batch_size = 4
video_feat_dim = 1024  # Pretend video features vector size
frames = 16  # Number of frames for 3D CNN
height, width = 64, 64  # Frame spatial size for 3D CNN

# --------------------------
# MODELS
# --------------------------

class VideoCaptioningLSTM(nn.Module):
    def __init__(self, input_dim, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = nn.Linear(input_dim, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, video_feat, captions):
        video_feat = self.encoder(video_feat).unsqueeze(1)  # (B,1,embed)
        embeddings = self.embed(captions)  # (B, seq_len, embed)
        inputs = torch.cat((video_feat, embeddings), dim=1)  # prepend video feat
        outputs, _ = self.lstm(inputs)
        outputs = self.fc(outputs)
        return outputs

class BiLSTM(nn.Module):
    def __init__(self, input_dim, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = nn.Linear(input_dim, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, video_feat, captions):
        video_feat = self.encoder(video_feat).unsqueeze(1)
        embeddings = self.embed(captions)
        inputs = torch.cat((video_feat, embeddings), dim=1)
        outputs, _ = self.lstm(inputs)
        outputs = self.fc(outputs)
        return outputs

class VideoCaptioningGRU(nn.Module):
    def __init__(self, input_dim, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = nn.Linear(input_dim, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, video_feat, captions):
        video_feat = self.encoder(video_feat).unsqueeze(1)
        embeddings = self.embed(captions)
        inputs = torch.cat((video_feat, embeddings), dim=1)
        outputs, _ = self.gru(inputs)
        outputs = self.fc(outputs)
        return outputs

class BiGRU(nn.Module):
    def __init__(self, input_dim, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = nn.Linear(input_dim, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.gru = nn.GRU(embed_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, vocab_size)

    def forward(self, video_feat, captions):
        video_feat = self.encoder(video_feat).unsqueeze(1)
        embeddings = self.embed(captions)
        inputs = torch.cat((video_feat, embeddings), dim=1)
        outputs, _ = self.gru(inputs)
        outputs = self.fc(outputs)
        return outputs

class TransformerCaptioning(nn.Module):
    def __init__(self, input_dim, embed_size, vocab_size, num_heads=4, num_layers=2, max_seq_len=20):
        super().__init__()
        self.encoder = nn.Linear(input_dim, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = nn.Parameter(torch.randn(1, max_seq_len + 1, embed_size))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, video_feat, captions):
        B = captions.size(0)
        video_feat = self.encoder(video_feat).unsqueeze(1)
        embeddings = self.embed(captions)
        x = torch.cat([video_feat, embeddings], dim=1)
        x = x + self.pos_encoder[:, :x.size(1), :]
        x = self.transformer(x)
        outputs = self.fc(x)
        return outputs

class CNN3D_LSTM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.cnn3d = nn.Sequential(
            nn.Conv3d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d((1,1,1))
        )
        self.fc_feat = nn.Linear(128, embed_size)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, video_clip, captions):
        B = video_clip.size(0)
        features = self.cnn3d(video_clip).view(B, -1)  # (B, 128)
        video_emb = self.fc_feat(features).unsqueeze(1)  # (B, 1, embed)
        embedded = self.embed(captions)  # (B, seq_len, embed)
        inputs = torch.cat((video_emb, embedded), dim=1)
        outputs, _ = self.lstm(inputs)
        outputs = self.fc(outputs)
        return outputs

# --------------------------
# HELPER: SELECT MODEL
# --------------------------
def get_model(model_name):
    if model_name == 'lstm':
        return VideoCaptioningLSTM(video_feat_dim, embed_size, hidden_size, vocab_size)
    elif model_name == 'bilstm':
        return BiLSTM(video_feat_dim, embed_size, hidden_size, vocab_size)
    elif model_name == 'gru':
        return VideoCaptioningGRU(video_feat_dim, embed_size, hidden_size, vocab_size)
    elif model_name == 'bigru':
        return BiGRU(video_feat_dim, embed_size, hidden_size, vocab_size)
    elif model_name == 'transformer':
        return TransformerCaptioning(video_feat_dim, embed_size, vocab_size, max_seq_len=max_seq_len)
    elif model_name == '3dcnn_lstm':
        return CNN3D_LSTM(vocab_size, embed_size, hidden_size)
    else:
        raise ValueError(f"Unknown model name: {model_name}")

# --------------------------
# DUMMY DATA FOR TESTING
# --------------------------
# For non-3dcnn models: video features + captions
video_feats = torch.randn(batch_size, video_feat_dim).to(device)
captions = torch.randint(0, vocab_size, (batch_size, max_seq_len)).to(device)

# For 3D CNN model: raw video clip + captions
video_clips = torch.randn(batch_size, 3, frames, height, width).to(device)

# --------------------------
# TRAIN LOOP EXAMPLE
# --------------------------
def train(model_name='lstm', epochs=2):
    model = get_model(model_name).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    print(f"Training {model_name} model...")

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        if model_name == '3dcnn_lstm':
            outputs = model(video_clips, captions)
        else:
            outputs = model(video_feats, captions)

        # outputs shape: (batch, seq_len+1, vocab_size)
        # We want to predict next words, so shift captions by one for target
        targets = captions  # Simplified target
        outputs = outputs[:, 1:, :]  # remove first token prediction to match targets length

        loss = criterion(outputs.reshape(-1, vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# --------------------------
# RUN TRAINING FOR ALL MODELS
# --------------------------
if __name__ == "__main__":
    for model_name in ['lstm', 'bilstm', 'gru', 'bigru', 'transformer', '3dcnn_lstm']:
        train(model_name)
        print("="*40)


Training lstm model...
Epoch 1, Loss: 8.5105
Epoch 2, Loss: 8.2992
Training bilstm model...
Epoch 1, Loss: 8.5178
Epoch 2, Loss: 8.1983
Training gru model...
Epoch 1, Loss: 8.5276
Epoch 2, Loss: 8.1397
Training bigru model...
Epoch 1, Loss: 8.5319
Epoch 2, Loss: 7.9329




Training transformer model...
Epoch 1, Loss: 8.6476
Epoch 2, Loss: 6.8612
Training 3dcnn_lstm model...
Epoch 1, Loss: 8.5180
Epoch 2, Loss: 8.3139
