<a href="https://colab.research.google.com/github/tiffany-gu/vjepa2-nyjt/blob/main/Cross_Modal_Transformer_Fusion_Architecture_(with_V_JEPA_2_%26_CoMotion_Context).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Configuration Parameters ---
# These would typically be loaded from a config file or passed as arguments
# VISION_FEATURE_DIM: Dimension of features extracted from V-JEPA 2.
# V-JEPA 2 ViT-g typically outputs 768-dim features.
VISION_FEATURE_DIM = 768

# SKELETON_FEATURE_DIM: Dimension of features derived from CoMotion's 3D joint coordinates.
# If CoMotion outputs (N_joints * 3) coordinates, this would be that dimension,
# or a projected dimension after an initial processing layer.
SKELETON_FEATURE_DIM = 256 # Example: 25 joints * 3 coords = 75, then projected to 256

FUSION_EMBED_DIM = 512    # Common embedding dimension for fused features
NUM_HEADS = 8             # Number of attention heads in the transformer
NUM_FUSION_LAYERS = 3     # Number of transformer encoder layers in the fusion module
NUM_CLASSES = 10          # Example: Number of action classes in InHARD dataset
MAX_SEQUENCE_LENGTH = 16  # Example: Number of frames/time steps in a video clip

# --- 1. Vision Feature Encoder (Conceptual V-JEPA 2 Integration) ---
# This class conceptually represents the process of taking pre-extracted
# V-JEPA 2 features and preparing them for the fusion module.
# In a real setup, you would have already run V-JEPA 2 on your video clips
# (as per Task 1.4) to get these `video_features`.
class VisionEncoder(nn.Module):
    def __init__(self, input_dim=VISION_FEATURE_DIM, output_dim=VISION_FEATURE_DIM):
        super().__init__()
        # This layer might be an identity if features are already in the desired dim,
        # or a simple projection if further processing is needed before fusion.
        # It represents the 'attentive probe' mentioned in the project for V-JEPA 2 baseline,
        # but here it's just a pass-through or light transformation for fusion.
        self.processor = nn.Identity() if input_dim == output_dim else nn.Linear(input_dim, output_dim)
        print(f"VisionEncoder: Represents processing of pre-extracted V-JEPA 2 features. "
              f"Expects input features of dim {input_dim}.")

    def forward(self, video_features):
        # video_features shape: (batch_size, sequence_length, VISION_FEATURE_DIM)
        # These are assumed to be the frame-level feature embeddings pre-extracted
        # from the frozen V-JEPA 2 encoder (Task 1.4).
        return self.processor(video_features)

# --- 2. Skeleton Feature Encoder (Conceptual CoMotion Integration) ---
# This class conceptually represents taking 3D skeleton data from CoMotion
# and converting it into a feature representation suitable for fusion.
# In a real setup, you would have already run CoMotion on your video clips
# (as per Task 1.3) to get `skeleton_data` (e.g., 3D joint coordinates or SMPL params).
class SkeletonEncoder(nn.Module):
    def __init__(self, input_dim=SKELETON_FEATURE_DIM, output_dim=SKELETON_FEATURE_DIM):
        super().__init__()
        # This could be a simple linear layer, an MLP, or a small transformer
        # to process the raw 3D joint coordinates (or SMPL parameters)
        # into a fixed-dimension feature vector per time step.
        # For demonstration, we use a linear layer.
        self.processor = nn.Linear(input_dim, output_dim)
        print(f"SkeletonEncoder: Represents processing of CoMotion 3D skeleton data. "
              f"Expects input features of dim {input_dim} (e.g., N_joints * 3).")

    def forward(self, skeleton_data):
        # skeleton_data shape: (batch_size, sequence_length, SKELETON_FEATURE_DIM)
        # These are assumed to be the processed 3D joint coordinates from CoMotion (Task 1.3).
        return self.processor(skeleton_data)

# --- 3. Cross-Modal Transformer Fusion Architecture ---
# This module takes vision and skeleton features and fuses them using a transformer.
# The "skeleton conditioning" is achieved by combining the projected vision and
# skeleton features before feeding them into the transformer, allowing the
# self-attention mechanism to learn relationships between the two modalities.
class CrossModalTransformer(nn.Module):
    def __init__(self, vision_input_dim, skeleton_input_dim, embed_dim, num_heads, num_layers, max_seq_len):
        super().__init__()
        # Project vision and skeleton features to a common embedding dimension
        self.vision_projection = nn.Linear(vision_input_dim, embed_dim)
        self.skeleton_projection = nn.Linear(skeleton_input_dim, embed_dim)

        # Token for capturing global context (e.g., [CLS] token in BERT)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Positional encoding to capture temporal information
        # +1 for the [CLS] token at the beginning of the sequence
        self.positional_encoding = nn.Parameter(torch.randn(1, max_seq_len + 1, embed_dim))

        # Transformer Encoder layers
        # d_model is the feature dimension for the transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4, # Standard practice for feedforward network dimension
            batch_first=True               # Input and output tensors are (batch, sequence, feature)
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.norm = nn.LayerNorm(embed_dim) # Layer normalization after transformer output

        print(f"CrossModalTransformer: Fusing features with embed_dim={embed_dim}, "
              f"num_heads={num_heads}, num_layers={num_layers}.")

    def forward(self, vision_features, skeleton_features):
        # vision_features shape: (batch_size, sequence_length, VISION_FEATURE_DIM)
        # skeleton_features shape: (batch_size, sequence_length, SKELETON_FEATURE_DIM)

        batch_size, seq_len, _ = vision_features.shape

        # Project features to the common embedding dimension
        vision_proj = self.vision_projection(vision_features)   # (B, S, embed_dim)
        skeleton_proj = self.skeleton_projection(skeleton_features) # (B, S, embed_dim)

        # "Skeleton Conditioning": Combine projected vision and skeleton features.
        # A simple yet effective way is element-wise addition, allowing the skeleton
        # information to directly modulate the visual features at each time step.
        # Alternatively, concatenation could be used, or more complex cross-attention.
        combined_features_per_timestep = vision_proj + skeleton_proj # (B, S, embed_dim)

        # Add CLS token to the beginning of the sequence
        cls_tokens = self.cls_token.expand(batch_size, -1, -1) # (B, 1, embed_dim)
        # The input sequence for the transformer: [CLS_token, feature_t1, feature_t2, ...]
        input_sequence = torch.cat((cls_tokens, combined_features_per_timestep), dim=1) # (B, S+1, embed_dim)

        # Add positional encoding to inject temporal information
        # Ensure positional_encoding tensor is appropriately sliced for the current sequence length
        if input_sequence.shape[1] > self.positional_encoding.shape[1]:
            raise ValueError(
                f"Input sequence length ({input_sequence.shape[1]}) exceeds "
                f"max_seq_len for positional encoding ({self.positional_encoding.shape[1]}). "
                f"Adjust MAX_SEQUENCE_LENGTH or positional_encoding size."
            )
        input_sequence = input_sequence + self.positional_encoding[:, :input_sequence.shape[1], :]

        # Pass the combined sequence through the transformer encoder
        transformer_output = self.transformer_encoder(input_sequence)

        # Extract the output corresponding to the [CLS] token.
        # This token is designed to aggregate information from the entire sequence
        # and serves as the fused representation for classification.
        fused_representation = transformer_output[:, 0, :] # Shape: (batch_size, embed_dim)

        # Apply final layer normalization to the fused representation
        return self.norm(fused_representation)

# --- 4. Final Classification Head ---
# This takes the fused features and maps them to action probabilities (logits).
class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
        print(f"ClassificationHead: Outputting {num_classes} classes from {input_dim}-dim features.")

    def forward(self, fused_features):
        # fused_features shape: (batch_size, FUSION_EMBED_DIM)
        return self.fc(fused_features) # Returns logits

# --- 5. Complete Fusion Model ---
# Orchestrates the entire process, from feature encoding to fusion and classification.
class FusionModel(nn.Module):
    def __init__(self,
                 vision_feature_dim=VISION_FEATURE_DIM,
                 skeleton_feature_dim=SKELETON_FEATURE_DIM,
                 fusion_embed_dim=FUSION_EMBED_DIM,
                 num_heads=NUM_HEADS,
                 num_fusion_layers=NUM_FUSION_LAYERS,
                 num_classes=NUM_CLASSES,
                 max_sequence_length=MAX_SEQUENCE_LENGTH):
        super().__init__()

        # Initialize the conceptual encoders for V-JEPA 2 and CoMotion features
        self.vision_encoder = VisionEncoder(input_dim=vision_feature_dim, output_dim=vision_feature_dim)
        self.skeleton_encoder = SkeletonEncoder(input_dim=skeleton_feature_dim, output_dim=skeleton_feature_dim)

        # Initialize the cross-modal transformer for fusion
        self.cross_modal_transformer = CrossModalTransformer(
            vision_input_dim=vision_feature_dim,
            skeleton_input_dim=skeleton_feature_dim,
            embed_dim=fusion_embed_dim,
            num_heads=num_heads,
            num_layers=num_fusion_layers,
            max_seq_len=max_sequence_length
        )
        # Initialize the classification head
        self.classification_head = ClassificationHead(fusion_embed_dim, num_classes)
        print("FusionModel: All components initialized.")

    def forward(self, video_features, skeleton_data):
        # 1. Process (conceptually encode) the pre-extracted features
        # These steps represent the input features from Task 1.4 (V-JEPA 2)
        # and Task 1.3 (CoMotion).
        encoded_vision = self.vision_encoder(video_features)
        encoded_skeleton = self.skeleton_encoder(skeleton_data)

        # 2. Fuse the encoded features using the cross-modal transformer
        fused_representation = self.cross_modal_transformer(encoded_vision, encoded_skeleton)

        # 3. Classify the resulting fused representation
        logits = self.classification_head(fused_representation)
        return logits

# --- Example Usage (Conceptual) ---
if __name__ == "__main__":
    print("--- Initializing Fusion Model ---")
    model = FusionModel()

    # Create dummy input data to simulate features from Task 1.4 (V-JEPA 2)
    # and Task 1.3 (CoMotion).
    # Shape: (batch_size, sequence_length, feature_dim)
    dummy_video_features = torch.randn(2, MAX_SEQUENCE_LENGTH, VISION_FEATURE_DIM)
    dummy_skeleton_data = torch.randn(2, MAX_SEQUENCE_LENGTH, SKELETON_FEATURE_DIM)

    print(f"\nDummy Video Features shape (simulating V-JEPA 2 output): {dummy_video_features.shape}")
    print(f"Dummy Skeleton Data shape (simulating CoMotion output): {dummy_skeleton_data.shape}")

    # Perform a forward pass through the entire model
    print("\n--- Performing Forward Pass ---")
    output_logits = model(dummy_video_features, dummy_skeleton_data)

    print(f"Output Logits shape: {output_logits.shape}")
    print(f"Example Output Logits (first sample): {output_logits[0].detach().numpy()}")

    # This section outlines the conceptual steps for training the model.
    # In a real scenario, you would load your InHARD dataset, define
    # a loss function, an optimizer, and run a training loop.
    print("\n--- Conceptual Training Loop Outline ---")
    # criterion = nn.CrossEntropyLoss() # For multi-class classification
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    #
    # # Example of a training loop (assuming you have a PyTorch DataLoader for InHARD)
    # num_epochs = 10
    # for epoch in range(num_epochs):
    #     model.train() # Set model to training mode
    #     for batch_idx, (video_data_batch, skeleton_data_batch, labels_batch) in enumerate(dataloader):
    #         # Move data to the appropriate device (e.g., GPU)
    #         # video_data_batch, skeleton_data_batch, labels_batch = video_data_batch.to(device), skeleton_data_batch.to(device), labels_batch.to(device)
    #
    #         optimizer.zero_grad() # Clear gradients
    #         outputs = model(video_data_batch, skeleton_data_batch) # Forward pass
    #         loss = criterion(outputs, labels_batch) # Calculate loss
    #         loss.backward() # Backward pass
    #         optimizer.step() # Update model parameters
    #
    #     print(f"Epoch {epoch+1} Loss: {loss.item()}")
    #
    #     # After each epoch, typically evaluate on a validation set
    #     # model.eval() # Set model to evaluation mode
    #     # with torch.no_grad():
    #     #     # ... run evaluation logic ...
    #     #     print(f"Validation Accuracy: {accuracy}")

    print("\nThis conceptual implementation covers Task 2.2: Fusion Module Implementation.")
    print("It sets up the cross-modal transformer architecture with skeleton conditioning and the final classification head.")
    print("The next step (Task 2.3) would be to integrate this model into a full training pipeline.")