In [None]:
!pip install -q torch-geometric
!pip install -q torch-geometric-temporal
!pip install -q mediapipe==0.10.9

In [None]:
import os
import json
from pathlib import Path
import cv2
import numpy as np
import mediapipe as mp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric_temporal.nn.attention import STConv
from tqdm import tqdm

In [None]:
class SkeletonExtractor:
    def __init__(self):
        self.mp_pose = mp.solutions.pose
        self.pose = self.mp_pose.Pose(
            static_image_mode=False,
            model_complexity=2,
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5
        )

    def extract_keypoints(self, video_path):
        cap = cv2.VideoCapture(video_path)
        keypoints_sequence = []

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Convert BGR to RGB
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = self.pose.process(rgb_frame)

            if results.pose_landmarks:
                # Extract 33 keypoints (x, y, z)
                keypoints = []
                for landmark in results.pose_landmarks.landmark:
                    keypoints.extend([landmark.x, landmark.y, landmark.z])
                keypoints_sequence.append(keypoints)
            else:
                # Handle missing pose detection with zero padding
                keypoints_sequence.append([0.0] * 99)  # 33 * 3 = 99

        cap.release()
        return np.array(keypoints_sequence)

In [None]:
class STGCNDataPreprocessor:
    def __init__(self, sequence_length=64, num_joints=33):
        self.sequence_length = sequence_length
        self.num_joints = num_joints

    def normalize_keypoints(self, keypoints):
        # Normalize based on torso/hip center
        coords = keypoints.reshape(-1, self.num_joints, 3)

        # Center normalization using hip center (MediaPipe indices 23, 24)
        if coords.shape[0] > 0:
            hip_center = (coords[:, 23, :] + coords[:, 24, :]) / 2
            normalized_coords = coords - hip_center[:, np.newaxis, :]
        else:
            normalized_coords = coords

        return normalized_coords

    def create_temporal_windows(self, skeleton_data):
        # Handle sequences shorter than required length
        if len(skeleton_data) < self.sequence_length:
            # Pad with zeros or repeat last frame
            padding_needed = self.sequence_length - len(skeleton_data)
            if len(skeleton_data) > 0:
                # Repeat the last frame
                last_frame = skeleton_data[-1:].repeat(padding_needed, axis=0)
                skeleton_data = np.concatenate([skeleton_data, last_frame], axis=0)
            else:
                # Create zero padding
                skeleton_data = np.zeros((self.sequence_length, self.num_joints, 3))
        elif len(skeleton_data) > self.sequence_length:
            # Sample frames uniformly
            indices = np.linspace(0, len(skeleton_data)-1, self.sequence_length, dtype=int)
            skeleton_data = skeleton_data[indices]

        return skeleton_data

In [None]:
class GraphConstructor:
    def __init__(self):
        # MediaPipe skeleton connections
        self.skeleton_edges = [
            (11, 12), (11, 13), (13, 15), (12, 14), (14, 16),  # Arms
            (11, 23), (12, 24), (23, 24),  # Torso
            (23, 25), (25, 27), (27, 29), (27, 31),  # Left leg
            (24, 26), (26, 28), (28, 30), (28, 32),  # Right leg
        ]

    def create_adjacency_matrix(self, num_joints=33):
        adj_matrix = np.eye(num_joints)
        for i, j in self.skeleton_edges:
            adj_matrix[i][j] = 1
            adj_matrix[j][i] = 1
        return adj_matrix

    def skeleton_to_graph(self, skeleton_sequence):
        T, N, C = skeleton_sequence.shape
        adj_matrix = self.create_adjacency_matrix(N)
        edge_index = torch.tensor(np.array(np.where(adj_matrix)), dtype=torch.long)
        # Format for STGCN: (Channel, Time, Joints)
        x = torch.tensor(skeleton_sequence.transpose(2, 0, 1), dtype=torch.float32)
        return x, edge_index

In [None]:
class ARIDDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split_file,
                 sequence_length=32,
                 num_joints=33,
                 precomputed_skeletons=None,
                 transform=None):
        """
        Args:
            root_dir (str): Path to ARID dataset root directory
            split_file (str): Path to split file (split0_train.txt or split0_test.txt)
                            Format: "idx class_idx relativepath"
                            Example: "1 0 Drink/Drink1.avi"
            sequence_length (int): Number of frames per sequence
            num_joints (int): Number of keypoints (33 for MediaPipe)
            precomputed_skeletons (str): Path to precomputed skeleton data (optional)
            transform: Optional transform to be applied on samples
        """
        self.root_dir = Path(root_dir)
        self.clips_dir = self.root_dir / "clips_v1.5"
        self.sequence_length = sequence_length
        self.num_joints = num_joints
        self.transform = transform

        # ARID v1.5 has 11 action categories
        self.class_names = [
            'drink', 'jump', 'pick', 'pour', 'push',
            'run', 'sit', 'stand', 'turn', 'walk', 'wave'
        ]
        self.num_classes = len(self.class_names)

        # Initialize processing components
        self.skeleton_extractor = SkeletonExtractor()
        self.preprocessor = STGCNDataPreprocessor(sequence_length, num_joints)
        self.graph_constructor = GraphConstructor()

        # Load dataset splits with numeric class labels
        self.samples = self._load_split_file(split_file)

        print(f"Loaded {len(self.samples)} samples from {split_file}")
        print(f"Classes: {self.class_names}")
        print(f"Clips directory: {self.clips_dir}")

        # Load precomputed skeletons if available
        self.precomputed_skeletons = {}
        if precomputed_skeletons and os.path.exists(precomputed_skeletons):
            self._load_precomputed_skeletons(precomputed_skeletons)

    def _load_split_file(self, split_file):
        """Load video paths and labels from split file with format: idx class_idx relativepath"""
        samples = []

        with open(split_file, 'r') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if line:
                    # Split by whitespace, handle potential spaces in path
                    parts = line.split()

                    if len(parts) >= 3:
                        sample_idx = parts[0]  # Sample index (not used for training)
                        class_idx = int(parts[1])  # Numeric class index (0-10)
                        # Join remaining parts in case path contains spaces
                        relative_path = ' '.join(parts[2:])  # e.g., "Drink/Drink1.avi"

                        # Validate class index is within expected range
                        if 0 <= class_idx < self.num_classes:
                            # Construct full path: ARID/clips_v1.5/relativepath
                            full_path = self.clips_dir / relative_path

                            # Check if file exists
                            if full_path.exists():
                                samples.append((str(full_path), class_idx, sample_idx))
                            else:
                                print(f"Warning: File not found - {full_path}")
                        else:
                            print(f"Warning: Invalid class index {class_idx} at line {line_num}")
                    else:
                        print(f"Warning: Invalid format at line {line_num}: {line}")

        return samples

    def _load_precomputed_skeletons(self, skeleton_file):
        """Load precomputed skeleton data if available"""
        try:
            with open(skeleton_file, 'r') as f:
                self.precomputed_skeletons = json.load(f)
            print(f"Loaded {len(self.precomputed_skeletons)} precomputed skeletons")
        except Exception as e:
            print(f"Could not load precomputed skeletons from {skeleton_file}: {e}")

    def _get_skeleton_data(self, video_path):
        """Get skeleton data either from precomputed or extract from video"""
        video_key = str(Path(video_path).name)

        if video_key in self.precomputed_skeletons:
            return np.array(self.precomputed_skeletons[video_key])
        else:
            # Extract skeleton from video
            return self.skeleton_extractor.extract_keypoints(video_path)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, class_idx, sample_idx = self.samples[idx]

        try:
            # Get skeleton data
            skeleton_data = self._get_skeleton_data(video_path)

            # Preprocess skeleton data
            if len(skeleton_data) > 0:
                normalized_skeleton = self.preprocessor.normalize_keypoints(skeleton_data)
                temporal_skeleton = self.preprocessor.create_temporal_windows(normalized_skeleton)
            else:
                # Handle empty skeleton data
                temporal_skeleton = np.zeros((self.sequence_length, self.num_joints, 3))

            # Convert to graph format
            x, edge_index = self.graph_constructor.skeleton_to_graph(temporal_skeleton)

            # Apply transforms if any
            if self.transform:
                x = self.transform(x)

            return x, edge_index, torch.tensor(class_idx, dtype=torch.long)

        except Exception as e:
            print(f"Error processing {video_path}: {e}")
            # Return zero tensor in case of error
            zero_skeleton = np.zeros((self.sequence_length, self.num_joints, 3))
            x, edge_index = self.graph_constructor.skeleton_to_graph(zero_skeleton)
            return x, edge_index, torch.tensor(class_idx, dtype=torch.long)

In [None]:
def collate_fn(batch):
    """Custom collate function to handle graph data"""
    xs, edge_indices, labels = zip(*batch)

    # Stack tensors
    batch_x = torch.stack(xs, dim=0)  # (batch_size, channels, time, joints)
    batch_labels = torch.stack(labels, dim=0)

    # Edge indices are the same for all samples (same graph structure)
    batch_edge_index = edge_indices[0]

    return batch_x, batch_edge_index, batch_labels

In [None]:
def create_dataloaders(dataset_root, train_split_file, test_split_file,
                      batch_size=8, num_workers=2, sequence_length=32,
                      precomputed_skeletons=None):
    """
    Create train and test dataloaders for ARID dataset with numeric class labels
    """

    # Create datasets
    train_dataset = ARIDDataset(
        root_dir=dataset_root,
        split_file=train_split_file,
        sequence_length=sequence_length,
        precomputed_skeletons=precomputed_skeletons
    )

    test_dataset = ARIDDataset(
        root_dir=dataset_root,
        split_file=test_split_file,
        sequence_length=sequence_length,
        precomputed_skeletons=precomputed_skeletons
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    num_classes = train_dataset.num_classes
    class_names = train_dataset.class_names

    return train_loader, test_loader, num_classes, class_names

In [None]:
class ARIDSTGCNClassifier(torch.nn.Module):
    def __init__(self, num_nodes=33, in_channels=3, hidden_channels=64,
                 out_channels=64, num_classes=11, kernel_size=3, K=3):
        super(ARIDSTGCNClassifier, self).__init__()

        # ST-Conv blocks
        self.stconv1 = STConv(num_nodes, in_channels, hidden_channels,
                             out_channels, kernel_size, K)
        self.stconv2 = STConv(num_nodes, out_channels, hidden_channels,
                             out_channels, kernel_size, K)

        # Classification head
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(out_channels * num_nodes, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(256, num_classes)
        )

    def forward(self, x, edge_index, edge_weight=None):
        # x shape: (batch_size, channels, time_steps, num_nodes)
        x = F.relu(self.stconv1(x, edge_index, edge_weight))
        x = F.relu(self.stconv2(x, edge_index, edge_weight))

        # Global average pooling over time dimension
        x = torch.mean(x, dim=2)  # (batch_size, channels, num_nodes)

        # Flatten and classify
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x

In [None]:
def evaluate_model(model, data_loader, device):
    """
    Evaluate classification accuracy of model on data_loader.
    Returns accuracy as a percentage.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, edge_index, labels in data_loader:
            x = x.to(device)
            edge_index = edge_index.to(device)
            labels = labels.to(device)
            outputs = model(x, edge_index)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return 100.0 * correct / total

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=100, lr=0.001):
    """
    Train the STGCN model on ARID dataset
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    best_acc = 0.0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (x, edge_index, labels) in enumerate(train_loader):
            x, edge_index, labels = x.to(device), edge_index.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(x, edge_index)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        scheduler.step()

        # Validation phase
        val_acc = evaluate_model(model, val_loader, device)

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_arid_stgcn.pth')

        print(f'Epoch {epoch+1}/{num_epochs}: '
              f'Train Loss: {train_loss/len(train_loader):.4f}, '
              f'Train Acc: {100.*correct/total:.2f}%, '
              f'Val Acc: {val_acc:.2f}%')

    print(f'Best validation accuracy: {best_acc:.2f}%')
    return best_acc

In [None]:
def precompute_skeletons(dataset_root, split_files, output_file):
    """
    **MAIN SKELETON PRECOMPUTATION FUNCTION**
    Precompute and save skeleton data for all videos in the dataset

    Args:
        dataset_root (str): Path to ARID dataset root directory
        split_files (list): List of split file paths to process
        output_file (str): Path to output JSON file for skeleton data
    """
    print("=" * 60)
    print("STARTING SKELETON PRECOMPUTATION")
    print("=" * 60)

    extractor = SkeletonExtractor()
    skeleton_data = {}

    # Process each split file
    for split_file in split_files:
        print(f"\nProcessing split file: {split_file}")

        # Create temporary dataset to get video paths
        temp_dataset = ARIDDataset(dataset_root, split_file)

        for idx, (video_path, class_idx, sample_idx) in enumerate(tqdm(temp_dataset.samples, desc="Processing videos", unit="video")):
            video_key = Path(video_path).name

            if video_key not in skeleton_data:
                try:
                    skeletons = extractor.extract_keypoints(video_path)
                    skeleton_data[video_key] = skeletons.tolist()
                except Exception as e:
                    # Store empty data for failed extractions
                    skeleton_data[video_key] = []

    # Save to JSON file
    print(f"\nSaving {len(skeleton_data)} skeleton sequences to {output_file}")
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    with open(output_file, 'w') as f:
        json.dump(skeleton_data, f, indent=2)

    print("=" * 60)
    print("SKELETON PRECOMPUTATION COMPLETED")
    print(f"Total videos processed: {len(skeleton_data)}")
    print(f"Saved to: {output_file}")
    print("=" * 60)

In [None]:
if __name__ == "__main__":
    # Configuration
    DATASET_ROOT = "/kaggle/input/aridzip"
    TRAIN_SPLIT = "/kaggle/input/aridzip/list_cvt/split_0/split0_train.txt"
    TEST_SPLIT = "/kaggle/input/aridzip/list_cvt/split_0/split0_test.txt"
    PRECOMPUTED_SKELETON_FILE = "/kaggle/working/arid_precomputed_skeletons.json"

    BATCH_SIZE = 8
    NUM_WORKERS = 2
    SEQUENCE_LENGTH = 32
    NUM_EPOCHS = 100
    LEARNING_RATE = 0.001

    # **STEP 1: PRECOMPUTE SKELETONS (RUN ONCE)**
    print("Step 1: Checking for precomputed skeleton data...")

    if not os.path.exists(PRECOMPUTED_SKELETON_FILE):
        print("Precomputed skeleton file not found. Starting skeleton extraction...")
        print("WARNING: This process may take several hours for the full ARID dataset!")

        # Precompute skeletons for both train and test splits
        precompute_skeletons(
            dataset_root=DATASET_ROOT,
            split_files=[TRAIN_SPLIT, TEST_SPLIT],
            output_file=PRECOMPUTED_SKELETON_FILE
        )
        print(" Skeleton precomputation completed!")
    else:
        print(f"Found precomputed skeleton file: {PRECOMPUTED_SKELETON_FILE}")

    # **STEP 3: CREATE DATALOADERS WITH PRECOMPUTED SKELETONS**
    print("\n" + "=" * 60)
    print("STEP 3: CREATING DATALOADERS WITH PRECOMPUTED SKELETONS")
    print("=" * 60)

    train_loader, test_loader, num_classes, class_names = create_dataloaders(
        dataset_root=DATASET_ROOT,
        train_split_file=TRAIN_SPLIT,
        test_split_file=TEST_SPLIT,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        sequence_length=SEQUENCE_LENGTH,
        precomputed_skeletons=PRECOMPUTED_SKELETON_FILE  # **KEY INTEGRATION**
    )

    print(f"Number of classes: {num_classes}")
    print(f"Class names: {class_names}")
    print(f"Train samples: {len(train_loader.dataset)}")
    print(f"Test samples: {len(test_loader.dataset)}")

    # **STEP 4: TEST DATALOADER**
    print("\n" + "=" * 60)
    print("STEP 4: TESTING DATALOADER")
    print("=" * 60)

    for batch_idx, (x, edge_index, labels) in enumerate(train_loader):
        print(f"Batch {batch_idx + 1}:")
        print(f"  Input shape: {x.shape}")  # (batch_size, channels, time, joints)
        print(f"  Edge index shape: {edge_index.shape}")
        print(f"  Labels shape: {labels.shape}")
        print(f"  Labels (indices): {labels}")
        print(f"  Labels (names): {[class_names[l.item()] for l in labels]}")

        if batch_idx == 2:  # Only show first 3 batches
            break

    # **STEP 5: CREATE AND TRAIN MODEL**
    print("\n" + "=" * 60)
    print("STEP 5: CREATING AND TRAINING MODEL")
    print("=" * 60)

    model = ARIDSTGCNClassifier(
        num_nodes=33,
        in_channels=3,
        hidden_channels=64,
        out_channels=64,
        num_classes=num_classes,
        kernel_size=3,
        K=3
    )

    print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

    # Train the model
    best_accuracy = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=test_loader,  # Using test as validation for simplicity
        num_epochs=NUM_EPOCHS,
        lr=LEARNING_RATE
    )

    print(f"\n Training completed! Best accuracy: {best_accuracy:.2f}%")
    print(f"Model saved as: best_arid_stgcn.pth")
    print(f"Precomputed skeletons saved as: {PRECOMPUTED_SKELETON_FILE}")