In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel
from sklearn.decomposition import PCA
import numpy as np

In [5]:


# ================= Temporal Localization Non-Maximum Suppression =================
def tlnms(proposals, threshold=0.5):
    """
    Apply Non-Maximum Suppression to the temporal proposals.
    Args:
        proposals (torch.Tensor): Shape (N, 2), where N is the number of proposals.
                                  Each proposal has a start and end time.
        threshold (float): IoU threshold for suppression.
    Returns:
        filtered_proposals (torch.Tensor): Shape (M, 2), filtered proposals.
    """
    if proposals.shape[0] == 0:
        return proposals

    proposals = proposals.cpu().detach().numpy()
    start = proposals[:, 0]
    end = proposals[:, 1]
    lengths = end - start
    sorted_indices = lengths.argsort()[::-1]

    selected = []
    while len(sorted_indices) > 0:
        current = sorted_indices[0]
        selected.append(current)
        sorted_indices = sorted_indices[1:]

        overlapping = []
        for i in sorted_indices:
            intersection = max(0, min(end[current], end[i]) - max(start[current], start[i]))
            union = (end[current] - start[current]) + (end[i] - start[i]) - intersection
            iou = intersection / union
            if iou > threshold:
                overlapping.append(i)

        sorted_indices = [i for i in sorted_indices if i not in overlapping]

    return torch.tensor(proposals[selected], dtype=torch.float32)

# ================= Temporal Event Proposal =================
class TemporalEventProposal(nn.Module):
    def __init__(self, input_dim):
        super(TemporalEventProposal, self).__init__()
        self.gru = nn.GRU(input_dim, 512, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(1024, 2)  # Predict start and end boundaries

    def forward(self, x):
        outputs, _ = self.gru(x)
        proposals = self.fc(outputs)
        return proposals

class TemporalSemanticRelationModule(nn.Module):
    def __init__(self, input_dim, output_dim=512):
        super(TemporalSemanticRelationModule, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, input_dim)
        Returns:
            Tensor of shape (batch_size, output_dim)
        """
        x = x.mean(dim=1)  # Apply temporal mean pooling over seq_len
        print(f"Shape after mean pooling: {x.shape}")  # Debugging statement
        return self.relu(self.fc(x))




class CaptionGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(CaptionGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, captions, features):
        embeddings = self.embedding(captions)  # (batch_size, seq_len, embedding_dim)
        transformed_features = features.unsqueeze(1).repeat(1, embeddings.size(1), 1)  # Align with seq_len
        print(f"Shape of transformed_features: {transformed_features.shape}")
        print(f"Shape of embeddings: {embeddings.shape}")

        inputs = torch.cat([transformed_features, embeddings], dim=2)  # Concatenate along feature dim
        print(f"Shape of inputs to LSTM: {inputs.shape}")

        outputs, _ = self.lstm(inputs)
        logits = self.fc(outputs)
        return logits









# ================= Dataset =================
class VideoDataset(Dataset):
    def __init__(self, annotations, video_features, tokenizer, max_len):
        self.annotations = annotations
        self.video_features = video_features
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        events = self.annotations[idx]["events"]
        captions = self.annotations[idx]["captions"]
        features = self.video_features[idx]

        tokens = []
        masks = []
        for caption in captions:
            tokenized = self.tokenizer(caption, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt")
            tokens.append(tokenized["input_ids"].squeeze(0))
            masks.append(tokenized["attention_mask"].squeeze(0))

        return features, torch.tensor(events, dtype=torch.float32), tokens, masks

# Collate function to handle variable-length video features
def collate_fn(batch):
    features, events, captions, masks = zip(*batch)
    padded_features = pad_sequence(features, batch_first=True)
    captions = [torch.stack(cap) for cap in captions]
    masks = [torch.stack(mask) for mask in masks]
    return padded_features, events, captions, masks


def train_with_tlnms(dataset, tokenizer, input_dim, vocab_size, embedding_dim, hidden_dim, max_len, batch_size, num_epochs, learning_rate):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    tep = TemporalEventProposal(input_dim).cuda()
    tsrm = TemporalSemanticRelationModule(input_dim).cuda()
    caption_gen = CaptionGenerator(vocab_size, embedding_dim, hidden_dim).cuda()

    optimizer = optim.Adam(list(tep.parameters()) + list(tsrm.parameters()) + list(caption_gen.parameters()), lr=learning_rate)
    proposal_criterion = nn.MSELoss()
    caption_criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        total_loss = 0.0

        for features, events, captions, masks in dataloader:
            features = features.cuda()
            captions = [cap.cuda() for cap in captions]

            # Temporal Event Proposals
            raw_proposals = tep(features)
            raw_proposals = torch.sigmoid(raw_proposals) * features.size(1)  # Scale to [0, seq_len]
            print(f"Shape of raw_proposals: {raw_proposals.shape}")

            # Proposal Loss
            proposal_loss = sum(proposal_criterion(raw, target.cuda()) for raw, target in zip(raw_proposals, events))
            proposal_loss /= len(events)

            # Temporal Semantic Relation Module
            visual_features = tsrm(features)
            print(f"Shape of visual_features: {visual_features.shape}")

            # Caption Generation
            caption_loss = 0
            for i, cap in enumerate(captions):
                current_visual_features = visual_features[i].unsqueeze(0).repeat(cap.size(0), 1)
                logits = caption_gen(cap[:, :-1], current_visual_features)
                caption_loss += caption_criterion(logits.reshape(-1, vocab_size), cap[:, 1:].reshape(-1))
            caption_loss /= len(captions)

            # Total Loss
            total_step_loss = proposal_loss + caption_loss
            optimizer.zero_grad()
            total_step_loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")

    torch.save(tep.state_dict(), "tep_model.pth")
    torch.save(tsrm.state_dict(), "tsrm_model.pth")
    torch.save(caption_gen.state_dict(), "caption_gen_model.pth")


# ================= Inference Fixes =================
def inference_with_tlnms(video_features, tep, tsrm, caption_gen, tokenizer):
    tep.eval()
    tsrm.eval()
    caption_gen.eval()

    captions_list = []

    for features in video_features:
        features = features.unsqueeze(0).cuda()
        raw_proposals = tep(features)
        raw_proposals = torch.sigmoid(raw_proposals)
        filtered_proposals = tlnms(raw_proposals[0])  # Apply TLNMS to the proposals

        captions_for_video = []
        for proposal in filtered_proposals:
            start, end = map(int, proposal)
            segment_features = features[:, start:end, :]  # Extract features for the segment
            visual_features = tsrm(segment_features.mean(dim=1))

            input_ids = torch.tensor([tokenizer.cls_token_id]).unsqueeze(0).cuda()
            generated_caption = []

            for _ in range(30):
                logits = caption_gen(input_ids, visual_features.unsqueeze(0))
                next_token = logits.argmax(-1)[:, -1]
                generated_caption.append(next_token.item())
                input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
                if next_token.item() == tokenizer.sep_token_id:
                    break

            captions_for_video.append(tokenizer.decode(generated_caption, skip_special_tokens=True))
        captions_list.append(captions_for_video)

    return captions_list





# ================= Main =================
if __name__ == "__main__":
    # Example dataset
    annotations = [
        {"events": [[0, 30], [40, 70]], "captions": ["event1 caption", "event2 caption"]},
        {"events": [[10, 50]], "captions": ["single event caption"]},
    ]
    video_features = [
        torch.rand((150, 2048)),  # Example variable-length features for video 1
        torch.rand((100, 2048)),  # Example variable-length features for video 2
    ]

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    dataset = VideoDataset(annotations, video_features, tokenizer, max_len=30)

    # Hyperparameters
    input_dim = 2048
    vocab_size = 30522
    embedding_dim = 512
    hidden_dim = 512
    max_len = 30
    batch_size = 4
    num_epochs = 10
    learning_rate = 1e-4

    train_with_tlnms(dataset, tokenizer, input_dim, vocab_size, embedding_dim, hidden_dim, max_len, batch_size, num_epochs, learning_rate)

    # Load trained models
    tep = TemporalEventProposal(2048).cuda()
    tep.load_state_dict(torch.load("tep_model.pth"))

    tsrm = TemporalSemanticRelationModule(2048).cuda()
    tsrm.load_state_dict(torch.load("tsrm_model.pth"))

    caption_gen = CaptionGenerator(30522, 512, 512).cuda()
    caption_gen.load_state_dict(torch.load("caption_gen_model.pth"))

    captions = inference_with_tlnms(video_features, tep, tsrm, caption_gen, tokenizer)
    print("Generated Captions:", captions)


Shape of raw_proposals: torch.Size([2, 150, 2])


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (150) must match the size of tensor b (2) at non-singleton dimension 0