In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer
import numpy as np

In [20]:
def tlnms(proposals, threshold=0.1):
    if proposals.shape[0] == 0:
        return proposals

    proposals = proposals.cpu().detach().numpy()
    start = proposals[:, 0]
    end = proposals[:, 1]
    lengths = end - start

    # Only keep proposals where start < end
    valid_indices = lengths > 0
    start = start[valid_indices]
    end = end[valid_indices]
    lengths = lengths[valid_indices]
    sorted_indices = lengths.argsort()[::-1]

    selected = []
    while len(sorted_indices) > 0:
        current = sorted_indices[0]
        selected.append(current)
        sorted_indices = sorted_indices[1:]

        overlapping = []
        for i in sorted_indices:
            intersection = max(0, min(end[current], end[i]) - max(start[current], start[i]))
            union = (end[current] - start[current]) + (end[i] - start[i]) - intersection
            iou = intersection / union if union > 0 else 0  # Avoid division by zero
            if iou > threshold:
                overlapping.append(i)

        sorted_indices = [i for i in sorted_indices if i not in overlapping]

    return torch.tensor(proposals[selected], dtype=torch.float32)


# ================= Temporal Event Proposal =================
class TemporalEventProposal(nn.Module):
    def __init__(self, input_dim):
        super(TemporalEventProposal, self).__init__()
        self.gru = nn.GRU(input_dim, 512, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(1024, 2)  # Predict start and end boundaries

    def forward(self, x):
        outputs, _ = self.gru(x)
        proposals = self.fc(outputs)
        return proposals


class TemporalSemanticRelationModule(nn.Module):
    def __init__(self, input_dim, output_dim=512):
        super(TemporalSemanticRelationModule, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        # Ensure x is at least 2D (batch_size, feature_dim)
        if x.ndim == 3:  # Shape: (batch_size, seq_len, feature_dim)
            x = x.mean(dim=1)  # Temporal mean pooling over seq_len
        elif x.ndim == 2:  # Shape: (seq_len, feature_dim)
            x = x.mean(dim=0)  # Temporal mean pooling over seq_len
        else:
            raise ValueError(f"Unexpected input shape: {x.shape}")
        
        # After mean pooling, x should have shape (batch_size, input_dim)
        return self.relu(self.fc(x))



# ================= Caption Generator =================
class CaptionGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(CaptionGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + 512, hidden_dim, batch_first=True)  # Combined input size
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, captions, features):
        embeddings = self.embedding(captions)  # (batch_size, seq_len, embedding_dim)
        transformed_features = features.unsqueeze(1).repeat(1, embeddings.size(1), 1)  # Match seq_len
        inputs = torch.cat([transformed_features, embeddings], dim=2)  # Concatenate along feature dim
        outputs, _ = self.lstm(inputs)
        logits = self.fc(outputs)
        return logits


# ================= Dataset =================
class VideoDataset(Dataset):
    def __init__(self, annotations, video_features, tokenizer, max_len):
        self.annotations = annotations
        self.video_features = video_features
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        events = self.annotations[idx]["events"]
        captions = self.annotations[idx]["captions"]
        features = self.video_features[idx]

        tokens = []
        masks = []
        for caption in captions:
            tokenized = self.tokenizer(caption, max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt")
            tokens.append(tokenized["input_ids"].squeeze(0))
            masks.append(tokenized["attention_mask"].squeeze(0))

        return features, torch.tensor(events, dtype=torch.float32), tokens, masks


# ================= Collate Function =================
def collate_fn(batch):
    features, events, captions, masks = zip(*batch)
    padded_features = pad_sequence(features, batch_first=True)
    captions = [torch.stack(cap) for cap in captions]
    masks = [torch.stack(mask) for mask in masks]
    return padded_features, events, captions, masks


def train_with_tlnms(dataset, tokenizer, input_dim, vocab_size, embedding_dim, hidden_dim, max_len, batch_size, num_epochs, learning_rate):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    tep = TemporalEventProposal(input_dim).cuda()
    tsrm = TemporalSemanticRelationModule(input_dim).cuda()
    caption_gen = CaptionGenerator(vocab_size, embedding_dim, hidden_dim).cuda()

    optimizer = optim.Adam(list(tep.parameters()) + list(tsrm.parameters()) + list(caption_gen.parameters()), lr=learning_rate)
    proposal_criterion = nn.MSELoss()  # Proposal boundary loss
    caption_criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        total_loss = 0.0

        for features, events, captions, masks in dataloader:
            features = features.cuda()  # Shape: (batch_size, seq_len, input_dim)
            captions = [cap.cuda() for cap in captions]

            # Temporal Event Proposals
            raw_proposals = tep(features)  # Shape: (batch_size, seq_len, 2)
            raw_proposals = torch.sigmoid(raw_proposals) * features.size(1)  # Normalize to [0, seq_len]

            # Proposal Loss
            proposal_loss = 0
            for i in range(len(events)):
                expanded_target = torch.zeros_like(raw_proposals[i]).cuda()
                for start, end in events[i]:
                    # Find the closest indices for start and end
                    closest_start_idx = torch.argmin(torch.abs(raw_proposals[i][:, 0] - start))
                    closest_end_idx = torch.argmin(torch.abs(raw_proposals[i][:, 1] - end))
                    expanded_target[closest_start_idx, 0] = start
                    expanded_target[closest_end_idx, 1] = end
                proposal_loss += proposal_criterion(raw_proposals[i], expanded_target)
            proposal_loss /= len(events)  # Average over batch

            # Temporal Semantic Relation Module
            visual_features = tsrm(features)  # Shape: (batch_size, 512)

            # Caption Generation
            caption_loss = 0
            for i, cap in enumerate(captions):
                current_visual_features = visual_features[i].unsqueeze(0).repeat(cap.size(0), 1)  # Match caption batch size
                logits = caption_gen(cap[:, :-1], current_visual_features)
                caption_loss += caption_criterion(logits.reshape(-1, vocab_size), cap[:, 1:].reshape(-1))
            caption_loss /= len(captions)

            # Total Loss
            total_step_loss = proposal_loss + caption_loss
            total_loss += total_step_loss.item()

            optimizer.zero_grad()
            total_step_loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

    # Save models
    torch.save(tep.state_dict(), "tep_model.pth")
    torch.save(tsrm.state_dict(), "tsrm_model.pth")
    torch.save(caption_gen.state_dict(), "caption_gen_model.pth")



def inference_with_tlnms(video_features, tep, tsrm, caption_gen, tokenizer):
    tep.eval()
    tsrm.eval()
    caption_gen.eval()

    captions_list = []

    for features in video_features:
        features = features.unsqueeze(0).cuda()

        # Temporal Event Proposals
        raw_proposals = tep(features)
        raw_proposals = torch.sigmoid(raw_proposals) * features.size(1)  # Normalize proposals

        # Filter valid proposals (start < end)
        raw_proposals = raw_proposals[0]
        valid_indices = raw_proposals[:, 0] < raw_proposals[:, 1]
        raw_proposals = raw_proposals[valid_indices]

        if raw_proposals.size(0) == 0:
            captions_list.append(["No valid proposals"])
            continue

        # Apply TLNMS
        filtered_proposals = tlnms(raw_proposals, threshold=0.2)

        valid_proposals = []
        for proposal in filtered_proposals:
            start, end = map(int, proposal)
            if start < end and start >= 0 and end <= features.size(1):
                valid_proposals.append((start, end))

        if not valid_proposals:
            captions_list.append(["No valid proposals"])
            continue

        captions_for_video = []
        for start, end in valid_proposals:
            segment_features = features[:, start:end, :]
            if segment_features.size(1) == 0:  # Skip empty segments
                continue

            # Process segment features
            pooled_features = segment_features.mean(dim=1)  # Temporal mean pooling
            print(f"Pooled features shape: {pooled_features.shape}")

            # Normalize pooled features
            pooled_features = (pooled_features - pooled_features.mean()) / (pooled_features.std() + 1e-6)

            visual_features = tsrm(pooled_features)  # Pass through TSRM

            # Debugging visual features
            print(f"Visual features shape: {visual_features.shape}")
            assert not torch.isnan(visual_features).any(), "Visual features contain NaN values."
            assert torch.isfinite(visual_features).all(), "Visual features contain Inf values."

            # Generate caption
            input_ids = torch.tensor([tokenizer.cls_token_id]).unsqueeze(0).cuda()
            generated_caption = []

            for _ in range(30):  # Limit caption length to 30 tokens
                logits = caption_gen(input_ids, visual_features.unsqueeze(0))

                # Debug logits
                print(f"Logits shape: {logits.shape}")
                next_token_id = logits.argmax(-1)[:, -1].item()

                if next_token_id == tokenizer.sep_token_id:
                    break

                generated_caption.append(next_token_id)
                input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]]).cuda()], dim=1)

            decoded_caption = tokenizer.decode(generated_caption, skip_special_tokens=True)
            print(f"Decoded Caption: {decoded_caption}")
            captions_for_video.append(decoded_caption)

        captions_list.append(captions_for_video if captions_for_video else ["No valid captions generated"])

    return captions_list






In [24]:
import numpy as np
import h5py
import torch
import json

def get_video_features_and_annotations(file_path='data/features/c3d_features/c3d_features.hdf5', annotation_path = 'data/features/c3d_features/youcookii_annotations_trainval.json'):
    video_features = []
    annotations = []
    with open(annotation_path,'r') as annotations_file:
        total_annotations = json.load(annotations_file)
    total_annotations = total_annotations['database']
    
    with h5py.File(file_path, 'r') as f:
        for video_name in f.keys():
            group = f[video_name]
            curr_annotation = total_annotations[video_name]
            events = []
            captions = []
            for data in curr_annotation['annotations']:
                seg = data['segment']
                events.append(seg)
                sentence = data['sentence']
                captions.append(sentence)
            annotations.append({'events':events, 'captions':captions})

            features = np.array(group['c3d_features'][()])
            video_features.append(torch.from_numpy(features).float())

    return video_features,annotations

def get_only_video_features(file_path='data/features/c3d_features/c3d_infer_features.hdf5'):
    video_features = []
    with h5py.File(file_path, 'r') as f:
        for video_name in f.keys():
            group = f[video_name]
            features = np.array(group['c3d_features'][()])
            video_features.append(torch.from_numpy(features).float())

    return video_features

In [30]:


# ================= Main =================
if __name__ == "__main__":
    # Example annotations and features
    
    video_features,annotations = get_video_features_and_annotations()
    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Create dataset
    dataset = VideoDataset(annotations, video_features, tokenizer, max_len=30)

    # Hyperparameters
    input_dim = 4096
    vocab_size = tokenizer.vocab_size  # Dynamically fetch vocabulary size
    embedding_dim = 512
    hidden_dim = 512
    max_len = 30
    batch_size = 4
    num_epochs = 100
    learning_rate = 1e-4

    # Train the models
    print("Starting training...")
    train_with_tlnms(dataset, tokenizer, input_dim, vocab_size, embedding_dim, hidden_dim, max_len, batch_size, num_epochs, learning_rate)
    print("Training completed!")




Starting training...
Epoch 1/100, Loss: 5605.2392
Epoch 2/100, Loss: 3324.9333
Epoch 3/100, Loss: 3284.0263
Epoch 4/100, Loss: 1960.4773
Epoch 5/100, Loss: 2451.5225
Epoch 6/100, Loss: 2478.1666
Epoch 7/100, Loss: 1909.6986
Epoch 8/100, Loss: 2909.6877
Epoch 9/100, Loss: 2087.9720
Epoch 10/100, Loss: 1450.9621
Epoch 11/100, Loss: 1657.1963
Epoch 12/100, Loss: 1851.4252
Epoch 13/100, Loss: 2135.6555
Epoch 14/100, Loss: 2507.7818
Epoch 15/100, Loss: 1494.4784
Epoch 16/100, Loss: 1434.3891
Epoch 17/100, Loss: 890.5954
Epoch 18/100, Loss: 1524.7976
Epoch 19/100, Loss: 1290.4775
Epoch 20/100, Loss: 1161.6757
Epoch 21/100, Loss: 262.1055
Epoch 22/100, Loss: 239.4555
Epoch 23/100, Loss: 199.9641
Epoch 24/100, Loss: 228.6244
Epoch 25/100, Loss: 193.8211
Epoch 26/100, Loss: 228.0022
Epoch 27/100, Loss: 237.7659
Epoch 28/100, Loss: 251.1528
Epoch 29/100, Loss: 204.3293
Epoch 30/100, Loss: 238.2575
Epoch 31/100, Loss: 229.9862
Epoch 32/100, Loss: 198.8987
Epoch 33/100, Loss: 208.7679
Epoch 34/100

In [31]:
    # Load models
    tep = TemporalEventProposal(4096).cuda()
    tsrm = TemporalSemanticRelationModule(4096, 512).cuda()
    caption_gen = CaptionGenerator(30522, 512, 512).cuda()

    tep.load_state_dict(torch.load("tep_model.pth"))
    tsrm.load_state_dict(torch.load("tsrm_model.pth"))
    caption_gen.load_state_dict(torch.load("caption_gen_model.pth"))

    # Perform inference
    print("Performing inference...")
    captions = inference_with_tlnms(video_features, tep, tsrm, caption_gen, tokenizer)
    print("Generated Captions:", captions)

  tep.load_state_dict(torch.load("tep_model.pth"))
  tsrm.load_state_dict(torch.load("tsrm_model.pth"))
  caption_gen.load_state_dict(torch.load("caption_gen_model.pth"))


Performing inference...
Pooled features shape: torch.Size([1, 4096])
Visual features shape: torch.Size([512])
Logits shape: torch.Size([1, 1, 30522])
Logits shape: torch.Size([1, 2, 30522])
Logits shape: torch.Size([1, 3, 30522])
Logits shape: torch.Size([1, 4, 30522])
Logits shape: torch.Size([1, 5, 30522])
Logits shape: torch.Size([1, 6, 30522])
Logits shape: torch.Size([1, 7, 30522])
Logits shape: torch.Size([1, 8, 30522])
Logits shape: torch.Size([1, 9, 30522])
Decoded Caption: add add some lettuce the salad
Pooled features shape: torch.Size([1, 4096])
Visual features shape: torch.Size([512])
Logits shape: torch.Size([1, 1, 30522])
Logits shape: torch.Size([1, 2, 30522])
Logits shape: torch.Size([1, 3, 30522])
Logits shape: torch.Size([1, 4, 30522])
Logits shape: torch.Size([1, 5, 30522])
Logits shape: torch.Size([1, 6, 30522])
Logits shape: torch.Size([1, 7, 30522])
Logits shape: torch.Size([1, 8, 30522])
Decoded Caption: add some some the chicken in the
Pooled features shape: tor