In [318]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler, ConcatDataset
import numpy as np
import os
import random
from typing import Dict, List, Tuple, Any
import json
import glob
import sys
import pandas as pd
import math

In [319]:
def load_frame_data_standardized(npz_path):
    """
    Load saved frame data from an NPZ file.
    
    Args:
        npz_path (str): Path to the saved .npz file
        
    Returns:
        tuple: All the detection results for the frame
    """
    data = np.load(npz_path)
    
    # Extract all arrays from the npz file
    dom_landmarks_standardized = data['dom_landmarks_standardized']
    non_dom_landmarks_standardized = data['non_dom_landmarks_standardized']
    confidence_scores = data['confidence_scores']
    interpolation_scores = data['interpolation_scores']
    detection_status = data['detection_status']
    blendshape_scores_standardized = data['blendshape_scores_standardized']
    face_detected = data['face_detected'].item()  # Convert 0-d array to scalar
    nose_to_wrist_dist_standardized = data['nose_to_wrist_dist_standardized']
    frame_idx = data['frame_idx'].item()
    timestamp_ms = data['timestamp_ms'].item()
    dom_velocity_small_standardized = data['dom_velocity_small_standardized']
    dom_velocity_large_standardized = data['dom_velocity_large_standardized']
    non_dom_velocity_small_standardized = data['non_dom_velocity_small_standardized']
    non_dom_velocity_large_standardized = data['non_dom_velocity_large_standardized']
    velocity_confidence = data['velocity_confidence']
    velocity_calculation_confidence = data['velocity_calculation_confidence']
    nose_to_wrist_velocity_small_standardized = data['wrist_velocity_small_standardized']
    nose_to_wrist_velocity_large_standardized = data['wrist_velocity_large_standardized']
    
    return (dom_landmarks_standardized, non_dom_landmarks_standardized, confidence_scores, interpolation_scores,
            detection_status, blendshape_scores_standardized, face_detected, 
            nose_to_wrist_dist_standardized, frame_idx, timestamp_ms, dom_velocity_small_standardized, dom_velocity_large_standardized, non_dom_velocity_small_standardized, non_dom_velocity_large_standardized, velocity_confidence, velocity_calculation_confidence, nose_to_wrist_velocity_small_standardized, nose_to_wrist_velocity_large_standardized)

def sorted_npz_files_checked_label(directory_path):
    if os.path.exists(directory_path) and os.path.isdir(directory_path):
        # List all NPZ files in the directory
        npz_files = sorted(glob.glob(os.path.join(directory_path, "*.npz")))
    else:
        print(f"Directory path {directory_path} doesn't exist or it isn't a directory")
        sys.exit(1)
        
    
    # Skip if no files found
    if not npz_files:
        print(f"No NPZ files found in {directory_path}")
        sys.exit(1)
    
    
    with open(os.path.join(directory_path, 'detection_statistics.json')) as f:
        statistics_file = json.load(f)
    
    if statistics_file['video_info']['total_frames'] != (len(npz_files)-1):
        print("npz filepath list contain different amount of items than total frames")
        sys.exit(1)


    frame_to_file = {}
    for file_path in npz_files:
        if os.path.basename(file_path) == 'smooth_labels.npz':
            label_path = file_path
            continue
        try:
            frame_data = load_frame_data_standardized(file_path)
        except Exception as e:
            print(f"Error loading frame with path: {file_path}: {e}")
            sys.exit(1)
            
        frame_idx = frame_data[8]  # Index for frame_idx
        frame_to_file[frame_idx] = file_path

    
    frame_indices = sorted(frame_to_file.keys())
    if not all(frame_indices[i+1] - frame_indices[i] == 1 for i in range(len(frame_indices) - 1)):
        print("Consecutive frames are not different by one frame")
        sys.exit(1)

    

    return frame_to_file, frame_indices, label_path

def load_label(label_path):
    label_data = np.load(label_path)
    L_index = label_data['L_index']
    L_values = label_data['L_values']
    return L_index, L_values

In [320]:

class ASLFrameDataset(Dataset):
    """Dataset for ASL frame data from video clips with feature extraction."""
    def __init__(self, dataframe):
        """
        Initialize the dataset.
        
        Args:
            dataframe: Pandas DataFrame containing 'landmarks_file_path' column
        """
        self.dataframe = dataframe
        self.video_paths = list(dataframe['landmarks_file_path'])
        
    def __len__(self):
        """Return the number of videos in the dataset."""
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        """Get data for a complete video with all features."""
        directory_path = self.video_paths[idx]
        
        # Get paths to all frame files in this video
        frame_to_file, frame_indices, label_path = sorted_npz_files_checked_label(directory_path)
        
        # Initialize dictionaries to store all data
        all_data = {
            # Primary features for model input
            'dom_landmarks': [],
            'non_dom_landmarks': [],
            'blendshape_scores': [],
            'nose_to_wrist_dist': [],
            'dom_velocity_small': [],
            'dom_velocity_large': [],
            'non_dom_velocity_small': [],
            'non_dom_velocity_large': [],
            'nose_to_wrist_velocity_small': [],
            'nose_to_wrist_velocity_large': [],
            
            # Additional data for later use
            'confidence_scores': [],
            'interpolation_scores': [],
            'detection_status': [],
            'face_detected': [],
            'frame_idx': [],
            'velocity_confidence': [],
            'velocity_calculation_confidence': []
        }
        
        # Load data from each frame
        for frame_idx in frame_indices:
            file_path = frame_to_file[frame_idx]
            frame_data = load_frame_data_standardized(file_path)
            
            # Unpack frame data
            (dom_landmarks_standardized,
             non_dom_landmarks_standardized,
             confidence_scores,
             interpolation_scores,
             detection_status,
             blendshape_scores_standardized,
             face_detected,
             nose_to_wrist_dist_standardized,
             frame_idx_val,
             timestamp_ms,  # We'll skip this one
             dom_velocity_small_standardized,
             dom_velocity_large_standardized,
             non_dom_velocity_small_standardized,
             non_dom_velocity_large_standardized,
             velocity_confidence,
             velocity_calculation_confidence,
             nose_to_wrist_velocity_small_standardized,
             nose_to_wrist_velocity_large_standardized) = frame_data
            
            # Store primary features for model input
            all_data['dom_landmarks'].append(dom_landmarks_standardized)
            all_data['non_dom_landmarks'].append(non_dom_landmarks_standardized)
            all_data['blendshape_scores'].append(blendshape_scores_standardized)
            all_data['nose_to_wrist_dist'].append(nose_to_wrist_dist_standardized)
            all_data['dom_velocity_small'].append(dom_velocity_small_standardized)
            all_data['dom_velocity_large'].append(dom_velocity_large_standardized)
            all_data['non_dom_velocity_small'].append(non_dom_velocity_small_standardized)
            all_data['non_dom_velocity_large'].append(non_dom_velocity_large_standardized)
            all_data['nose_to_wrist_velocity_small'].append(nose_to_wrist_velocity_small_standardized)
            all_data['nose_to_wrist_velocity_large'].append(nose_to_wrist_velocity_large_standardized)
            
            # Store additional data for later use
            all_data['confidence_scores'].append(confidence_scores)
            all_data['interpolation_scores'].append(interpolation_scores)
            all_data['detection_status'].append(detection_status)
            all_data['face_detected'].append(face_detected)
            all_data['frame_idx'].append(frame_idx_val)
            all_data['velocity_confidence'].append(velocity_confidence)
            all_data['velocity_calculation_confidence'].append(velocity_calculation_confidence)
        
        # Convert lists to numpy arrays
        for key in all_data:
            all_data[key] = np.array(all_data[key])
        
        # Load label data
        L_index, L_values = load_label(label_path)
        all_data['L_index'] = L_index
        all_data['L_values'] = L_values
        
        # Store sequence length and directory path
        all_data['seq_length'] = len(frame_indices)
        all_data['directory_path'] = directory_path
        
        return all_data


class SingleDataFrameBatchSampler(Sampler):
    """
    Custom batch sampler that ensures each batch contains samples 
    from only one dataframe.
    """
    def __init__(self, dataset_sizes: List[int], batch_size: int, drop_last: bool = False):
        """
        Initialize the batch sampler.
        
        Args:
            dataset_sizes: List of sizes for each dataset
            batch_size: Batch size
            drop_last: Whether to drop the last batch if incomplete
        """
        self.dataset_sizes = dataset_sizes
        self.batch_size = batch_size
        self.drop_last = drop_last
        
        # Calculate offsets for indexing into the combined dataset
        self.offsets = [0]
        for size in dataset_sizes[:-1]:
            self.offsets.append(self.offsets[-1] + size)
    
    def __iter__(self):
        """Generate batches of indices, ensuring each batch comes from one dataset."""
        # Create index lists for each dataset
        all_indices = []
        for dataset_idx, size in enumerate(self.dataset_sizes):
            offset = self.offsets[dataset_idx]
            indices = list(range(offset, offset + size))
            random.shuffle(indices)
            all_indices.append(indices)
            
        # Create batches for each dataset
        all_batches = []
        for dataset_idx, indices in enumerate(all_indices):
            for i in range(0, len(indices), self.batch_size):
                batch = indices[i:min(i + self.batch_size, len(indices))]
                
                # Skip last incomplete batch if drop_last is True
                if self.drop_last and len(batch) < self.batch_size:
                    continue
                
                all_batches.append(batch)
        
        # Shuffle the order of batches
        random.shuffle(all_batches)
        
        # Yield batches one at a time
        for batch in all_batches:
            yield batch
    
    def __len__(self):
        """Return the number of batches."""
        if self.drop_last:
            return sum(size // self.batch_size for size in self.dataset_sizes)
        else:
            return sum((size + self.batch_size - 1) // self.batch_size for size in self.dataset_sizes)

def collate_with_dynamic_padding(batch):
    """
    Custom collate function that handles variable-length sequences and label data.
    """
    # Find the maximum sequence length in this batch
    max_seq_length = max(sample['seq_length'] for sample in batch)
    batch_size = len(batch)
    
    # Initialize the result dictionary
    result = {
        'directory_paths': [],
        'seq_lengths': []
    }
    
    # Store directory paths and sequence lengths
    for sample in batch:
        result['directory_paths'].append(sample['directory_path'])
        result['seq_lengths'].append(sample['seq_length'])
    
    result['seq_lengths'] = torch.tensor(result['seq_lengths'], dtype=torch.long)
    
    # Create mask tensor for frames [batch_size, max_seq_length]
    frame_mask = torch.zeros((batch_size, max_seq_length), dtype=torch.bool)
    
    # Handle variable-sized label data
    # Find maximum dimensions for L_index and L_values
    max_tokens = max(sample['L_index'].shape[0] for sample in batch)
    token_width = batch[0]['L_index'].shape[1]  # Assuming all have same width (6)
    
    # Create padded tensors for labels - using appropriate dtypes
    L_index_padded = torch.zeros((batch_size, max_tokens, token_width), dtype=torch.long)
    L_values_padded = torch.zeros((batch_size, max_tokens, token_width), dtype=torch.float32)
    label_mask = torch.zeros((batch_size, max_tokens), dtype=torch.bool)
    
    # Fill in label data
    for i, sample in enumerate(batch):
        num_tokens = sample['L_index'].shape[0]
        L_index_padded[i, :num_tokens] = torch.tensor(sample['L_index'], dtype=torch.long)
        L_values_padded[i, :num_tokens] = torch.tensor(sample['L_values'], dtype=torch.float32)
        label_mask[i, :num_tokens] = True
    
    result['L_index'] = L_index_padded
    result['L_values'] = L_values_padded
    result['label_mask'] = label_mask
    
    # Process feature data with consistent dimensions
    feature_keys = [
        # Primary features for model input
        'dom_landmarks', 'non_dom_landmarks', 'blendshape_scores',
        'nose_to_wrist_dist', 'dom_velocity_small', 'dom_velocity_large',
        'non_dom_velocity_small', 'non_dom_velocity_large',
        'nose_to_wrist_velocity_small', 'nose_to_wrist_velocity_large',
        
        # Additional data for later use
        'confidence_scores', 'interpolation_scores', 'detection_status',
        'frame_idx', 'velocity_confidence',
        'velocity_calculation_confidence'
    ]
    
    # Process all standard features
    for key in feature_keys:
        try:
            # Get the sample feature
            sample_feature = batch[0][key]
            feature_shape = sample_feature.shape[1:] if len(sample_feature.shape) > 1 else ()
            
            # Create padded tensor [batch_size, max_seq_length, *feature_shape]
            padded_tensor = torch.zeros((batch_size, max_seq_length) + feature_shape, dtype=torch.float32)
            
            # Fill in the actual data and update the mask
            for i, sample in enumerate(batch):
                seq_length = sample['seq_length']
                feature_data = sample[key]
                padded_tensor[i, :seq_length] = torch.tensor(feature_data, dtype=torch.float32)
                frame_mask[i, :seq_length] = True
                
            # Add to result
            result[key] = padded_tensor
            
        except Exception as e:
            print(f"Error processing feature '{key}': {e}")
            print(f"  Shape in first sample: {np.array(batch[0][key]).shape}")
            if i > 0:
                print(f"  Shape in problematic sample {i}: {np.array(sample[key]).shape}")
    
    # Process face_detected separately with proper reshaping
    try:
        # Create a tensor specifically for face_detected (which needs special handling)
        face_detected_tensor = torch.zeros((batch_size, max_seq_length), dtype=torch.float32)
        
        for i, sample in enumerate(batch):
            seq_length = sample['seq_length']
            face_data = sample['face_detected']
            
            # Convert to tensor and ensure it's 1D
            face_tensor = torch.tensor(face_data, dtype=torch.float32)
            
            # Assign directly without reshaping
            face_detected_tensor[i, :seq_length] = face_tensor
            
        result['face_detected'] = face_detected_tensor
        
    except Exception as e:
        print(f"Error processing face_detected: {e}")
        print(f"  Shape: {np.array(batch[0]['face_detected']).shape}")
    
    # Add the frame mask
    result['mask'] = frame_mask
    
    return result


def create_asl_dataloader(low_df, mid_df, high_df, batch_size=16, num_workers=4, drop_last=False):
    """
    Create a data loader for ASL data that ensures batches only contain samples from one dataframe.
    
    Args:
        low_df: DataFrame with low frame count videos
        mid_df: DataFrame with medium frame count videos
        high_df: DataFrame with high frame count videos
        batch_size: Batch size
        num_workers: Number of worker processes for data loading
        drop_last: Whether to drop the last batch if incomplete
        
    Returns:
        A DataLoader that yields batches from the three dataframes
    """
    # Create datasets for each dataframe
    low_dataset = ASLFrameDataset(low_df)
    mid_dataset = ASLFrameDataset(mid_df)
    high_dataset = ASLFrameDataset(high_df)
    
    # Get dataset sizes
    dataset_sizes = [len(low_dataset), len(mid_dataset), len(high_dataset)]
    
    # Combine datasets
    combined_dataset = ConcatDataset([low_dataset, mid_dataset, high_dataset])
    
    # Create a batch sampler that ensures batches only contain samples from one dataframe
    batch_sampler = SingleDataFrameBatchSampler(dataset_sizes, batch_size, drop_last)
    
    # Create the data loader
    data_loader = DataLoader(
        combined_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_with_dynamic_padding
    )
    
    return data_loader

In [321]:
low_df = pd.read_csv("./low_df_only_path.csv")
mid_df = pd.read_csv("./mid_df_only_path.csv")
high_df = pd.read_csv("./high_df_only_path.csv")

In [322]:
loader = create_asl_dataloader(
    low_df=low_df, 
    mid_df=mid_df, 
    high_df=high_df,
    batch_size=2,
    num_workers=0
)
for i, batch in enumerate(loader):
    print(f"Successfully loaded batch {i}")
    print(batch['seq_lengths'])
    if i >= 5:  # Check a few batches
        break

Successfully loaded batch 0
tensor([50, 45])
Successfully loaded batch 1
tensor([42, 48])
Successfully loaded batch 2
tensor([67, 91])
Successfully loaded batch 3
tensor([40, 42])
Successfully loaded batch 4
tensor([104, 105])
Successfully loaded batch 5
tensor([36, 26])


In [323]:
def inspect_tensor(tensor, name, max_items=3):
    """Print shape and sample values from a tensor."""
    print(f"\n{name}:")
    print(f"  Shape: {tensor.shape}")
    print(f"  Type: {tensor.dtype}")
    
    # Print a few values if it's not too large
    if tensor.numel() > 0:
        if tensor.dim() <= 2:
            print(f"  Sample values:\n{tensor[:max_items]}")
        else:
            # For higher dimensional tensors, show the first slice
            print(f"  First slice sample:\n{tensor[0, 0]}")
    
    # Check for NaN or infinity values
    if torch.isnan(tensor).any():
        print("  WARNING: Contains NaN values!")
    if torch.isinf(tensor).any():
        print("  WARNING: Contains infinity values!")

# Print basic batch information
print(f"Batch contains data from {len(batch['directory_paths'])} videos")
print(f"Video paths: {batch['directory_paths']}")
print(f"Sequence lengths: {batch['seq_lengths']}")
print(f"Maximum sequence length in this batch: {batch['mask'].shape[1]}")

# Inspect key tensors
inspect_tensor(batch['mask'], "Frame mask")
inspect_tensor(batch['label_mask'], "Label mask")
inspect_tensor(batch['L_index'], "Label indices")
inspect_tensor(batch['L_values'], "Label values")

# Check primary feature tensors
primary_features = [
    'dom_landmarks', 'non_dom_landmarks', 'blendshape_scores',
    'nose_to_wrist_dist', 'dom_velocity_small', 'dom_velocity_large', 
    'non_dom_velocity_small', 'non_dom_velocity_large',
    'nose_to_wrist_velocity_small', 'nose_to_wrist_velocity_large'
]

for feature in primary_features:
    inspect_tensor(batch[feature], feature)

# Check a few additional data tensors
additional_features = [
    'confidence_scores', 'face_detected', 'velocity_confidence'
]

for feature in additional_features:
    inspect_tensor(batch[feature], feature)

# Visualize the padding with the mask
print("\nFrame mask visualization (1=real data, 0=padding):")
for i in range(len(batch['mask'])):
    seq_len = batch['seq_lengths'][i].item()
    max_len = batch['mask'].shape[1]
    padding = max_len - seq_len
    print(f"Sample {i}: {'1'*seq_len}{'0'*padding} ({seq_len} real frames, {padding} padding)")

# Optional: Check for consistent shapes across batch dimension
print("\nChecking consistency across batch dimension:")
for key in batch:
    if isinstance(batch[key], torch.Tensor) and batch[key].dim() > 0:
        print(f"{key}: First dimension size = {batch[key].shape[0]}")

Batch contains data from 2 videos
Video paths: ['./OpenASL-main/Open_asl_all_data/clips/aWYgWpsTAjo-00:00:11.929-00:00:13.910_fps15_landmarks/', './ASL_Citizen/videos/3837343931767261-OPINION 1_fps15_L_landmarks/']
Sequence lengths: tensor([36, 26])
Maximum sequence length in this batch: 36

Frame mask:
  Shape: torch.Size([2, 36])
  Type: torch.bool
  Sample values:
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True, False, False, False, False,
         False, False, False, False, False, False]])

Label mask:
  Shape: torch.Size([2, 8])
  Type: to

In [324]:


class LandmarkEmbedding(nn.Module):
    """
    Creates learnable embeddings for hand landmarks.
    
    This module maps each landmark (across both hands) to a unique 
    embedding vector that encodes its semantic meaning.
    """
    def __init__(self, embedding_dim, num_landmarks_per_hand=21):
        """
        Initialize the landmark embedding module.
        
        Args:
            embedding_dim: Dimension of the embedding vectors
            num_landmarks_per_hand: Number of landmarks per hand (default: 21)
        """
        super(LandmarkEmbedding, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.num_landmarks_per_hand = num_landmarks_per_hand
        self.total_landmarks = 2 * num_landmarks_per_hand  # Both hands
        
        # Create the embedding table: [total_landmarks, embedding_dim]
        self.embedding_table = nn.Embedding(
            num_embeddings=self.total_landmarks,
            embedding_dim=embedding_dim
        )
        
        # Initialize the embeddings with a normal distribution
        nn.init.normal_(self.embedding_table.weight, mean=0.0, std=0.02)
    
    def forward(self, landmark_indices=None):
        """
        Get embeddings for landmarks.
        
        Args:
            landmark_indices: Optional tensor of landmark indices to retrieve.
                             If None, returns all landmark embeddings.
        
        Returns:
            Tensor of landmark embeddings
        """
        if landmark_indices is None:
            # Return all landmark embeddings
            # Create indices for all landmarks: 0 to total_landmarks-1
            landmark_indices = torch.arange(self.total_landmarks, device=self.embedding_table.weight.device)
        
        # Get the embeddings for the specified indices
        embeddings = self.embedding_table(landmark_indices)
        return embeddings
    
    def get_dominant_hand_embeddings(self):
        """
        Get embeddings for landmarks in the dominant hand.
        
        Returns:
            Tensor of shape [num_landmarks_per_hand, embedding_dim]
        """
        indices = torch.arange(self.num_landmarks_per_hand, 
                              device=self.embedding_table.weight.device)
        return self.embedding_table(indices)
    
    def get_non_dominant_hand_embeddings(self):
        """
        Get embeddings for landmarks in the non-dominant hand.
        
        Returns:
            Tensor of shape [num_landmarks_per_hand, embedding_dim]
        """
        indices = torch.arange(self.num_landmarks_per_hand, self.total_landmarks, 
                              device=self.embedding_table.weight.device)
        return self.embedding_table(indices)

In [325]:
embedding_dim = 36
embedding_table = LandmarkEmbedding(embedding_dim=embedding_dim, num_landmarks_per_hand=21)

In [326]:
embeddings = embedding_table.forward()

In [327]:
embeddings.shape

torch.Size([42, 36])

In [328]:
class LandmarkSpatialEncoder(nn.Module):
    """
    Encodes the spatial information (x,y,z coordinates) of individual hand landmarks.
    
    This module transforms the 3D coordinates of each landmark into a higher-dimensional
    representation that captures the 'where' aspect of the landmark.
    """
    def __init__(self, 
                 embedding_dim, 
                 hidden_dims=None, 
                 num_layers=2,
                 activation='relu',
                 init_method='kaiming_normal',
                 init_gain=1.0,
                 init_nonlinearity='relu'):
        """
        Initialize the spatial encoder with customizable architecture.
        
        Args:
            embedding_dim: Base dimension for the model
            hidden_dims: List of hidden layer dimensions. If None, uses [4*embedding_dim] * num_layers
            num_layers: Number of hidden layers (default: 2)
            activation: Activation function to use ('relu', 'leaky_relu', 'gelu', 'silu', 'tanh', etc.)
            init_method: Weight initialization method ('kaiming_normal', 'kaiming_uniform', 
                        'xavier_normal', 'xavier_uniform', 'normal', 'uniform')
            init_gain: Gain parameter for certain initialization methods
            init_nonlinearity: Nonlinearity parameter for certain initialization methods
        """
        super(LandmarkSpatialEncoder, self).__init__()
        
        # The output dimension will be 2*embedding_dim as requested
        self.output_dim = 2 * embedding_dim
        
        # If hidden_dims not provided, create default configuration
        if hidden_dims is None:
            hidden_dims = [4 * embedding_dim] * num_layers
        
        # Get the activation function
        self.activation_fn = self._get_activation(activation)
        
        # Create layers list starting with input layer
        layers = []
        
        # Input layer
        layers.append(nn.Linear(3, hidden_dims[0]))
        layers.append(self.activation_fn)
        
        # Add hidden layers
        for i in range(1, len(hidden_dims)):
            layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            layers.append(self.activation_fn)
        
        # Add output layer
        layers.append(nn.Linear(hidden_dims[-1], self.output_dim))
        
        # Create the feed-forward network
        self.spatial_encoder = nn.Sequential(*layers)
        
        # Initialize weights using the specified method
        self._init_weights(init_method, init_gain, init_nonlinearity)
    
    def _get_activation(self, activation_name):
        """Get the activation function based on name."""
        activations = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.1),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),  # Also known as Swish
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'elu': nn.ELU(),
            'prelu': nn.PReLU(),
        }
        
        if activation_name.lower() not in activations:
            raise ValueError(f"Activation function '{activation_name}' not supported. "
                           f"Choose from: {', '.join(activations.keys())}")
        
        return activations[activation_name.lower()]
    
    def _init_weights(self, init_method, gain, nonlinearity):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if not isinstance(module, nn.Linear):
                continue
                
            if init_method == 'kaiming_normal':
                nn.init.kaiming_normal_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'xavier_normal':
                nn.init.xavier_normal_(module.weight, gain=gain)
            elif init_method == 'xavier_uniform':
                nn.init.xavier_uniform_(module.weight, gain=gain)
            elif init_method == 'normal':
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif init_method == 'uniform':
                nn.init.uniform_(module.weight, a=-0.1, b=0.1)
            else:
                raise ValueError(f"Initialization method '{init_method}' not supported.")
            
            # Initialize bias if it exists
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, landmarks):
        """
        Encode the spatial coordinates of landmarks.
        
        Args:
            landmarks: Tensor of shape [..., 3] containing x,y,z coordinates
                       The leading dimensions can be anything (batch, sequence, landmark)
        
        Returns:
            Tensor of shape [..., output_dim] with the spatial encodings
        """
        # Get the original shape to reshape the output later
        original_shape = landmarks.shape
        
        # Reshape to [-1, 3] to process all landmarks in parallel
        flat_landmarks = landmarks.reshape(-1, 3)
        
        # Apply the spatial encoder
        encoded = self.spatial_encoder(flat_landmarks)
        
        # Reshape back to original dimensions but with output_dim as the last dimension
        reshaped_encoded = encoded.reshape(*original_shape[:-1], self.output_dim)
        
        return reshaped_encoded

In [329]:
ladmark_encoder = LandmarkSpatialEncoder(embedding_dim, hidden_dims=[30, 60, 30],activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')


In [330]:
batch['dom_landmarks'].shape

torch.Size([2, 36, 20, 3])

In [331]:
dom_landmarks_where = ladmark_encoder.forward(batch['dom_landmarks'])
non_dom_landmarks_where = ladmark_encoder.forward(batch['non_dom_landmarks'])

In [333]:
def combine_spatial_and_semantic_features(spatial_features, semantic_features):
    """
    Combines the spatial encoder output with the semantic embedding features.
    
    This function concatenates the "where" (spatial) information with the "what" 
    (semantic) information to create a comprehensive landmark representation.
    
    Args:
        spatial_features: Tensor of shape [..., n_spatial_encode] where
                         n_spatial_encode = 2*embedding_dim
        semantic_features: Tensor of shape [..., embedding_dim]
    
    Returns:
        Tensor of shape [..., 3*embedding_dim] containing the combined representation
    """

    batch_dims = spatial_features.shape[:-2]
    expanded_embeddings = semantic_features.expand(*batch_dims, -1, -1)
    # Verify that the batch dimensions match
    assert spatial_features.shape[:-1] == expanded_embeddings.shape[:-1], \
        "Batch dimensions of spatial and semantic features must match"
    
    # Concatenate along the last dimension
    combined_features = torch.cat([expanded_embeddings, spatial_features], dim=-1)
    
    return combined_features

In [334]:
dom_landmarks_conc = combine_spatial_and_semantic_features(spatial_features=dom_landmarks_where, semantic_features=embeddings[:20])
non_dom_landmarks_conc = combine_spatial_and_semantic_features(spatial_features=non_dom_landmarks_where, semantic_features=embeddings[21:41])


In [335]:
class WristSpatialEncoder(nn.Module):
    """
    Encodes the spatial information of wrist landmarks relative to the nose.
    
    This module processes the 2D coordinates (x,y) of each wrist independently
    but in parallel, using shared weights across both wrists.
    """
    def __init__(self, 
                 embedding_dim, 
                 hidden_dims=None, 
                 num_layers=2,
                 activation='relu',
                 init_method='kaiming_normal',
                 init_gain=1.0,
                 init_nonlinearity='relu'):
        """
        Initialize the wrist spatial encoder with customizable architecture.
        
        Args:
            embedding_dim: Base dimension for the model
            hidden_dims: List of hidden layer dimensions. If None, uses [4*embedding_dim] * num_layers
            num_layers: Number of hidden layers (default: 2)
            activation: Activation function to use ('relu', 'leaky_relu', 'gelu', 'silu', 'tanh', etc.)
            init_method: Weight initialization method ('kaiming_normal', 'kaiming_uniform', 
                        'xavier_normal', 'xavier_uniform', 'normal', 'uniform')
            init_gain: Gain parameter for certain initialization methods
            init_nonlinearity: Nonlinearity parameter for certain initialization methods
        """
        super(WristSpatialEncoder, self).__init__()
        
        # The output dimension will be 2*embedding_dim as requested
        self.output_dim = 2 * embedding_dim
        
        # If hidden_dims not provided, create default configuration
        if hidden_dims is None:
            hidden_dims = [4 * embedding_dim] * num_layers
        
        # Get the activation function
        self.activation_fn = self._get_activation(activation)
        
        # Create layers list starting with input layer
        layers = []
        
        # Input layer (2D coordinates instead of 3D)
        layers.append(nn.Linear(2, hidden_dims[0]))
        layers.append(self.activation_fn)
        
        # Add hidden layers
        for i in range(1, len(hidden_dims)):
            layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            layers.append(self.activation_fn)
        
        # Add output layer
        layers.append(nn.Linear(hidden_dims[-1], self.output_dim))
        
        # Create the feed-forward network
        self.wrist_encoder = nn.Sequential(*layers)
        
        # Initialize weights using the specified method
        self._init_weights(init_method, init_gain, init_nonlinearity)
    
    def _get_activation(self, activation_name):
        """Get the activation function based on name."""
        activations = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.1),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),  # Also known as Swish
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'elu': nn.ELU(),
            'prelu': nn.PReLU(),
        }
        
        if activation_name.lower() not in activations:
            raise ValueError(f"Activation function '{activation_name}' not supported. "
                           f"Choose from: {', '.join(activations.keys())}")
        
        return activations[activation_name.lower()]
    
    def _init_weights(self, init_method, gain, nonlinearity):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if not isinstance(module, nn.Linear):
                continue
                
            if init_method == 'kaiming_normal':
                nn.init.kaiming_normal_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'xavier_normal':
                nn.init.xavier_normal_(module.weight, gain=gain)
            elif init_method == 'xavier_uniform':
                nn.init.xavier_uniform_(module.weight, gain=gain)
            elif init_method == 'normal':
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif init_method == 'uniform':
                nn.init.uniform_(module.weight, a=-0.1, b=0.1)
            else:
                raise ValueError(f"Initialization method '{init_method}' not supported.")
            
            # Initialize bias if it exists
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, wrist_coordinates):
        """
        Encode the spatial coordinates of wrist landmarks.
        
        Args:
            wrist_coordinates: Tensor of shape [..., 2, 2] containing x,y coordinates
                              for both wrists. Leading dimensions can be anything
                              (batch, sequence), and the last two dimensions are:
                              - Dimension -2: Wrist index (0=dominant, 1=non-dominant)
                              - Dimension -1: Coordinates (x,y)
        
        Returns:
            Tensor of shape [..., 2, output_dim] with the spatial encodings for each wrist
        """
        # Get the original shape to reshape the output later
        original_shape = wrist_coordinates.shape
        
        # Reshape to [-1, 2] to process all wrist coordinates in parallel
        # This flattens all leading dimensions and processes each (x,y) pair independently
        flat_wrists = wrist_coordinates.reshape(-1, 2)
        
        # Apply the wrist encoder
        encoded = self.wrist_encoder(flat_wrists)
        
        # Reshape back to original dimensions but with output_dim as the last dimension
        # Replace the coordinate dimension (2) with output_dim
        new_shape = original_shape[:-1] + (self.output_dim,)
        reshaped_encoded = encoded.reshape(new_shape)
        
        return reshaped_encoded

In [336]:
wrist_encoder = WristSpatialEncoder(embedding_dim, hidden_dims=[30, 60, 30],activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

In [337]:
wrists_where = wrist_encoder.forward(wrist_coordinates=batch['nose_to_wrist_dist'])

In [338]:
def combine_wrist_embedding_and_spatial(wrist_embeddings, wrist_spatial_features):
    """
    Combines wrist semantic embeddings with their spatial features.
    
    This function integrates:
    1. The semantic meaning of each wrist (from embeddings)
    2. The spatial position of each wrist (from the WristSpatialEncoder)
    
    Args:
        wrist_embeddings: Tensor of shape [2, embedding_dim] with wrist embeddings
                         where [0] is dom wrist and [1] is non-dom wrist
        wrist_spatial_features: Tensor of shape [..., 2, 2*embedding_dim] 
                               from the WristSpatialEncoder
    
    Returns:
        Tensor of shape [..., 2, 3*embedding_dim] with the combined representation
    """
    # Get the batch dimensions from the spatial features tensor
    batch_dims = wrist_spatial_features.shape[:-2]
    
    # Expand wrist embeddings to match the batch dimensions
    # From [2, embedding_dim] to [..., 2, embedding_dim]
    expanded_embeddings = wrist_embeddings.expand(*batch_dims, -1, -1)
    
    # Verify that the shapes are compatible for concatenation
    assert expanded_embeddings.shape[:-1] == wrist_spatial_features.shape[:-1], \
        "Batch dimensions of embeddings and spatial features must match"
    
    # Concatenate along the last dimension
    combined_features = torch.cat([
        expanded_embeddings,     # Wrist identity (what)
        wrist_spatial_features   # Wrist position (where)
    ], dim=-1)
    
    return combined_features

In [339]:
embeddings[20].shape

torch.Size([36])

In [340]:
wrists_where.shape

torch.Size([2, 36, 2, 72])

In [341]:
wrists_conc = combine_wrist_embedding_and_spatial(wrist_embeddings=torch.cat([embeddings[20], embeddings[41]], dim=-1).reshape((2,-1)), wrist_spatial_features=wrists_where)

In [342]:
class BlendshapeEncoder(nn.Module):
    """
    Encodes facial blendshape scores into a higher-dimensional representation.
    
    This network processes the 52 facial blendshape parameters that capture
    expressions and face movements relevant to ASL interpretation.
    """
    def __init__(self, 
                 embedding_dim, 
                 hidden_dims=None, 
                 num_layers=2,
                 activation='relu',
                 init_method='kaiming_normal',
                 init_gain=1.0,
                 init_nonlinearity='relu'):
        """
        Initialize the blendshape encoder with customizable architecture.
        
        Args:
            embedding_dim: Base dimension for the model
            hidden_dims: List of hidden layer dimensions. If None, uses [4*embedding_dim] * num_layers
            num_layers: Number of hidden layers (default: 2)
            activation: Activation function to use ('relu', 'leaky_relu', 'gelu', 'silu', 'tanh', etc.)
            init_method: Weight initialization method ('kaiming_normal', 'kaiming_uniform', 
                        'xavier_normal', 'xavier_uniform', 'normal', 'uniform')
            init_gain: Gain parameter for certain initialization methods
            init_nonlinearity: Nonlinearity parameter for certain initialization methods
        """
        super(BlendshapeEncoder, self).__init__()
        
        # The output dimension will be 2*embedding_dim as requested
        self.output_dim = 2 * embedding_dim
        
        # Input dimension for blendshape scores
        self.input_dim = 52
        
        # If hidden_dims not provided, create default configuration
        if hidden_dims is None:
            hidden_dims = [4 * embedding_dim] * num_layers
        
        # Get the activation function
        self.activation_fn = self._get_activation(activation)
        
        # Create layers list starting with input layer
        layers = []
        
        # Input layer (52 blendshape scores)
        layers.append(nn.Linear(self.input_dim, hidden_dims[0]))
        layers.append(self.activation_fn)
        
        # Add hidden layers
        for i in range(1, len(hidden_dims)):
            layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            layers.append(self.activation_fn)
        
        # Add output layer
        layers.append(nn.Linear(hidden_dims[-1], self.output_dim))
        
        # Create the feed-forward network
        self.blendshape_encoder = nn.Sequential(*layers)
        
        # Initialize weights using the specified method
        self._init_weights(init_method, init_gain, init_nonlinearity)
    
    def _get_activation(self, activation_name):
        """Get the activation function based on name."""
        activations = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.1),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),  # Also known as Swish
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'elu': nn.ELU(),
            'prelu': nn.PReLU(),
        }
        
        if activation_name.lower() not in activations:
            raise ValueError(f"Activation function '{activation_name}' not supported. "
                           f"Choose from: {', '.join(activations.keys())}")
        
        return activations[activation_name.lower()]
    
    def _init_weights(self, init_method, gain, nonlinearity):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if not isinstance(module, nn.Linear):
                continue
                
            if init_method == 'kaiming_normal':
                nn.init.kaiming_normal_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'xavier_normal':
                nn.init.xavier_normal_(module.weight, gain=gain)
            elif init_method == 'xavier_uniform':
                nn.init.xavier_uniform_(module.weight, gain=gain)
            elif init_method == 'normal':
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif init_method == 'uniform':
                nn.init.uniform_(module.weight, a=-0.1, b=0.1)
            else:
                raise ValueError(f"Initialization method '{init_method}' not supported.")
            
            # Initialize bias if it exists
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, blendshape_scores):
        """
        Encode the facial blendshape scores.
        
        Args:
            blendshape_scores: Tensor of shape [..., 52] containing facial expression parameters.
                              Leading dimensions can be anything (batch, sequence).
        
        Returns:
            Tensor of shape [..., output_dim] with the encoded facial features
        """
        # Get the original shape to reshape the output later
        original_shape = blendshape_scores.shape
        
        # Reshape to [-1, 52] to process all blendshape scores in parallel
        flat_blendshapes = blendshape_scores.reshape(-1, self.input_dim)
        
        # Apply the blendshape encoder
        encoded = self.blendshape_encoder(flat_blendshapes)
        
        # Reshape back to original dimensions but with output_dim as the last dimension
        # Replace the blendshape dimension (52) with output_dim
        new_shape = original_shape[:-1] + (self.output_dim,)
        reshaped_encoded = encoded.reshape(new_shape)
        
        return reshaped_encoded

In [343]:
blendshapes_feedforward = BlendshapeEncoder(embedding_dim, hidden_dims=None, num_layers=2,activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

In [344]:
blendshapes_encoded = blendshapes_feedforward(batch['blendshape_scores'])

In [345]:
class VelocityEncoder(nn.Module):
    """
    Encodes velocity features of hand landmarks into a higher-dimensional representation.
    
    This network processes the 5 spherical coordinate velocity features for each landmark
    independently but in parallel, using the same weights across all landmarks, hands,
    and velocity windows (small and large).
    """
    def __init__(self, 
                 n_velocity_encoding, 
                 hidden_dims=None, 
                 num_layers=2,
                 activation='relu',
                 init_method='kaiming_normal',
                 init_gain=1.0,
                 init_nonlinearity='relu'):
        """
        Initialize the velocity encoder with customizable architecture.
        
        Args:
            n_velocity_encoding: Output dimension for each landmark's velocity encoding
            hidden_dims: List of hidden layer dimensions. If None, uses [4*n_velocity_encoding] * num_layers
            num_layers: Number of hidden layers (default: 2)
            activation: Activation function to use ('relu', 'leaky_relu', 'gelu', 'silu', 'tanh', etc.)
            init_method: Weight initialization method ('kaiming_normal', 'kaiming_uniform', 
                        'xavier_normal', 'xavier_uniform', 'normal', 'uniform')
            init_gain: Gain parameter for certain initialization methods
            init_nonlinearity: Nonlinearity parameter for certain initialization methods
        """
        super(VelocityEncoder, self).__init__()
        
        # The output dimension as specified
        self.output_dim = n_velocity_encoding
        
        # Input dimension for velocity features (spherical coordinates)
        self.input_dim = 5
        
        # If hidden_dims not provided, create default configuration
        if hidden_dims is None:
            hidden_dims = [4 * n_velocity_encoding] * num_layers
        
        # Get the activation function
        self.activation_fn = self._get_activation(activation)
        
        # Create layers list starting with input layer
        layers = []
        
        # Input layer (5 velocity features in spherical coordinates)
        layers.append(nn.Linear(self.input_dim, hidden_dims[0]))
        layers.append(self.activation_fn)
        
        # Add hidden layers
        for i in range(1, len(hidden_dims)):
            layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            layers.append(self.activation_fn)
        
        # Add output layer
        layers.append(nn.Linear(hidden_dims[-1], self.output_dim))
        
        # Create the feed-forward network
        self.velocity_encoder = nn.Sequential(*layers)
        
        # Initialize weights using the specified method
        self._init_weights(init_method, init_gain, init_nonlinearity)
    
    def _get_activation(self, activation_name):
        """Get the activation function based on name."""
        activations = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.1),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),  # Also known as Swish
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'elu': nn.ELU(),
            'prelu': nn.PReLU(),
        }
        
        if activation_name.lower() not in activations:
            raise ValueError(f"Activation function '{activation_name}' not supported. "
                           f"Choose from: {', '.join(activations.keys())}")
        
        return activations[activation_name.lower()]
    
    def _init_weights(self, init_method, gain, nonlinearity):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if not isinstance(module, nn.Linear):
                continue
                
            if init_method == 'kaiming_normal':
                nn.init.kaiming_normal_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'xavier_normal':
                nn.init.xavier_normal_(module.weight, gain=gain)
            elif init_method == 'xavier_uniform':
                nn.init.xavier_uniform_(module.weight, gain=gain)
            elif init_method == 'normal':
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif init_method == 'uniform':
                nn.init.uniform_(module.weight, a=-0.1, b=0.1)
            else:
                raise ValueError(f"Initialization method '{init_method}' not supported.")
            
            # Initialize bias if it exists
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, velocity_features):
        """
        Encode the velocity features for hand landmarks.
        
        Args:
            velocity_features: Tensor of shape [..., 5] containing velocity features
                              in spherical coordinates. Leading dimensions can be anything
                              (batch, sequence, landmark).
        
        Returns:
            Tensor of shape [..., output_dim] with the encoded velocity features
        """
        # Get the original shape to reshape the output later
        original_shape = velocity_features.shape
        
        # Reshape to [-1, 5] to process all velocity features in parallel
        flat_velocities = velocity_features.reshape(-1, self.input_dim)
        
        # Apply the velocity encoder
        encoded = self.velocity_encoder(flat_velocities)
        
        # Reshape back to original dimensions but with output_dim as the last dimension
        # Replace the velocity dimension (5) with output_dim
        new_shape = original_shape[:-1] + (self.output_dim,)
        reshaped_encoded = encoded.reshape(new_shape)
        
        return reshaped_encoded
    
    def encode_all_velocity_windows(self, dom_vel_small, dom_vel_large, non_dom_vel_small, non_dom_vel_large):
        """
        Encode all four velocity window tensors using the same encoder.
        
        Args:
            dom_vel_small: Dominant hand small window velocities [batch_size, seq_len, 20, 5]
            dom_vel_large: Dominant hand large window velocities [batch_size, seq_len, 20, 5]
            non_dom_vel_small: Non-dominant hand small window velocities [batch_size, seq_len, 20, 5]
            non_dom_vel_large: Non-dominant hand large window velocities [batch_size, seq_len, 20, 5]
            
        Returns:
            Dictionary containing encoded velocity features for all windows
        """
        # Process each velocity window
        dom_small_encoded = self.forward(dom_vel_small)  # [batch_size, seq_len, 20, output_dim]
        dom_large_encoded = self.forward(dom_vel_large)  # [batch_size, seq_len, 20, output_dim]
        non_dom_small_encoded = self.forward(non_dom_vel_small)  # [batch_size, seq_len, 20, output_dim]
        non_dom_large_encoded = self.forward(non_dom_vel_large)  # [batch_size, seq_len, 20, output_dim]
        
        return {
            'dom_velocity_small_encoded': dom_small_encoded,
            'dom_velocity_large_encoded': dom_large_encoded,
            'non_dom_velocity_small_encoded': non_dom_small_encoded,
            'non_dom_velocity_large_encoded': non_dom_large_encoded
        }

In [346]:
velocity_feedforward = VelocityEncoder(n_velocity_encoding=2*embedding_dim, hidden_dims=None, num_layers=2,activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

In [347]:
dom_small_vel_encoded = velocity_feedforward.forward(batch['dom_velocity_small']) 
dom_large_vel_encoded = velocity_feedforward.forward(batch['dom_velocity_large']) 
non_dom_small_vel_encoded = velocity_feedforward.forward(batch['non_dom_velocity_small']) 
non_dom_large_vel_encoded = velocity_feedforward.forward(batch['non_dom_velocity_large']) 

In [348]:
def combine_semantic_and_velocity_features(semantic_features, velocity_small_features, velocity_large_features):
    """
    Combines landmark semantic embeddings with velocity features from both time windows.
    
    This function concatenates:
    1. The "what" (semantic embedding) of each landmark
    2. The "how fast small window" (small window velocity encoding)
    3. The "how fast large window" (large window velocity encoding)
    
    Args:
        semantic_features: Tensor of shape [..., embedding_dim] containing landmark embeddings
        velocity_small_features: Tensor of shape [..., n_velocity_encoding] from small window
        velocity_large_features: Tensor of shape [..., n_velocity_encoding] from large window
    
    Returns:
        Tensor of shape [..., embedding_dim + 2*n_velocity_encoding] with the combined representation
    """
    batch_shape = velocity_small_features.shape[:-2]
    semantic_features_expanded = semantic_features.expand(*batch_shape, -1, -1)
    # Verify that the batch dimensions match
    assert semantic_features_expanded.shape[:-1] == velocity_small_features.shape[:-1] == velocity_large_features.shape[:-1], \
        "Batch dimensions of semantic and velocity features must match"
    
    # Concatenate all three feature types along the last dimension
    combined_features = torch.cat([
        semantic_features_expanded,        # Landmark identity (what)
        velocity_small_features,  # Short-term movement (how fast recently)
        velocity_large_features   # Long-term movement (how fast overall)
    ], dim=-1)
    
    return combined_features

In [349]:
dom_landmarks_velocity_conc = combine_semantic_and_velocity_features(semantic_features=embeddings[:20], velocity_small_features=dom_small_vel_encoded, velocity_large_features=dom_large_vel_encoded)
non_dom_landmarks_velocity_conc = combine_semantic_and_velocity_features(semantic_features=embeddings[21:41], velocity_small_features=non_dom_small_vel_encoded, velocity_large_features=non_dom_large_vel_encoded)

In [350]:
dom_landmarks_velocity_conc.shape

torch.Size([2, 36, 20, 180])

In [351]:
class WristVelocityEncoder(nn.Module):
    """
    Encodes velocity features of wrist landmarks relative to the nose.
    
    This network processes the 3 polar coordinate velocity features for each wrist
    independently but in parallel, using the same weights across both wrists
    and both velocity windows (small and large).
    """
    def __init__(self, 
                 n_velocity_encoding, 
                 hidden_dims=None, 
                 num_layers=2,
                 activation='relu',
                 init_method='kaiming_normal',
                 init_gain=1.0,
                 init_nonlinearity='relu'):
        """
        Initialize the wrist velocity encoder with customizable architecture.
        
        Args:
            n_velocity_encoding: Output dimension for each wrist's velocity encoding
            hidden_dims: List of hidden layer dimensions. If None, uses [4*n_velocity_encoding] * num_layers
            num_layers: Number of hidden layers (default: 2)
            activation: Activation function to use ('relu', 'leaky_relu', 'gelu', 'silu', 'tanh', etc.)
            init_method: Weight initialization method ('kaiming_normal', 'kaiming_uniform', 
                        'xavier_normal', 'xavier_uniform', 'normal', 'uniform')
            init_gain: Gain parameter for certain initialization methods
            init_nonlinearity: Nonlinearity parameter for certain initialization methods
        """
        super(WristVelocityEncoder, self).__init__()
        
        # The output dimension as specified
        self.output_dim = n_velocity_encoding
        
        # Input dimension for wrist velocity features (polar coordinates)
        self.input_dim = 3
        
        # If hidden_dims not provided, create default configuration
        if hidden_dims is None:
            hidden_dims = [4 * n_velocity_encoding] * num_layers
        
        # Get the activation function
        self.activation_fn = self._get_activation(activation)
        
        # Create layers list starting with input layer
        layers = []
        
        # Input layer (3 velocity features in polar coordinates)
        layers.append(nn.Linear(self.input_dim, hidden_dims[0]))
        layers.append(self.activation_fn)
        
        # Add hidden layers
        for i in range(1, len(hidden_dims)):
            layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            layers.append(self.activation_fn)
        
        # Add output layer
        layers.append(nn.Linear(hidden_dims[-1], self.output_dim))
        
        # Create the feed-forward network
        self.wrist_velocity_encoder = nn.Sequential(*layers)
        
        # Initialize weights using the specified method
        self._init_weights(init_method, init_gain, init_nonlinearity)
    
    def _get_activation(self, activation_name):
        """Get the activation function based on name."""
        activations = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.1),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),  # Also known as Swish
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'elu': nn.ELU(),
            'prelu': nn.PReLU(),
        }
        
        if activation_name.lower() not in activations:
            raise ValueError(f"Activation function '{activation_name}' not supported. "
                           f"Choose from: {', '.join(activations.keys())}")
        
        return activations[activation_name.lower()]
    
    def _init_weights(self, init_method, gain, nonlinearity):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if not isinstance(module, nn.Linear):
                continue
                
            if init_method == 'kaiming_normal':
                nn.init.kaiming_normal_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(module.weight, a=0.0, nonlinearity=nonlinearity)
            elif init_method == 'xavier_normal':
                nn.init.xavier_normal_(module.weight, gain=gain)
            elif init_method == 'xavier_uniform':
                nn.init.xavier_uniform_(module.weight, gain=gain)
            elif init_method == 'normal':
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif init_method == 'uniform':
                nn.init.uniform_(module.weight, a=-0.1, b=0.1)
            else:
                raise ValueError(f"Initialization method '{init_method}' not supported.")
            
            # Initialize bias if it exists
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, wrist_velocity_features):
        """
        Encode the velocity features for wrist landmarks.
        
        Args:
            wrist_velocity_features: Tensor of shape [..., 3] containing velocity features
                                    in polar coordinates. Leading dimensions can be anything
                                    (batch, sequence, wrist).
        
        Returns:
            Tensor of shape [..., output_dim] with the encoded velocity features
        """
        # Get the original shape to reshape the output later
        original_shape = wrist_velocity_features.shape
        
        # Reshape to [-1, 3] to process all velocity features in parallel
        flat_velocities = wrist_velocity_features.reshape(-1, self.input_dim)
        
        # Apply the wrist velocity encoder
        encoded = self.wrist_velocity_encoder(flat_velocities)
        
        # Reshape back to original dimensions but with output_dim as the last dimension
        # Replace the velocity dimension (3) with output_dim
        new_shape = original_shape[:-1] + (self.output_dim,)
        reshaped_encoded = encoded.reshape(new_shape)
        
        return reshaped_encoded
    
    def encode_both_velocity_windows(self, wrist_vel_small, wrist_vel_large):
        """
        Encode both velocity window tensors for wrists using the same encoder.
        
        Args:
            wrist_vel_small: Wrist small window velocities [batch_size, seq_len, 2, 3]
            wrist_vel_large: Wrist large window velocities [batch_size, seq_len, 2, 3]
            
        Returns:
            Dictionary containing encoded velocity features for both windows
        """
        # Process each velocity window
        small_window_encoded = self.forward(wrist_vel_small)  # [batch_size, seq_len, 2, output_dim]
        large_window_encoded = self.forward(wrist_vel_large)  # [batch_size, seq_len, 2, output_dim]
        
        return {
            'wrist_velocity_small_encoded': small_window_encoded,
            'wrist_velocity_large_encoded': large_window_encoded
        }

In [352]:
wrist_vel_feedforward = WristVelocityEncoder(n_velocity_encoding=2*embedding_dim, hidden_dims=None, num_layers=2,activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

In [353]:
wrist_vel_small_encoded = wrist_vel_feedforward.forward(batch['nose_to_wrist_velocity_small'])
wrist_vel_large_encoded = wrist_vel_feedforward.forward(batch['nose_to_wrist_velocity_large'])

In [354]:
wrist_vel_small_encoded.shape

torch.Size([2, 36, 2, 72])

In [355]:
def combine_wrist_embedding_and_velocity(wrist_embeddings, wrist_velocity_small, wrist_velocity_large):
    """
    Combines wrist semantic embeddings with velocity features from both time windows.
    
    This function handles the specific arrangement of wrist data in your model:
    - In embeddings: Wrists are at indices 20 (dom) and 41 (non-dom) in the embedding table
    - In velocity tensors: Wrists are at indices 0 (dom) and 1 (non-dom)
    
    Args:
        wrist_embeddings: Tensor of shape [2, embedding_dim] with wrist embeddings
                         where [0] is dom wrist and [1] is non-dom wrist
        wrist_velocity_small: Tensor of shape [..., 2, n_velocity_encoding] 
                             from small window velocity encoder
        wrist_velocity_large: Tensor of shape [..., 2, n_velocity_encoding] 
                             from large window velocity encoder
    
    Returns:
        Tensor of shape [..., 2, embedding_dim + 2*n_velocity_encoding] 
        with the combined representation for both wrists
    """
    # Get the batch dimensions from the velocity tensors
    batch_dims = wrist_velocity_small.shape[:-2]
    
    # Expand wrist embeddings to match the batch dimensions
    # From [2, embedding_dim] to [..., 2, embedding_dim]
    expanded_embeddings = wrist_embeddings.expand(*batch_dims, -1, -1)
    
    # Verify that the shapes are compatible for concatenation
    assert expanded_embeddings.shape[:-1] == wrist_velocity_small.shape[:-1] == wrist_velocity_large.shape[:-1], \
        "Batch dimensions of embeddings and velocity features must match"
    
    # Concatenate along the last dimension
    combined_features = torch.cat([
        expanded_embeddings,      # Wrist identity (what)
        wrist_velocity_small,     # Short-term movement (how fast recently)
        wrist_velocity_large      # Long-term movement (how fast overall)
    ], dim=-1)
    
    return combined_features

In [356]:
torch.cat([embeddings[20], embeddings[41]], dim=-1).reshape((2,-1)).shape

torch.Size([2, 36])

In [357]:
wrists_vel_conc = combine_wrist_embedding_and_velocity(wrist_embeddings=torch.cat([embeddings[20], embeddings[41]], dim=-1).reshape((2,-1)), wrist_velocity_small=wrist_vel_small_encoded, wrist_velocity_large=wrist_vel_large_encoded)

In [358]:
class LandmarkTransformerEncoder(nn.Module):
    """
    Transformer encoder for processing hand landmarks and learning contextual relationships.
    
    This module treats the set of landmarks as a sequence and applies self-attention
    to learn the relationships between different parts of the hand.
    """
    def __init__(self, 
                 input_dim, 
                 num_layers=2,
                 num_heads=8,
                 hidden_dim=None,
                 ff_dim=None,
                 prenorm=True,
                 activation='gelu',
                 init_method='xavier_uniform',
                 init_gain=1.0):
        """
        Initialize the landmark transformer encoder.
        
        Args:
            input_dim: Dimension of input features per landmark (3*embedding_dim)
            num_layers: Number of transformer encoder layers
            num_heads: Number of attention heads
            hidden_dim: Hidden dimension size (if None, uses input_dim)
            ff_dim: Feed-forward dimension (if None, uses 4*hidden_dim)
            prenorm: Whether to use pre-norm (True) or post-norm (False) architecture
            activation: Activation function in feed-forward network
            init_method: Weight initialization method
            init_gain: Gain parameter for initialization
        """
        super(LandmarkTransformerEncoder, self).__init__()
        
        # Set dimensions
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim
        self.ff_dim = ff_dim if ff_dim is not None else 4 * self.hidden_dim
        
        # Input projection if needed
        self.input_projection = None
        if self.input_dim != self.hidden_dim:
            self.input_projection = nn.Linear(self.input_dim, self.hidden_dim)
        
        # Create transformer encoder layers
        encoder_layer = LandmarkTransformerLayer(
            hidden_dim=self.hidden_dim,
            num_heads=num_heads,
            ff_dim=self.ff_dim,
            prenorm=prenorm,
            activation=activation
        )
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        
        # Final normalization
        self.norm = nn.LayerNorm(self.hidden_dim)
        
        # Initialize weights
        self._init_weights(init_method, init_gain)
    
    def _init_weights(self, init_method, gain):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                if init_method == 'xavier_uniform':
                    nn.init.xavier_uniform_(module.weight, gain=gain)
                elif init_method == 'xavier_normal':
                    nn.init.xavier_normal_(module.weight, gain=gain)
                elif init_method == 'kaiming_uniform':
                    nn.init.kaiming_uniform_(module.weight, a=0, mode='fan_in')
                elif init_method == 'kaiming_normal':
                    nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in')
                else:
                    raise ValueError(f"Initialization method '{init_method}' not supported.")
                
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, x):
        """
        Process hand landmarks through the transformer.
        
        Args:
            x: Tensor of shape [batch_size, seq_len, 20, input_dim]
               where 20 is the number of landmarks and input_dim is 3*embedding_dim
        
        Returns:
            Tensor of shape [batch_size, seq_len, 20, hidden_dim]
            with contextually enriched landmark representations
        """
        # Get original shape
        batch_size, seq_len, num_landmarks, _ = x.shape
        
        # Reshape to process each frame separately
        # [batch_size * seq_len, 20, input_dim]
        x_reshaped = x.reshape(-1, num_landmarks, self.input_dim)
        
        # Apply input projection if needed
        if self.input_projection is not None:
            x_reshaped = self.input_projection(x_reshaped)
        
        # Process through transformer layers
        for layer in self.layers:
            x_reshaped = layer(x_reshaped)
        
        # Apply final normalization
        x_reshaped = self.norm(x_reshaped)
        
        # Reshape back to original dimensions
        # [batch_size, seq_len, 20, hidden_dim]
        output = x_reshaped.reshape(batch_size, seq_len, num_landmarks, self.hidden_dim)
        
        return output


class LandmarkTransformerLayer(nn.Module):
    """
    Single transformer encoder layer for landmark processing.
    """
    def __init__(self, hidden_dim, num_heads, ff_dim, prenorm=True, activation='gelu'):
        """
        Initialize a transformer encoder layer.
        
        Args:
            hidden_dim: Hidden dimension size
            num_heads: Number of attention heads
            ff_dim: Feed-forward dimension
            prenorm: Whether to use pre-norm (True) or post-norm (False)
            activation: Activation function in feed-forward network
        """
        super(LandmarkTransformerLayer, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.prenorm = prenorm
        
        # Multi-head attention
        self.self_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True
        )
        
        # Feed-forward network
        self.ff_network = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim),
            self._get_activation(activation),
            nn.Linear(ff_dim, hidden_dim)
        )
        
        # Layer normalizations
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
    
    def _get_activation(self, name):
        """Get activation function by name."""
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'silu' or name.lower() == 'swish':
            return nn.SiLU()
        else:
            raise ValueError(f"Activation function '{name}' not supported.")
    
    def forward(self, x):
        """
        Process landmarks through a transformer layer.
        
        Args:
            x: Tensor of shape [batch_size*seq_len, 20, hidden_dim]
               representing landmarks in a single frame
        
        Returns:
            Tensor of same shape with contextualized representations
        """
        # Pre-norm or post-norm architecture
        if self.prenorm:
            # Pre-norm: Apply normalization before attention
            norm_x = self.norm1(x)
            attn_output, _ = self.self_attention(norm_x, norm_x, norm_x)
            x = x + attn_output  # Residual connection
            
            # Feed-forward with normalization
            norm_x = self.norm2(x)
            ff_output = self.ff_network(norm_x)
            x = x + ff_output  # Residual connection
        else:
            # Post-norm: Apply attention then normalization
            attn_output, _ = self.self_attention(x, x, x)
            x = self.norm1(x + attn_output)  # Residual connection and norm
            
            # Feed-forward and normalization
            ff_output = self.ff_network(x)
            x = self.norm2(x + ff_output)  # Residual connection and norm
        
        return x

In [359]:
dom_landmarks_conc.shape

torch.Size([2, 36, 20, 108])

In [360]:
dom_transformer = LandmarkTransformerEncoder(input_dim=3 * embedding_dim,num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)
non_dom_transformer = LandmarkTransformerEncoder(input_dim=3 * embedding_dim,num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)


In [361]:
dom_contextualized = dom_transformer(dom_landmarks_conc)
non_dom_contextualized=non_dom_transformer(non_dom_landmarks_conc)

In [362]:
dom_contextualized.shape

torch.Size([2, 36, 20, 256])

In [363]:

class LandmarkAttentionPooling(nn.Module):
    """
    Applies attention pooling over landmarks using PyTorch's MultiheadAttention.
    """
    def __init__(self, input_dim, output_dim):
        """
        Initialize the attention pooling module.
        
        Args:
            input_dim: Dimension of input features per landmark
            output_dim: Dimension of the output representation
        """
        super(LandmarkAttentionPooling, self).__init__()
        
        # Using PyTorch's built-in attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=input_dim,
            num_heads=1,  # Single head is sufficient for pooling
            batch_first=True
        )
        
        # Learnable query vector
        self.query = nn.Parameter(torch.randn(1, 1, input_dim))
        
        # Output projection
        self.output_projection = nn.Linear(input_dim, output_dim)
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(input_dim)
    
    def forward(self, x):
        """
        Apply attention pooling over landmarks.
        
        Args:
            x: Tensor of shape [batch_size, seq_len, num_landmarks, input_dim]
        
        Returns:
            Tensor of shape [batch_size, seq_len, output_dim]
        """
        batch_size, seq_len, num_landmarks, input_dim = x.shape
        
        # Reshape to process each sequence element separately
        x_reshaped = x.reshape(batch_size * seq_len, num_landmarks, input_dim)
        
        # Apply layer normalization
        x_norm = self.layer_norm(x_reshaped)
        
        # Expand query to match the batch size
        query = self.query.expand(batch_size * seq_len, -1, -1)
        
        # Apply attention
        # The query attends to all landmarks (keys and values are the same: x_norm)
        pooled, _ = self.attention(query, x_norm, x_norm)
        
        # Remove the sequence dimension (which is 1 for the query)
        pooled = pooled.squeeze(1)  # [batch_size * seq_len, input_dim]
        
        # Project to output dimension
        output = self.output_projection(pooled)  # [batch_size * seq_len, output_dim]
        
        # Reshape back to [batch_size, seq_len, output_dim]
        output = output.reshape(batch_size, seq_len, -1)
        
        return output

In [364]:
dom_pooling = LandmarkAttentionPooling(input_dim=dom_contextualized.shape[-1],output_dim=256)
non_dom_pooling = LandmarkAttentionPooling(input_dim=non_dom_contextualized.shape[-1],output_dim=256)

In [365]:
dom_pooled = dom_pooling(dom_contextualized)
non_dom_pooled = non_dom_pooling(non_dom_contextualized)

In [366]:
wrists_conc.shape

torch.Size([2, 36, 2, 108])

In [367]:
wrists_conc[:,:,0].shape

torch.Size([2, 36, 108])

In [368]:
dom_wrist_conc = wrists_conc[:,:,0]
non_dom_wrist_conc = wrists_conc[:,:,1]

In [369]:
dom_pooled.shape

torch.Size([2, 36, 256])

In [370]:
def concat_pooled_wrists(pooled, wrist):
# Verify that the shapes are compatible for concatenation
    assert pooled.shape[:-1] == wrist.shape[:-1], \
        "Batch dimensions of embeddings and spatial features must match"
    
    # Concatenate along the last dimension
    combined_features = torch.cat([
        pooled,     # Wrist identity (what)
        wrist   # Wrist position (where)
    ], dim=-1)
    
    return combined_features

In [371]:
dom_spatial_combined = concat_pooled_wrists(pooled=dom_pooled, wrist=dom_wrist_conc)
non_dom_spatial_combined = concat_pooled_wrists(pooled=non_dom_pooled, wrist=non_dom_wrist_conc)

In [372]:
dom_spatial_combined.shape

torch.Size([2, 36, 364])

In [373]:
dom_landmarks_velocity_conc.shape

torch.Size([2, 36, 20, 180])

In [374]:
dom_vel_transformer = LandmarkTransformerEncoder(input_dim=dom_landmarks_velocity_conc.shape[-1],num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)
non_dom_vel_transformer = LandmarkTransformerEncoder(input_dim=dom_landmarks_velocity_conc.shape[-1],num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)


In [375]:
dom_vel_contextualized = dom_vel_transformer(dom_landmarks_velocity_conc)
non_dom_vel_contextualized=non_dom_vel_transformer(non_dom_landmarks_velocity_conc)

In [376]:
dom_vel_pooling = LandmarkAttentionPooling(input_dim=dom_vel_contextualized.shape[-1],output_dim=256)
non_dom_vel_pooling = LandmarkAttentionPooling(input_dim=non_dom_vel_contextualized.shape[-1],output_dim=256)

In [377]:
dom_vel_pooled = dom_vel_pooling(dom_vel_contextualized)
non_dom_vel_pooled = non_dom_vel_pooling(non_dom_vel_contextualized)

In [378]:
dom_vel_pooled.shape

torch.Size([2, 36, 256])

In [379]:
dom_wrist_vel_conc = wrists_vel_conc[:,:,0]
non_dom_wrist_vel_conc = wrists_vel_conc[:,:,1]

In [380]:
dom_wrist_vel_conc.shape

torch.Size([2, 36, 180])

In [381]:
dom_velocity_combined = concat_pooled_wrists(pooled=dom_vel_pooled, wrist=dom_wrist_vel_conc)
non_dom_velocity_combined = concat_pooled_wrists(pooled=non_dom_vel_pooled, wrist=non_dom_wrist_vel_conc)

In [382]:
dom_velocity_combined.shape

torch.Size([2, 36, 436])

In [383]:
dom_spatial_combined.shape

torch.Size([2, 36, 364])

In [384]:
dom_combined = concat_pooled_wrists(dom_spatial_combined, dom_velocity_combined)
non_dom_combined = concat_pooled_wrists(non_dom_spatial_combined, non_dom_velocity_combined)

In [385]:
dom_combined.shape

torch.Size([2, 36, 800])

In [386]:
hands_combined = torch.stack([dom_combined, non_dom_combined], dim=2)
hands_combined.shape

torch.Size([2, 36, 2, 800])

In [388]:

class ConfidenceWeightedTransformerEncoder(nn.Module):
    """
    Transformer encoder that incorporates confidence scores into attention calculations.
    
    This second-stage transformer learns relationships between the two hands while
    taking into account confidence and interpolation scores from both spatial and
    velocity features.
    """
    def __init__(self, 
                 input_dim, 
                 num_layers=2,
                 num_heads=8,
                 hidden_dim=None,
                 ff_dim=None,
                 prenorm=True,
                 activation='gelu',
                 init_method='xavier_uniform',
                 init_gain=1.0):
        """
        Initialize the confidence-weighted transformer encoder.
        
        Args:
            input_dim: Dimension of input features per hand
            num_layers: Number of transformer encoder layers
            num_heads: Number of attention heads
            hidden_dim: Hidden dimension size (if None, uses input_dim)
            ff_dim: Feed-forward dimension (if None, uses 4*hidden_dim)
            prenorm: Whether to use pre-norm (True) or post-norm (False) architecture
            activation: Activation function in feed-forward network
            init_method: Weight initialization method
            init_gain: Gain parameter for initialization
        """
        super(ConfidenceWeightedTransformerEncoder, self).__init__()
        
        # Set dimensions
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim is not None else input_dim
        self.ff_dim = ff_dim if ff_dim is not None else 4 * self.hidden_dim
        
        # Input projection if needed
        self.input_projection = None
        if self.input_dim != self.hidden_dim:
            self.input_projection = nn.Linear(self.input_dim, self.hidden_dim)
        
        # Create transformer encoder layers with confidence weighting
        layers = []
        for _ in range(num_layers):
            layers.append(
                ConfidenceWeightedTransformerLayer(
                    hidden_dim=self.hidden_dim,
                    num_heads=num_heads,
                    ff_dim=self.ff_dim,
                    prenorm=prenorm,
                    activation=activation
                )
            )
        self.layers = nn.ModuleList(layers)
        
        # Final normalization
        self.norm = nn.LayerNorm(self.hidden_dim)
        
        # Initialize weights
        self._init_weights(init_method, init_gain)
    
    def _init_weights(self, init_method, gain):
        """Initialize the weights using the specified method."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                if init_method == 'xavier_uniform':
                    nn.init.xavier_uniform_(module.weight, gain=gain)
                elif init_method == 'xavier_normal':
                    nn.init.xavier_normal_(module.weight, gain=gain)
                elif init_method == 'kaiming_uniform':
                    nn.init.kaiming_uniform_(module.weight, a=0, mode='fan_in')
                elif init_method == 'kaiming_normal':
                    nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in')
                else:
                    raise ValueError(f"Initialization method '{init_method}' not supported.")
                
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, x, confidence_scores):
        """
        Process hand features through the transformer with confidence weighting.
        
        Args:
            x: Tensor of shape [batch_size, seq_len, 2, input_dim]
               where 2 represents the dom and non-dom hands
            confidence_scores: Dictionary containing:
                - Cd_spatial: [batch_size, seq_len, 2] confidence scores
                - Ci_spatial: [batch_size, seq_len, 2] interpolation scores
                - Cd_velocity: [batch_size, seq_len, 2] velocity calculation confidence
                - Ci_velocity: [batch_size, seq_len, 2] velocity confidence
        
        Returns:
            Tensor of shape [batch_size, seq_len, 2, hidden_dim]
            with confidence-weighted contextual representations
        """
        # Get original shape
        batch_size, seq_len, num_hands, _ = x.shape
        
        # Reshape to process each frame separately
        # [batch_size * seq_len, 2, input_dim]
        x_reshaped = x.reshape(-1, num_hands, self.input_dim)
        
        # Apply input projection if needed
        if self.input_projection is not None:
            x_reshaped = self.input_projection(x_reshaped)
        
        # Reshape confidence scores for per-frame processing
        conf_scores_reshaped = {}
        for key, tensor in confidence_scores.items():
            conf_scores_reshaped[key] = tensor.reshape(-1, num_hands)
        
        # Process through transformer layers
        for layer in self.layers:
            x_reshaped = layer(x_reshaped, conf_scores_reshaped)
        
        # Apply final normalization
        x_reshaped = self.norm(x_reshaped)
        
        # Reshape back to original dimensions
        # [batch_size, seq_len, 2, hidden_dim]
        output = x_reshaped.reshape(batch_size, seq_len, num_hands, self.hidden_dim)
        
        return output


class ConfidenceWeightedTransformerLayer(nn.Module):
    """
    Transformer encoder layer with confidence-weighted attention.
    """
    def __init__(self, hidden_dim, num_heads, ff_dim, prenorm=True, activation='gelu'):
        """
        Initialize a confidence-weighted transformer encoder layer.
        
        Args:
            hidden_dim: Hidden dimension size
            num_heads: Number of attention heads
            ff_dim: Feed-forward dimension
            prenorm: Whether to use pre-norm (True) or post-norm (False)
            activation: Activation function in feed-forward network
        """
        super(ConfidenceWeightedTransformerLayer, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.prenorm = prenorm
        self.num_heads = num_heads
        
        # Custom attention with confidence weighting
        self.self_attention = ConfidenceWeightedAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads
        )
        
        # Feed-forward network
        self.ff_network = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim),
            self._get_activation(activation),
            nn.Linear(ff_dim, hidden_dim)
        )
        
        # Layer normalizations
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
    
    def _get_activation(self, name):
        """Get activation function by name."""
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'silu' or name.lower() == 'swish':
            return nn.SiLU()
        else:
            raise ValueError(f"Activation function '{name}' not supported.")
    
    def forward(self, x, confidence_scores):
        """
        Process through a transformer layer with confidence-weighted attention.
        
        Args:
            x: Tensor of shape [batch_size*seq_len, 2, hidden_dim]
            confidence_scores: Dictionary of confidence scores
            
        Returns:
            Tensor of same shape with contextualized representations
        """
        # Pre-norm or post-norm architecture
        if self.prenorm:
            # Pre-norm: Apply normalization before attention
            norm_x = self.norm1(x)
            attn_output = self.self_attention(norm_x, norm_x, norm_x, confidence_scores)
            x = x + attn_output  # Residual connection
            
            # Feed-forward with normalization
            norm_x = self.norm2(x)
            ff_output = self.ff_network(norm_x)
            x = x + ff_output  # Residual connection
        else:
            # Post-norm: Apply attention then normalization
            attn_output = self.self_attention(x, x, x, confidence_scores)
            x = self.norm1(x + attn_output)  # Residual connection and norm
            
            # Feed-forward and normalization
            ff_output = self.ff_network(x)
            x = self.norm2(x + ff_output)  # Residual connection and norm
        
        return x


class ConfidenceWeightedAttention(nn.Module):
    """
    Multi-head attention with confidence weighting.
    
    This applies the formula:
    Attention(Q,K,V,Cd_spatial,Ci_spatial,Cd_velocity,Ci_velocity) = 
        softmax(QK^T/sqrt(dk) + f(Cd_spatial,Ci_spatial,Cd_velocity,Ci_velocity))V
    """
    def __init__(self, embed_dim, num_heads):
        super(ConfidenceWeightedAttention, self).__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # Learnable parameters for confidence weighting
        self.a = nn.Parameter(torch.zeros(1))  # For Cd_spatial
        self.b = nn.Parameter(torch.zeros(1))  # For Cd_velocity
        self.c = nn.Parameter(torch.zeros(1))  # For Ci_spatial
        self.d = nn.Parameter(torch.zeros(1))  # For Ci_velocity
        
        # Small epsilon to avoid log(0)
        self.epsilon = 0.01
    
  
    def compute_confidence_weights(self, confidence_scores):
        """
        Compute the confidence weighting matrix f(Cd_spatial, Ci_spatial, Cd_velocity, Ci_velocity).

        Args:
            confidence_scores: Dictionary with confidence score tensors of shape [flattened_batch_size, 2]

        Returns:
            Tensor of shape [flattened_batch_size, 2, 2] for weighting attention scores
        """
        Cd_spatial = confidence_scores['Cd_spatial']
        Ci_spatial = confidence_scores['Ci_spatial']
        Cd_velocity = confidence_scores['Cd_velocity']
        Ci_velocity = confidence_scores['Ci_velocity']

        # These tensors have shape [flattened_batch_size, 2]
        flattened_batch_size, num_hands = Cd_spatial.shape

        # Apply the confidence weighting formula
        f_values = (
            torch.log2(self.epsilon + Cd_spatial) * torch.sigmoid(self.a) * 0.25 +
            torch.log2(self.epsilon + Cd_velocity) * torch.sigmoid(self.b) * 0.25 +
            torch.log2(self.epsilon + Ci_spatial) * torch.sigmoid(self.c) * 0.5 +
            torch.log2(self.epsilon + Ci_velocity) * torch.sigmoid(self.d) * 0.5
        )
    
        
        # Create the 2x2 matrix for each batch item where columns have same values
        confidence_matrix = f_values.unsqueeze(1).expand(-1, num_hands, -1)
        
        
        return confidence_matrix
    
    def forward(self, query, key, value, confidence_scores):
        """
        Apply confidence-weighted attention.
        
        Args:
            query, key, value: Tensors of shape [batch_size*seq_len, num_hands, embed_dim]
                              where batch_size*seq_len represents flattened batch and sequence dimensions
            confidence_scores: Dictionary of confidence scores
            
        Returns:
            Attention output tensor of same shape
        """
        # Get the shape components - note there's no separate sequence dimension here!
        flattened_batch_size, num_hands, embed_dim = query.shape
        
        # Linear projections
        q = self.q_proj(query)  # [flattened_batch_size, num_hands, embed_dim]
        k = self.k_proj(key)    # [flattened_batch_size, num_hands, embed_dim]
        v = self.v_proj(value)  # [flattened_batch_size, num_hands, embed_dim]
        
        # Compute confidence weights
        # This should return: [flattened_batch_size, num_hands, num_hands]
        confidence_weights = self.compute_confidence_weights(confidence_scores)
        
        # Reshape for multi-head attention
        # Split embed_dim into num_heads × head_dim
        q = q.reshape(flattened_batch_size, num_hands, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)  # [flattened_batch_size, num_heads, num_hands, head_dim]
        
        k = k.reshape(flattened_batch_size, num_hands, self.num_heads, self.head_dim)
        k = k.permute(0, 2, 1, 3)  # [flattened_batch_size, num_heads, num_hands, head_dim]
        
        v = v.reshape(flattened_batch_size, num_hands, self.num_heads, self.head_dim)
        v = v.permute(0, 2, 1, 3)  # [flattened_batch_size, num_heads, num_hands, head_dim]
        
        # Calculate attention scores
        # [flattened_batch_size, num_heads, num_hands, num_hands]
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Add confidence weights to attention scores
        # Expand confidence_weights for all heads
        # [flattened_batch_size, 1, num_hands, num_hands]
        confidence_weights = confidence_weights.unsqueeze(1)
        
        # Add confidence weights to attention scores
        attention_scores = attention_scores + confidence_weights
        
        # Apply softmax
        attention_probs = F.softmax(attention_scores, dim=-1)
        
        # Apply attention to values
        # [flattened_batch_size, num_heads, num_hands, head_dim]
        context = torch.matmul(attention_probs, v)
        
        # Reshape back
        context = context.permute(0, 2, 1, 3)  # [flattened_batch_size, num_hands, num_heads, head_dim]
        context = context.reshape(flattened_batch_size, num_hands, embed_dim)
        
        # Final projection
        output = self.out_proj(context)  # [flattened_batch_size, num_hands, embed_dim]
        
        return output

In [389]:
cross_hand_transformer = ConfidenceWeightedTransformerEncoder(
    input_dim=hands_combined.shape[-1],
    num_layers=2,
    num_heads=4,
    hidden_dim=hands_combined.shape[-1],
    prenorm=True,
    activation='gelu'
)

In [390]:
confidence_scores = {
    'Cd_spatial': batch['confidence_scores'],
    'Ci_spatial': batch['interpolation_scores'],
    'Cd_velocity': batch['velocity_calculation_confidence'],
    'Ci_velocity': batch['velocity_confidence']
}


In [391]:
enhanced_hands = cross_hand_transformer(hands_combined, confidence_scores)

In [208]:
enhanced_hands.shape

torch.Size([2, 59, 2, 800])

In [393]:
final_pooling = LandmarkAttentionPooling(
    input_dim=enhanced_hands.shape[-1],
    output_dim=enhanced_hands.shape[-1])

In [394]:
final_pooling.query.shape

torch.Size([1, 1, 800])

In [395]:
final_hands_representation = final_pooling(enhanced_hands)

In [396]:
final_hands_representation.shape

torch.Size([2, 36, 800])

In [398]:
frame_representation = concat_pooled_wrists(final_hands_representation, blendshapes_encoded)

In [399]:
frame_representation.shape

torch.Size([2, 36, 872])

In [400]:
class TemporalDownsampler(nn.Module):
    """
    Reduces frame count using 1D convolution with configurable parameters.
    
    This module applies a 1D convolution across the temporal dimension,
    effectively reducing the number of frames while preserving important
    temporal information through learned filters.
    """
    def __init__(self, 
                 input_dim, 
                 output_channels=None, 
                 kernel_size=3, 
                 stride=2,
                 activation='relu',
                 norm_layer=True):
        """
        Initialize the temporal downsampler.
        
        Args:
            input_dim: Input dimension (d) - feature size per frame
            output_channels: Number of convolutional filters (C), defaults to input_dim
            kernel_size: Size of the convolutional kernel (k)
            stride: Stride of the convolution, controls downsampling factor
            activation: Activation function ('relu', 'gelu', None)
            norm_layer: Whether to include layer normalization after convolution
        """
        super(TemporalDownsampler, self).__init__()
        
        # Default output channels to input dimension if not specified
        self.output_channels = input_dim if output_channels is None else output_channels
        self.kernel_size = kernel_size
        self.stride = stride
        
        # Calculate padding to maintain temporal alignment
        # For even kernel sizes, we'll use asymmetric padding later
        self.padding = (kernel_size - 1) // 2
        self.is_even_kernel = (kernel_size % 2 == 0)
        
        # Convolutional layer
        self.conv = nn.Conv1d(
            in_channels=input_dim,
            out_channels=self.output_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=self.padding,  # This will be adjusted for even kernels
            bias=True
        )
        
        # Normalization layer
        self.norm = nn.LayerNorm(self.output_channels) if norm_layer else None
        
        # Activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation is None:
            self.activation = None
        else:
            raise ValueError(f"Unsupported activation function: {activation}")
    
    def forward(self, x):
        """
        Apply temporal downsampling to the input sequence.
        
        Args:
            x: Input tensor of shape [batch_size, n_frames, input_dim]
            
        Returns:
            Tensor of shape [batch_size, n_frames/stride, output_channels]
        """
        batch_size, n_frames, input_dim = x.shape
        
        # Reshape for conv1d which expects [batch_size, channels, length]
        x = x.permute(0, 2, 1)  # -> [batch_size, input_dim, n_frames]
        
        # Handle even-sized kernels with asymmetric padding if needed
        if self.is_even_kernel:
            # For even kernels, PyTorch padding is not symmetric
            # We'll pad manually to handle this
            pad_size = (self.kernel_size - 1) // 2
            x = nn.functional.pad(x, (pad_size, pad_size+1), mode='constant', value=0)
            
        # Apply convolution
        x = self.conv(x)  # -> [batch_size, output_channels, n_frames/stride]
        
        # Reshape back to [batch_size, n_frames/stride, output_channels]
        x = x.permute(0, 2, 1)
        
        # Apply normalization if specified
        if self.norm is not None:
            x = self.norm(x)
        
        # Apply activation if specified
        if self.activation is not None:
            x = self.activation(x)
        
        return x
    
    def compute_output_shape(self, input_length):
        """
        Calculate the output sequence length given the input length.
        
        Args:
            input_length: Length of the input sequence (n_frames)
            
        Returns:
            Length of the output sequence
        """
        # For even kernels with our manual padding
        if self.is_even_kernel:
            padding = self.padding + 1
        else:
            padding = self.padding
        
        # Standard formula for conv output shape
        return math.floor((input_length + 2 * padding - self.kernel_size) / self.stride + 1)

In [401]:
conv1d = TemporalDownsampler(
    input_dim=frame_representation.shape[-1],          # Feature dimension (d)
    output_channels=768,    # Number of filters (C), same as input to preserve dimension
    kernel_size=5,          # Kernel size (k)
    stride=2,               # Stride for downsampling
    activation='relu',      # Activation function
    norm_layer=True         # Include layer normalization
)

In [402]:
downsampled_representation = conv1d(frame_representation)

In [403]:
downsampled_representation.shape

torch.Size([2, 18, 768])

In [405]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # Create fixed positional encodings
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x, scale=1.0):
        """
        Add positional encodings to the input tensor.
        
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            scale: Scaling factor for the positional encodings
            
        Returns:
            Tensor with added positional encodings
        """
        x = x + (self.pe[:, :x.size(1), :] * scale)
        return x


In [406]:
positional_encoder = PositionalEncoding(
    d_model=downsampled_representation.shape[-1],  # Feature dimension
    max_len=120  # Add some buffer for sequence length
)

In [407]:
downsampled_with_positional_encoding = positional_encoder(downsampled_representation, scale=1.0)

In [408]:
downsampled_with_positional_encoding.shape

torch.Size([2, 18, 768])

In [410]:
class MultiScaleTemporalTransformer(nn.Module):
    """
    Transformer that processes sequences with multi-scale temporal attention.
    
    Uses exactly three attention heads:
    - Short-term head: Attends to frames within ±5 frames
    - Medium-term head: Attends to frames within ±15 frames
    - Long-term head: Attends to frames within ±45 frames
    """
    def __init__(self, 
                 d_model, 
                 num_layers=4,
                 short_range=5,
                 medium_range=15,
                 long_range=45,
                 dim_feedforward=2048,
                 activation='gelu',
                 stride=2):  # Add stride parameter
        """
        Initialize the multi-scale temporal transformer.
        
        Args:
            d_model: Model dimension / feature size
            num_layers: Number of transformer encoder layers
            short_range: Range for short-term attention (±frames)
            medium_range: Range for medium-term attention (±frames)
            long_range: Range for long-term attention (±frames)
            dim_feedforward: Dimension of feedforward network
            activation: Activation function type
            stride: Stride used in downsampling (needed for mask adjustment)
        """
        super(MultiScaleTemporalTransformer, self).__init__()
        
        # Store parameters
        self.d_model = d_model
        self.total_heads = 3  # Exactly 3 heads
        self.head_ranges = {
            'short': short_range,
            'medium': medium_range,
            'long': long_range
        }
        self.stride = stride
        
        # Create transformer layers
        encoder_layers = []
        for _ in range(num_layers):
            encoder_layers.append(
                MultiScaleTransformerEncoderLayer(
                    d_model=d_model,
                    head_ranges=self.head_ranges,
                    dim_feedforward=dim_feedforward,
                    activation=activation
                )
            )
        self.layers = nn.ModuleList(encoder_layers)
        
        # Layer normalization
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, src, mask=None):
        """
        Process the input sequence through the transformer.
        
        Args:
            src: Input tensor [batch_size, seq_len_downsampled, d_model]
            mask: Boolean mask [batch_size, seq_len_original] where True indicates valid frames
                 and False indicates padding frames
            
        Returns:
            Output tensor of same shape as input with multi-scale temporal context
        """
        output = src
        
        # Adjust mask for downsampled sequence length
        if mask is not None:
            # Subsample the mask to match downsampled sequence
            # Take every stride-th element, starting from 0
            # This accounts for how conv1d downsampling affects the sequence length
            downsample_mask = mask[:, ::self.stride]
            
            # Make sure downsampled mask matches sequence length
            # It might be off by 1 due to padding in conv1d
            if downsample_mask.shape[1] > src.shape[1]:
                downsample_mask = downsample_mask[:, :src.shape[1]]
            elif downsample_mask.shape[1] < src.shape[1]:
                # This shouldn't normally happen, but just in case
                pad_size = src.shape[1] - downsample_mask.shape[1]
                pad = torch.zeros((downsample_mask.shape[0], pad_size), dtype=torch.bool, device=mask.device)
                downsample_mask = torch.cat([downsample_mask, pad], dim=1)
            
            # Convert from True=valid to True=padding format used by transformer
            padding_mask = ~downsample_mask
        else:
            padding_mask = None
        
        # Pass through each transformer layer
        for layer in self.layers:
            output = layer(output, padding_mask=padding_mask)
        
        # Apply final normalization
        output = self.norm(output)
        
        return output


class MultiScaleTransformerEncoderLayer(nn.Module):
    """
    Transformer encoder layer with multi-scale temporal attention.
    """
    def __init__(self, 
                 d_model, 
                 head_ranges,
                 dim_feedforward=2048, 
                 activation="gelu"):
        super(MultiScaleTransformerEncoderLayer, self).__init__()
        
        # Multi-scale attention
        self.self_attn = MultiScaleAttention(
            embed_dim=d_model,
            head_ranges=head_ranges
        )
        
        # Feed-forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Activation function
        if activation == "relu":
            self.activation = F.relu
        elif activation == "gelu":
            self.activation = F.gelu
        else:
            raise ValueError(f"Unsupported activation: {activation}")
    
    def forward(self, src, padding_mask=None):
        """
        Forward pass through the transformer encoder layer.
        
        Args:
            src: Input tensor [batch_size, seq_len_downsampled, d_model]
            padding_mask: Boolean mask [batch_size, seq_len_downsampled] 
                         where True indicates padding
            
        Returns:
            Output tensor of the same shape
        """
        # Multi-scale attention with residual connection
        src2 = self.norm1(src)
        src2 = self.self_attn(src2, src2, src2, padding_mask=padding_mask)
        src = src + src2
        
        # Feed-forward network with residual connection
        src2 = self.norm2(src)
        src2 = self.linear2(self.activation(self.linear1(src2)))
        src = src + src2
        
        return src


class MultiScaleAttention(nn.Module):
    """
    Multi-head attention where different heads attend to different temporal ranges.
    Uses exactly 3 heads: short, medium, and long-term.
    """
    def __init__(self, embed_dim, head_ranges):
        super(MultiScaleAttention, self).__init__()
        
        self.embed_dim = embed_dim
        self.head_ranges = head_ranges
        self.total_heads = 3  # Fixed: one head per range
        
        assert embed_dim % self.total_heads == 0, "embed_dim must be divisible by 3"
        self.head_dim = embed_dim // self.total_heads
        
        # Create linear projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # Head indices (fixed for 3 heads)
        self.head_indices = {
            'short': (0, 1),
            'medium': (1, 2),
            'long': (2, 3)
        }
    
    def forward(self, query, key, value, padding_mask=None):
        """
        Apply multi-scale attention.
        
        Args:
            query, key, value: Input tensors [batch_size, seq_len, embed_dim]
            padding_mask: Boolean mask [batch_size, seq_len] where True indicates padding frames
            
        Returns:
            Output tensor [batch_size, seq_len, embed_dim]
        """
        batch_size, tgt_len, _ = query.shape
        src_len = key.shape[1]
        
        # Linear projections and reshape for multi-head attention
        q = self.q_proj(query).view(batch_size, tgt_len, self.total_heads, self.head_dim)
        k = self.k_proj(key).view(batch_size, src_len, self.total_heads, self.head_dim)
        v = self.v_proj(value).view(batch_size, src_len, self.total_heads, self.head_dim)
        
        # Transpose for attention computation
        q = q.transpose(1, 2)  # [batch_size, total_heads, tgt_len, head_dim]
        k = k.transpose(1, 2)  # [batch_size, total_heads, src_len, head_dim]
        v = v.transpose(1, 2)  # [batch_size, total_heads, src_len, head_dim]
        
        # Compute attention scores
        attn_output = self._multi_scale_attention(q, k, v, tgt_len, src_len, padding_mask)
        
        # Reshape and apply final projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, self.embed_dim)
        output = self.out_proj(attn_output)
        
        return output
    
    def _multi_scale_attention(self, q, k, v, tgt_len, src_len, padding_mask):
        """
        Apply attention with different temporal ranges for different heads.
        """
        # Compute scaled dot-product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Create temporal range masks for each head type
        temporal_masks = self._create_temporal_masks(tgt_len, src_len, device=q.device)
        
        # Apply padding mask if provided
        if padding_mask is not None:
            # Convert mask from [batch_size, seq_len] to [batch_size, 1, 1, seq_len]
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
            # True values are masked positions (set to -inf)
            attn_weights = attn_weights.masked_fill(padding_mask, float('-inf'))
        
        # Apply the temporal masks
        for scale, (start_idx, end_idx) in self.head_indices.items():
            mask = temporal_masks[scale]
            attn_weights[:, start_idx:end_idx] = attn_weights[:, start_idx:end_idx].masked_fill(mask, float('-inf'))
        
        # Apply softmax (no dropout)
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # Apply attention weights to values
        output = torch.matmul(attn_weights, v)
        
        return output
    
    def _create_temporal_masks(self, tgt_len, src_len, device):
        """
        Create masks to restrict attention to specific temporal ranges.
        
        Returns:
            Dictionary of masks for each head type
        """
        temporal_masks = {}
        
        # Create position indices
        pos_i = torch.arange(tgt_len, device=device).unsqueeze(1)
        pos_j = torch.arange(src_len, device=device).unsqueeze(0)
        
        # Calculate distance between positions
        dist = torch.abs(pos_i - pos_j)  # [tgt_len, src_len]
        
        # Create masks for each temporal range
        for scale, range_val in self.head_ranges.items():
            # True where attention should be blocked (outside of the range)
            mask = dist > range_val
            # Expand for batch dimension and appropriate number of heads
            # Shape: [1, 1, tgt_len, src_len]
            temporal_masks[scale] = mask.unsqueeze(0).unsqueeze(0)
        
        return temporal_masks

In [411]:
downsampled_with_positional_encoding.shape[-1]

768

In [412]:


temporal_transformer = MultiScaleTemporalTransformer(
    d_model=downsampled_with_positional_encoding.shape[-1],
    num_layers=4,
    short_range=5,
    medium_range=15,
    long_range=45,
    dim_feedforward=2 * downsampled_with_positional_encoding.shape[-1],
    activation='gelu',
    stride=2  
)


multi_scale_representation = temporal_transformer(downsampled_with_positional_encoding, mask=batch['mask'])

In [413]:
multi_scale_representation.shape

torch.Size([2, 18, 768])

In [416]:
multi_scale_representation_reinforced = positional_encoder(multi_scale_representation, scale=0.25)

In [417]:
multi_scale_representation_reinforced.shape

torch.Size([2, 18, 768])

In [625]:
dom_contextualized.shape[-1]

256

In [637]:
dom_vel_contextualized.shape[-1]

256

In [636]:
dom_vel_transformer.hidden_dim

256

In [None]:
embedding_dim = 36
embedding_table = LandmarkEmbedding(embedding_dim=embedding_dim, num_landmarks_per_hand=21)
ladmark_encoder = LandmarkSpatialEncoder(embedding_dim, hidden_dims=[30, 60, 30],activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

wrist_encoder = WristSpatialEncoder(embedding_dim, hidden_dims=[30, 60, 30],activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

blendshapes_feedforward = BlendshapeEncoder(embedding_dim, hidden_dims=None, num_layers=2,activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

velocity_feedforward = VelocityEncoder(n_velocity_encoding=2*embedding_dim, hidden_dims=None, num_layers=2,activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

wrist_vel_feedforward = WristVelocityEncoder(n_velocity_encoding=2*embedding_dim, hidden_dims=None, num_layers=2,activation='relu',init_method='kaiming_normal',init_gain=1.0,init_nonlinearity='relu')

dom_transformer = LandmarkTransformerEncoder(input_dim=3 * embedding_dim,num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)
non_dom_transformer = LandmarkTransformerEncoder(input_dim=3 * embedding_dim,num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)

dom_pooling = LandmarkAttentionPooling(input_dim=dom_transformer.hidden_dim,output_dim=256)
non_dom_pooling = LandmarkAttentionPooling(input_dim=non_dom_transformer.hidden_dim,output_dim=256)

dom_vel_transformer = LandmarkTransformerEncoder(input_dim=velocity_feedforward.output_dim*2+embedding_dim,num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)
non_dom_vel_transformer = LandmarkTransformerEncoder(input_dim=velocity_feedforward.output_dim*2+embedding_dim,num_layers=4,num_heads=8,hidden_dim=256,activation='gelu',prenorm=True)

dom_vel_pooling = LandmarkAttentionPooling(input_dim=dom_vel_transformer.hidden_dim,output_dim=256)
non_dom_vel_pooling = LandmarkAttentionPooling(input_dim=non_dom_vel_transformer.hidden_dim,output_dim=256)

cross_hand_transformer = ConfidenceWeightedTransformerEncoder(
    input_dim=hands_combined.shape[-1],
    num_layers=2,
    num_heads=4,
    hidden_dim=hands_combined.shape[-1],
    prenorm=True,
    activation='gelu'
)


final_pooling = LandmarkAttentionPooling(
    input_dim=enhanced_hands.shape[-1],
    output_dim=enhanced_hands.shape[-1])

conv1d = TemporalDownsampler(
    input_dim=frame_representation.shape[-1],          # Feature dimension (d)
    output_channels=768,    # Number of filters (C)
    kernel_size=5,          # Kernel size (k)
    stride=2,               # Stride for downsampling
    activation='relu',      
    norm_layer=True         # Include layer normalization
)


positional_encoder = PositionalEncoding(
    d_model=downsampled_representation.shape[-1],  # Feature dimension
    max_len=120/2  # max_frames
)



temporal_transformer = MultiScaleTemporalTransformer(
    d_model=downsampled_with_positional_encoding.shape[-1],
    num_layers=4,
    short_range=5,
    medium_range=15,
    long_range=45,
    dim_feedforward=2 * downsampled_with_positional_encoding.shape[-1],
    activation='gelu',
    stride=2  
)

TypeError: zeros() received an invalid combination of arguments - got (float, int), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)


In [None]:
embeddings = embedding_table.forward()

dom_landmarks_where = ladmark_encoder.forward(batch['dom_landmarks'])
non_dom_landmarks_where = ladmark_encoder.forward(batch['non_dom_landmarks'])

dom_landmarks_conc = combine_spatial_and_semantic_features(spatial_features=dom_landmarks_where, semantic_features=embeddings[:20])
non_dom_landmarks_conc = combine_spatial_and_semantic_features(spatial_features=non_dom_landmarks_where, semantic_features=embeddings[21:41])

wrists_where = wrist_encoder.forward(wrist_coordinates=batch['nose_to_wrist_dist'])

wrists_conc = combine_wrist_embedding_and_spatial(wrist_embeddings=torch.cat([embeddings[20], embeddings[41]], dim=-1).reshape((2,-1)), wrist_spatial_features=wrists_where)

blendshapes_encoded = blendshapes_feedforward(batch['blendshape_scores'])

dom_small_vel_encoded = velocity_feedforward.forward(batch['dom_velocity_small']) 
dom_large_vel_encoded = velocity_feedforward.forward(batch['dom_velocity_large']) 
non_dom_small_vel_encoded = velocity_feedforward.forward(batch['non_dom_velocity_small']) 
non_dom_large_vel_encoded = velocity_feedforward.forward(batch['non_dom_velocity_large']) 

dom_landmarks_velocity_conc = combine_semantic_and_velocity_features(semantic_features=embeddings[:20], velocity_small_features=dom_small_vel_encoded, velocity_large_features=dom_large_vel_encoded)
non_dom_landmarks_velocity_conc = combine_semantic_and_velocity_features(semantic_features=embeddings[21:41], velocity_small_features=non_dom_small_vel_encoded, velocity_large_features=non_dom_large_vel_encoded)

wrist_vel_small_encoded = wrist_vel_feedforward.forward(batch['nose_to_wrist_velocity_small'])
wrist_vel_large_encoded = wrist_vel_feedforward.forward(batch['nose_to_wrist_velocity_large'])

wrists_vel_conc = combine_wrist_embedding_and_velocity(wrist_embeddings=torch.cat([embeddings[20], embeddings[41]], dim=-1).reshape((2,-1)), wrist_velocity_small=wrist_vel_small_encoded, wrist_velocity_large=wrist_vel_large_encoded)

dom_contextualized = dom_transformer(dom_landmarks_conc)
non_dom_contextualized=non_dom_transformer(non_dom_landmarks_conc)

dom_pooled = dom_pooling(dom_contextualized)
non_dom_pooled = non_dom_pooling(non_dom_contextualized)

dom_wrist_conc = wrists_conc[:,:,0]
non_dom_wrist_conc = wrists_conc[:,:,1]

dom_spatial_combined = concat_pooled_wrists(pooled=dom_pooled, wrist=dom_wrist_conc)
non_dom_spatial_combined = concat_pooled_wrists(pooled=non_dom_pooled, wrist=non_dom_wrist_conc)

dom_vel_contextualized = dom_vel_transformer(dom_landmarks_velocity_conc)
non_dom_vel_contextualized=non_dom_vel_transformer(non_dom_landmarks_velocity_conc)

dom_vel_pooled = dom_vel_pooling(dom_vel_contextualized)
non_dom_vel_pooled = non_dom_vel_pooling(non_dom_vel_contextualized)

dom_wrist_vel_conc = wrists_vel_conc[:,:,0]
non_dom_wrist_vel_conc = wrists_vel_conc[:,:,1]

dom_velocity_combined = concat_pooled_wrists(pooled=dom_vel_pooled, wrist=dom_wrist_vel_conc)
non_dom_velocity_combined = concat_pooled_wrists(pooled=non_dom_vel_pooled, wrist=non_dom_wrist_vel_conc)

dom_combined = concat_pooled_wrists(dom_spatial_combined, dom_velocity_combined)
non_dom_combined = concat_pooled_wrists(non_dom_spatial_combined, non_dom_velocity_combined)

hands_combined = torch.stack([dom_combined, non_dom_combined], dim=2)

confidence_scores = {
    'Cd_spatial': batch['confidence_scores'],
    'Ci_spatial': batch['interpolation_scores'],
    'Cd_velocity': batch['velocity_calculation_confidence'],
    'Ci_velocity': batch['velocity_confidence']
}

enhanced_hands = cross_hand_transformer(hands_combined, confidence_scores)

frame_representation = concat_pooled_wrists(final_hands_representation, blendshapes_encoded)


downsampled_representation = conv1d(frame_representation)

downsampled_with_positional_encoding = positional_encoder(downsampled_representation, scale=1.0)

multi_scale_representation = temporal_transformer(downsampled_with_positional_encoding, mask=batch['mask'])

video_representation = positional_encoder(multi_scale_representation, scale=0.25)

In [None]:
video_representation.shape

torch.Size([2, 26, 768])

In [418]:
def semantic_smoothing_loss(logits, L_index, L_values, label_mask=None):
    """
    Custom loss function that supports semantic label smoothing with proper
    handling of first token prediction.
    
    Args:
        logits: Model output logits [batch_size, seq_len, vocab_size]
        L_index: Token indices [batch_size, max_n_tokens, 6] where each row contains 
                 the original label index and 5 semantically similar tokens
        L_values: Token values [batch_size, max_n_tokens, 6] containing smoothed probabilities
        label_mask: Boolean mask [batch_size, max_n_tokens] with True for valid tokens
        
    Returns:
        tuple: (total_loss, first_token_loss, next_token_loss)
    """
    batch_size, seq_len, vocab_size = logits.shape
    _, max_n_tokens, k = L_index.shape
    
    # For first token prediction: use first position logits to predict first token
    first_token_logits = logits[:, 0, :]  # [batch_size, vocab_size]
    first_token_targets = L_index[:, 0, :]  # [batch_size, 6]
    first_token_values = L_values[:, 0, :]  # [batch_size, 6]
    
    # For subsequent tokens: use shifted logits to predict shifted targets
    next_token_logits = logits[:, :-1, :]  # [batch_size, seq_len-1, vocab_size]
    next_token_targets = L_index[:, 1:, :]  # [batch_size, max_n_tokens-1, 6]
    next_token_values = L_values[:, 1:, :]  # [batch_size, max_n_tokens-1, 6]
    
    if label_mask is not None:
        first_token_mask = label_mask[:, 0]  # [batch_size]
        next_token_mask = label_mask[:, 1:]  # [batch_size, max_n_tokens-1]
    else:
        first_token_mask = (first_token_targets[:, 0] != 0)
        next_token_mask = (next_token_targets[:, :, 0] != 0)
    
    # Initialize loss components
    first_token_loss = torch.zeros(batch_size, device=logits.device)
    next_token_loss = torch.zeros(batch_size, min(next_token_targets.shape[1], next_token_logits.shape[1]), device=logits.device)
    
    # Calculate first token loss
    for b in range(batch_size):
        if first_token_mask[b]:
            pos_loss = 0
            for i in range(k):
                token_idx = first_token_targets[b, i]
                token_value = first_token_values[b, i]
                
                if token_idx == 0:  # Skip padding
                    continue
                
                log_prob = F.log_softmax(first_token_logits[b], dim=-1)[token_idx]
                pos_loss -= token_value * log_prob
            
            first_token_loss[b] = pos_loss
    
    # Calculate subsequent token losses
    for b in range(batch_size):
        for pos in range(min(next_token_targets.shape[1], next_token_logits.shape[1])):
            if pos < next_token_mask.shape[1] and next_token_mask[b, pos]:
                pos_loss = 0
                for i in range(k):
                    token_idx = next_token_targets[b, pos, i]
                    token_value = next_token_values[b, pos, i]
                    
                    if token_idx == 0:
                        continue
                    
                    log_prob = F.log_softmax(next_token_logits[b, pos], dim=-1)[token_idx]
                    pos_loss -= token_value * log_prob
                
                next_token_loss[b, pos] = pos_loss
    
    # Combine losses
    # Count valid tokens 
    valid_first_tokens = first_token_mask.sum().clamp(min=1)
    valid_next_tokens = 0
    for b in range(batch_size):
        valid_next_tokens += next_token_mask[b, :next_token_loss.shape[1]].sum()
    valid_next_tokens = valid_next_tokens.clamp(min=1)
    
    # Average first token loss
    avg_first_token_loss = first_token_loss.sum() / valid_first_tokens
    
    # Average next token losses
    avg_next_token_loss = next_token_loss.sum() / valid_next_tokens
    
    # Combine both losses
    total_loss = (avg_first_token_loss + avg_next_token_loss) / 2
    
    return total_loss, avg_first_token_loss, avg_next_token_loss

In [None]:
def compute_minimum_loss(target_labels, label_mask=None):
    """
    Compute the theoretical minimum loss (entropy of target distribution)
    """
    # Add small epsilon to avoid log(0)
    epsilon = 1e-10
    
    # Calculate entropy for each position: -∑(p_i * log(p_i))
    position_entropy = -(target_labels * torch.log(target_labels + epsilon)).sum(dim=2)
    
    # Apply mask if provided
    if label_mask is not None:
        position_entropy = position_entropy * label_mask.float()
        # Average entropy over valid tokens
        min_loss = position_entropy.sum() / label_mask.sum().clamp(min=1)
    else:
        # If no mask, use all tokens
        min_loss = position_entropy.mean()
    
    return min_loss

In [600]:
def optimized_semantic_smoothing_loss(logits, L_index, L_values, label_mask=None):
    """
    Highly optimized semantic smoothing loss using a single scatter operation.
    No loops needed!
    """
    batch_size, seq_len, vocab_size = logits.shape
    
    # Create target label distributions all at once
    target_labels = torch.zeros(batch_size, seq_len, vocab_size, device=logits.device)
    target_labels.scatter_(2, L_index, L_values)
    
    # Apply log_softmax to get log probabilities
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Compute loss (batch_size, seq_len)
    token_losses = -(target_labels * log_probs).sum(dim=2)
    
    # Apply mask if provided
    if label_mask is not None:
        token_losses = token_losses * label_mask.float()
        # Average loss over valid tokens
        total_loss = token_losses.sum() / label_mask.sum().clamp(min=1)
    else:
        # If no mask, use all tokens
        total_loss = token_losses.mean()
    
    return total_loss  

In [522]:
multi_scale_representation_reinforced.shape

torch.Size([2, 18, 768])

In [449]:
class OptimizedCrossAttention(nn.Module):
    """
    Cross-attention using PyTorch's optimized MultiheadAttention implementation.
    """
    def __init__(self, hidden_size=768, num_heads=12):
        super(OptimizedCrossAttention, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        
        # PyTorch's optimized multi-head attention
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            batch_first=True  # Important for our [batch, seq, features] format
        )
        
        # Layer normalization for pre-norm architecture (like GPT-2)
        self.layer_norm = nn.LayerNorm(hidden_size)
    
    def forward(self, hidden_states, video_representations, video_mask=None, stride=1):
        """
        Compute cross-attention between GPT token representations and video frames.
        """
        # Apply layer normalization to hidden states (pre-norm approach)
        query = self.layer_norm(hidden_states)
        
        # Handle strided video mask
        if video_mask is not None and stride > 1:
            # Subsample the mask to match video_representations shape
            video_mask = video_mask[:, ::stride]
            
            # Ensure mask length matches
            frame_length = video_representations.shape[1]
            if video_mask.shape[1] > frame_length:
                video_mask = video_mask[:, :frame_length]
            elif video_mask.shape[1] < frame_length:
                pad_size = frame_length - video_mask.shape[1]
                pad = torch.zeros((video_mask.shape[0], pad_size), dtype=torch.bool, device=video_mask.device)
                video_mask = torch.cat([video_mask, pad], dim=1)
            
            # Convert to attention mask format expected by PyTorch
            # True = don't attend, False = attend
            attn_mask = ~video_mask
        else:
            attn_mask = None
        
        # PyTorch's MultiheadAttention expects:
        # - query: [batch_size, target_seq_length, embed_dim]
        # - key: [batch_size, source_seq_length, embed_dim]
        # - value: [batch_size, source_seq_length, embed_dim]
        # - attn_mask: [batch_size, target_seq_length, source_seq_length] or [target_seq_length, source_seq_length]
        
        # Use PyTorch's optimized implementation
        cross_attention_output, _ = self.multihead_attn(
            query=query,                  # From GPT tokens
            key=video_representations,    # From video frames
            value=video_representations,  # From video frames
            key_padding_mask=attn_mask,   # Mask for padding frames
            need_weights=False            # Don't return attention weights to save computation
        )
        
        return cross_attention_output

In [451]:
opt_cross_att = OptimizedCrossAttention()
opt_cross_att.forward(hidden_states=outputs['hidden_states'], video_representations=multi_scale_representation_reinforced, video_mask=batch['mask'], stride=2)

tensor([[[-0.2158,  0.0108, -0.2586,  ..., -0.2118, -0.7830, -0.2154],
         [-0.2158,  0.0108, -0.2586,  ..., -0.2118, -0.7830, -0.2154],
         [-0.2158,  0.0108, -0.2586,  ..., -0.2118, -0.7830, -0.2154],
         ...,
         [-0.2158,  0.0108, -0.2586,  ..., -0.2118, -0.7830, -0.2154],
         [-0.2158,  0.0108, -0.2586,  ..., -0.2118, -0.7830, -0.2154],
         [-0.2158,  0.0108, -0.2586,  ..., -0.2118, -0.7830, -0.2154]],

        [[-0.4754, -0.2581, -0.3415,  ..., -0.3153, -1.0134, -0.2806],
         [-0.4754, -0.2581, -0.3415,  ..., -0.3153, -1.0134, -0.2806],
         [-0.4754, -0.2581, -0.3415,  ..., -0.3153, -1.0134, -0.2806],
         ...,
         [-0.4754, -0.2581, -0.3415,  ..., -0.3153, -1.0134, -0.2806],
         [-0.4754, -0.2581, -0.3415,  ..., -0.3153, -1.0134, -0.2806],
         [-0.4754, -0.2581, -0.3415,  ..., -0.3153, -1.0134, -0.2806]]],
       grad_fn=<TransposeBackward0>)

In [None]:
outputs['hidden_states']

torch.Size([2, 8, 768])

In [441]:
batch['mask'].shape[1]

36

In [442]:

# Subsample the mask to match video_representations shape
video_mask = batch['mask'][:, ::2]

# Ensure mask matches the number of frames after downsampling

video_mask = video_mask[:, :batch['mask'].shape[1]]


# Reshape mask to match attention scores: [batch_size, 1, 1, frame_length]
attention_mask = ~video_mask.unsqueeze(1).unsqueeze(2)

In [444]:
batch['mask']

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True, False, False, False, False,
         False, False, False, False, False, False]])

In [443]:
attention_mask

tensor([[[[False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False]]],


        [[[False, False, False, False, False, False, False, False, False, False,
           False, False, False,  True,  True,  True,  True,  True]]]])

In [None]:
class VideoGPT(nn.Module):
    """
    Integrates pre-trained GPT-2 with cross-attention for video-to-text translation.
    """
    def __init__(self, model_name="distilgpt2", num_cross_heads=12, freeze_gpt=True, stride=2):
        super(VideoGPT, self).__init__()
        
        # Load pre-trained model
        self.gpt = AutoModelForCausalLM.from_pretrained(model_name)
        self.config = self.gpt.config
        self.stride = stride
        
        # Dimensions
        self.hidden_size = self.config.n_embd  # 768 for distilGPT-2
        
        # Create cross-attention layer
        self.cross_attention = OptimizedCrossAttention(
            hidden_size=self.hidden_size,
            num_heads=num_cross_heads
        )
        
        # Freeze GPT-2 weights if specified
        if freeze_gpt:
            self._freeze_gpt_parameters()
    
    def _freeze_gpt_parameters(self):
        """Freeze all parameters of the GPT model."""
        for param in self.gpt.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, video_representations, video_mask=None, 
               L_index=None, L_values=None, label_mask=None):
        """
        Forward pass with integrated cross-attention and semantic smoothing loss.
        
        Args:
            input_ids: Token IDs for GPT [batch_size, n_tokens]
            video_representations: Video frame features [batch_size, n_frames/stride, hidden_size]
            video_mask: Mask tensor [batch_size, n_frames] with True for valid frames
            attention_mask: Mask for input tokens [batch_size, n_tokens]
            L_index: Token indices [batch_size, max_n_tokens, 6]
            L_values: Token values [batch_size, max_n_tokens, 6]
            label_mask: Boolean mask [batch_size, max_n_tokens]
            
        Returns:
            outputs: Model outputs including loss and logits
        """
        batch_size, n_tokens = input_ids.shape
        
        # Get GPT embeddings (word + position)
        position_ids = torch.arange(0, n_tokens, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
        gpt_embeds = self.gpt.transformer.wte(input_ids) + self.gpt.transformer.wpe(position_ids)
        
        # Store states at each step
        hidden_states = gpt_embeds

        print(f"Input IDs min/max: {input_ids.min().item()}, {input_ids.max().item()}")
        print(f"Video representations has NaN: {torch.isnan(video_representations).any().item()}")

        if label_mask is not None:
            # Create attention mask that combines padding and causal constraints
            extended_attention_mask = label_mask.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, seq_len]

            # Step 2: Create causal mask (lower triangular matrix)
            seq_length = label_mask.size(1)
            causal_mask = torch.tril(torch.ones((seq_length, seq_length), 
                                               device=label_mask.device))
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]

            # Step 3: Combine padding mask with causal mask
            combined_mask = causal_mask * extended_attention_mask.float()

            # Step 4: Convert to additive mask where 0 means "attend" and 
            # a large negative number means "don't attend"
            attention_mask = combined_mask.to(dtype=hidden_states.dtype)
            attention_mask = (1.0 - attention_mask) * -10000.0

        print(f"Attention mask min/max: {attention_mask.min().item()}, {attention_mask.max().item()}")
        print(f"Attention mask has -inf: {torch.isinf(attention_mask).any().item()}")
        print(f"Number of non-masked positions: {(attention_mask > -1000).sum().item()}")
        # Process through GPT layers with cross-attention
        for i, block in enumerate(self.gpt.transformer.h):
            print(f"Block {i} hidden states has NaN: {torch.isnan(hidden_states).any().item()}")
            # 1. GPT self-attention
            attn_outputs = block.attn(
                hidden_states,
                attention_mask=attention_mask if label_mask is not None else None
            )

            gpt_attn_output = attn_outputs[0]
            
            # Add residual connection
            hidden_states = gpt_attn_output + hidden_states
            
            # 2. Insert our cross-attention between self-attention and FFN
            cross_attention_output = self.cross_attention(
                hidden_states, 
                video_representations, 
                video_mask=video_mask,
                stride=self.stride
            )
            
            # Add residual connection to cross-attention
            hidden_states = hidden_states + cross_attention_output
            
            # 3. Feed-forward network
            feed_forward_output = block.mlp(hidden_states)
            hidden_states = hidden_states + feed_forward_output
        
        # Final layer norm
        hidden_states = self.gpt.transformer.ln_f(hidden_states)
        
        # Language modeling head
        lm_logits = self.gpt.lm_head(hidden_states)
        
        # Calculate loss if labels are provided
        loss = None

        
        if L_index is not None and L_values is not None:
            # Use our custom semantic smoothing loss
            loss = optimized_semantic_smoothing_loss(
                logits=lm_logits,
                L_index=L_index,
                L_values=L_values,
                label_mask=label_mask
            )
    
        
        return {
            "loss": loss, 
            "logits": lm_logits, 
            "hidden_states": hidden_states
        }

In [453]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer



model_name = "distilgpt2"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to("cuda")  # Move to GPU

In [454]:

tokenizer.pad_token = tokenizer.eos_token

# Initialize model
model = VideoGPT(
    model_name="distilgpt2",
    num_cross_heads=12,
    freeze_gpt=True,
    stride=2
)

In [455]:
batch["L_values"]

tensor([[[0.7000, 0.0646, 0.0611, 0.0587, 0.0579, 0.0577],
         [0.7000, 0.0653, 0.0643, 0.0610, 0.0548, 0.0546],
         [0.7000, 0.0738, 0.0602, 0.0559, 0.0556, 0.0545],
         [0.7000, 0.0638, 0.0636, 0.0582, 0.0580, 0.0565],
         [0.7000, 0.0703, 0.0629, 0.0587, 0.0580, 0.0500],
         [0.7000, 0.0756, 0.0569, 0.0564, 0.0562, 0.0548],
         [0.7000, 0.0658, 0.0651, 0.0601, 0.0562, 0.0527],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.7000, 0.0662, 0.0633, 0.0572, 0.0567, 0.0565],
         [0.7000, 0.0605, 0.0600, 0.0598, 0.0598, 0.0598],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])

In [456]:
batch["L_index"][:, :, 0]

tensor([[   72,   705,    76,  2658,   558,   474,  1952, 50256],
        [  404, 23971, 50256,     0,     0,     0,     0,     0]])

In [457]:
primary_targets = batch["L_index"][:, :, 0].clone()

# Create input_ids by shifting right (add BOS at beginning, remove last token)
batch_size, seq_len = primary_targets.shape
input_ids = torch.zeros_like(primary_targets)
input_ids[:, 0] = tokenizer.bos_token_id  # Start with BOS token
input_ids[:, 1:] = primary_targets[:, :-1]

In [458]:
batch['label_mask']

tensor([[ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True, False, False, False, False, False]])

In [459]:
batch["L_index"][:, :, 0]

tensor([[   72,   705,    76,  2658,   558,   474,  1952, 50256],
        [  404, 23971, 50256,     0,     0,     0,     0,     0]])

In [460]:
input_ids

tensor([[50256,    72,   705,    76,  2658,   558,   474,  1952],
        [50256,   404, 23971, 50256,     0,     0,     0,     0]])

In [461]:
batch["label_mask"].shape

torch.Size([2, 8])

In [462]:
batch['label_mask']

tensor([[ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True, False, False, False, False, False]])

In [463]:
outputs = model(
    input_ids=input_ids,
    video_representations=multi_scale_representation_reinforced,
    video_mask=batch["mask"],
    L_index=batch["L_index"],
    L_values=batch["L_values"],
    label_mask=batch["label_mask"]
)

Input IDs min/max: 0, 50256
Video representations has NaN: False
Attention mask min/max: -10000.0, -0.0
Attention mask has -inf: False
Number of non-masked positions: 57
Block 0 hidden states has NaN: False
Block 1 hidden states has NaN: False
Block 2 hidden states has NaN: False
Block 3 hidden states has NaN: False
Block 4 hidden states has NaN: False
Block 5 hidden states has NaN: False


In [614]:
torch.argmax(outputs['logits'], dim=-1)

tensor([[464, 464, 464, 464, 464, 464, 464, 464],
        [464, 464, 464, 464, 464, 464, 464, 464]])

In [622]:
def get_predictions_from_logits(logits, tokenizer):
    """
    Convert model logits to human-readable text predictions.
    
    Args:
        logits: Tensor of shape [batch_size, sequence_length, vocab_size]
        tokenizer: The GPT tokenizer
    
    Returns:
        Dictionary containing token IDs and decoded text for each batch item
    """
    # Get the most likely token at each position (argmax along vocab dimension)
    predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch_size, sequence_length]
    
    # Convert to numpy for easier handling
    token_ids_np = predicted_token_ids.cpu().numpy()
    
    # Container for results
    results = []
    
    # Process each sequence in the batch
    for i, ids in enumerate(token_ids_np):
        # Decode the token IDs to text
        text = tokenizer.decode(ids)
        
        # For more detailed analysis, get individual tokens
        tokens = []
        for token_id in ids:
            token_str = tokenizer.decode([token_id])
            tokens.append((token_id, token_str))
            
        results.append({
            "sequence_idx": i,
            "token_ids": ids.tolist(),
            "tokens": tokens,
            "text": text,
        })
    
    return results

In [623]:
predictions = get_predictions_from_logits(outputs["logits"], tokenizer)

In [624]:
predictions

[{'sequence_idx': 0,
  'token_ids': [464, 464, 464, 464, 464, 464, 464, 464],
  'tokens': [(464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The')],
  'text': 'TheTheTheTheTheTheTheThe'},
 {'sequence_idx': 1,
  'token_ids': [464, 464, 464, 464, 464, 464, 464, 464],
  'tokens': [(464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The'),
   (464, 'The')],
  'text': 'TheTheTheTheTheTheTheThe'}]

In [617]:
input_ids[1]

tensor([50256,   404, 23971, 50256,     0,     0,     0,     0])

In [618]:
def compare_predictions(input_ids, predictions, L_index, tokenizer):
    """Show comparison between inputs, predictions, and expected labels"""
    for i, pred in enumerate(predictions):
        print(f"\nExample {i}:")
        
        # Input sequence
        input_sequence = tokenizer.decode(input_ids[i])
        print(f"Input:      {input_sequence}")
        
        # Generated sequence
        print(f"Generated:  {pred['text']}")


# Use in your test
compare_predictions(
    input_ids,
    predictions,
    batch["L_index"],
    tokenizer
)


Example 0:
Input:      <|endoftext|>i 'm candace jones
Generated:  TheTheTheTheTheTheTheThe

Example 1:
Input:      <|endoftext|>opinion<|endoftext|>!!!!
Generated:  TheTheTheTheTheTheTheThe


In [619]:
def analyze_next_token_predictions(logits, current_position, tokenizer, top_k=5):
    """Analyze top-k predictions for the next token at a specified position"""
    batch_size = logits.shape[0]
    
    for b in range(batch_size):
        next_token_logits = logits[b, current_position, :]
        
        # Get top-k predictions
        values, indices = torch.topk(next_token_logits, top_k)
        probs = torch.softmax(values, dim=0)
        
        print(f"\nTop {top_k} predictions for batch {b}, position {current_position}:")
        for i, (idx, prob) in enumerate(zip(indices.tolist(), probs.tolist())):
            token = tokenizer.decode([idx])
            print(f"  {i+1}. Token: '{token}', ID: {idx}, Probability: {prob:.4f}")


# Use in your test to analyze what comes after the last input token
analyze_next_token_predictions(
    outputs["logits"], 
    current_position=input_ids.shape[1]-1,  # Last position
    tokenizer=tokenizer
)


Top 5 predictions for batch 0, position 7:
  1. Token: 'The', ID: 464, Probability: 0.3296
  2. Token: 'A', ID: 32, Probability: 0.2013
  3. Token: '"', ID: 1, Probability: 0.1883
  4. Token: 'I', ID: 40, Probability: 0.1809
  5. Token: 'In', ID: 818, Probability: 0.1000

Top 5 predictions for batch 1, position 7:
  1. Token: 'The', ID: 464, Probability: 0.3314
  2. Token: 'A', ID: 32, Probability: 0.2017
  3. Token: '"', ID: 1, Probability: 0.1858
  4. Token: 'I', ID: 40, Probability: 0.1813
  5. Token: 'In', ID: 818, Probability: 0.0998
