In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel
from sklearn.decomposition import PCA
import numpy as np

In [2]:
# ================= Temporal Event Proposal =================
class TemporalEventProposal(nn.Module):
    def __init__(self, input_dim):
        super(TemporalEventProposal, self).__init__()
        self.gru = nn.GRU(input_dim, 512, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(1024, 2)  # Predict start and end boundaries

    def forward(self, x):
        outputs, _ = self.gru(x)
        proposals = self.fc(outputs)
        return proposals

In [3]:
# ================= Event Captioning Module =================
class TemporalSemanticRelationModule(nn.Module):
    def __init__(self, input_dim):
        super(TemporalSemanticRelationModule, self).__init__()
        self.fc = nn.Linear(input_dim, 512)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.fc(x))

In [4]:
class CaptionGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(CaptionGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, captions, features):
        embeddings = self.embedding(captions)
        inputs = torch.cat([features.unsqueeze(1), embeddings], dim=1)
        outputs, _ = self.lstm(inputs)
        logits = self.fc(outputs)
        return logits

In [5]:
# ================= Dataset =================
class VideoDataset(Dataset):
    def __init__(self, annotations, video_features, tokenizer, max_len):
        self.annotations = annotations
        self.video_features = video_features
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        features = self.video_features[idx]
        events = torch.tensor(self.annotations[idx]["events"], dtype=torch.float32)
        captions = self.annotations[idx]["captions"]

        tokenized_captions = [
            self.tokenizer(caption, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt")
            for caption in captions
        ]

        input_ids = torch.stack([t["input_ids"].squeeze(0) for t in tokenized_captions])
        attention_masks = torch.stack([t["attention_mask"].squeeze(0) for t in tokenized_captions])

        return features, events, input_ids, attention_masks

In [6]:
# Collate function to handle variable-length video features
def collate_fn(batch):
    features, events, captions, masks = zip(*batch)
    padded_features = pad_sequence(features, batch_first=True)
    return padded_features, events, captions, masks

In [7]:
# ================= Training =================
def train():
    # Hyperparameters
    input_dim = 2048
    vocab_size = 30522
    embedding_dim = 512
    hidden_dim = 512
    max_len = 30
    batch_size = 4
    num_epochs = 10
    learning_rate = 1e-4

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Example dataset
    annotations = [
        {"events": [[0, 30], [40, 70]], "captions": ["event1 caption", "event2 caption"]},
        {"events": [[10, 50]], "captions": ["single event caption"]},
    ]
    video_features = [
        torch.rand((150, 2048)),  # Example variable-length features for video 1
        torch.rand((100, 2048)),  # Example variable-length features for video 2
    ]

    dataset = VideoDataset(annotations, video_features, tokenizer, max_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    tep = TemporalEventProposal(input_dim).cuda()
    tsrm = TemporalSemanticRelationModule(input_dim).cuda()
    caption_gen = CaptionGenerator(vocab_size, embedding_dim, hidden_dim).cuda()

    optimizer = optim.Adam(list(tep.parameters()) + list(tsrm.parameters()) + list(caption_gen.parameters()), lr=learning_rate)
    proposal_criterion = nn.MSELoss()  # Proposal boundary loss
    caption_criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        total_loss = 0.0
        for features, events, captions, masks in dataloader:
            features = features.cuda()
            captions = [cap.cuda() for cap in captions]

            # Temporal Event Proposals
            proposals = tep(features)
            target_proposals = torch.cat(events).cuda()  # Flatten ground truth events for batch
            proposal_loss = proposal_criterion(proposals.reshape(-1, 2), target_proposals)

            # Visual Representations
            visual_features = tsrm(features)

            # Caption Generation
            caption_loss = 0
            for i, (cap, mask) in enumerate(zip(captions, masks)):
                logits = caption_gen(cap[:, :-1], visual_features[i].unsqueeze(0))
                caption_loss += caption_criterion(logits.reshape(-1, vocab_size), cap[:, 1:].reshape(-1))

            caption_loss /= len(captions)  # Average over batch

            # Total Loss
            total_step_loss = proposal_loss + caption_loss
            total_loss += total_step_loss.item()

            optimizer.zero_grad()
            total_step_loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

    # Save models
    torch.save(tep.state_dict(), "tep_model.pth")
    torch.save(tsrm.state_dict(), "tsrm_model.pth")
    torch.save(caption_gen.state_dict(), "caption_gen_model.pth")

In [8]:
# ================= Inference =================
def inference(video_features, tep, tsrm, caption_gen, tokenizer):
    tep.eval()
    tsrm.eval()
    caption_gen.eval()

    captions_list = []

    for features in video_features:
        features = features.unsqueeze(0).cuda()
        proposals = tep(features)
        visual_features = tsrm(features)

        input_ids = torch.tensor([tokenizer.cls_token_id]).unsqueeze(0).cuda()
        captions = []

        for _ in range(30):
            logits = caption_gen(input_ids, visual_features)
            next_token = logits.argmax(-1)[:, -1]
            captions.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
            if next_token.item() == tokenizer.sep_token_id:
                break

        captions_list.append(tokenizer.decode(captions, skip_special_tokens=True))

    return captions_list

In [None]:
# ================= Main =================
if __name__ == "__main__":
    train()

    # Load trained models
    tep = TemporalEventProposal(2048).cuda()
    tep.load_state_dict(torch.load("tep_model.pth"))

    tsrm = TemporalSemanticRelationModule(2048).cuda()
    tsrm.load_state_dict(torch.load("tsrm_model.pth"))

    caption_gen = CaptionGenerator(30522, 512, 512).cuda()
    caption_gen.load_state_dict(torch.load("caption_gen_model.pth"))

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    video_features = [torch.rand((150, 2048)), torch.rand((100, 2048))]
    captions = inference(video_features, tep, tsrm, caption_gen, tokenizer)
    print("Generated Captions:", captions)
