## Install dependencies
This code cell installs the required dependencies for the dataset processing and visualization tasks.

In [None]:
!pip install numpy
!pip install opencv-python
!pip install matplotlib
!pip install scipy

## Create a Test Dataset

This code cell defines a function `create_test_dataset` that creates a smaller, random subset of the main dataset for testing purposes. It randomly selects a specified number of folders from the source directory and copies them to a test directory. It also copies the model file (`.pt`) to the test directory.

### Configurations:
- `SOURCE_DIR`: The directory where the original dataset is located.
- `TEST_DIR`: The directory where the test dataset will be created.
- `NUM_FOLDERS_TO_SELECT`: The number of folders to randomly select from the source dataset.

In [None]:
import os
import shutil
import random
import glob

# --- Configuration ---
SOURCE_DIR = "./Cap_Tracker_Datastet"
TEST_DIR = "./test_Cap_Tracker_Datastet_test"
NUM_FOLDERS_TO_SELECT = 3

def create_test_dataset():
    """
    Creates a smaller, random subset of the main dataset for testing purposes.
    """
    print(f"Attempting to create test set from '{SOURCE_DIR}'...")

    # 1. --- Validate Source Paths ---
    source_dataset_path = os.path.join(SOURCE_DIR, "dataset")
    if not os.path.isdir(SOURCE_DIR) or not os.path.isdir(source_dataset_path):
        print(f"❌ Error: Source directory '{SOURCE_DIR}/dataset' not found. Please run this script from the correct location.")
        return

    # 2. --- Create Destination Directory Structure ---
    print(f"-> Creating test directory at '{TEST_DIR}'...")
    test_dataset_path = os.path.join(TEST_DIR, "dataset")
    # exist_ok=True prevents an error if the directory already exists
    os.makedirs(test_dataset_path, exist_ok=True)

    # 3. --- Copy the .pt Model File ---
    # Find the .pt file using glob, which is flexible with naming
    pt_files = glob.glob(os.path.join(SOURCE_DIR, '*.pt'))
    if not pt_files:
        print("️️⚠️ Warning: No .pt model file found in the source directory.")
    else:
        source_pt_path = pt_files[0] # Assume there's only one .pt file
        dest_pt_path = os.path.join(TEST_DIR, os.path.basename(source_pt_path))
        print(f"-> Copying model file: {os.path.basename(source_pt_path)}")
        shutil.copy2(source_pt_path, dest_pt_path)

    # 4. --- Randomly Select and Copy Dataset Folders ---
    # Get a list of all subdirectories within the source dataset folder
    all_subfolders = [d for d in os.listdir(source_dataset_path) if os.path.isdir(os.path.join(source_dataset_path, d))]

    if len(all_subfolders) < NUM_FOLDERS_TO_SELECT:
        print(f"❌ Error: Source dataset has only {len(all_subfolders)} folders, but {NUM_FOLDERS_TO_SELECT} were requested.")
        return

    # Randomly select a sample of folder names
    selected_folders = random.sample(all_subfolders, NUM_FOLDERS_TO_SELECT)
    print(f"-> Randomly selected {NUM_FOLDERS_TO_SELECT} folders: {', '.join(selected_folders)}")

    # Copy each selected folder to the new test dataset directory
    for folder_name in selected_folders:
        source_folder = os.path.join(source_dataset_path, folder_name)
        destination_folder = os.path.join(test_dataset_path, folder_name)

        # Remove the destination folder if it already exists to ensure a fresh copy
        if os.path.exists(destination_folder):
            shutil.rmtree(destination_folder)

        print(f"   - Copying '{folder_name}'...")
        shutil.copytree(source_folder, destination_folder)

    print(f"\n✅ Successfully created the test dataset in '{TEST_DIR}'.")

if __name__ == "__main__":
    create_test_dataset()

Attempting to create test set from '/content/drive/MyDrive/Cap_Tracker_Datastet'...
-> Creating test directory at '/content/drive/MyDrive/test_Cap_Tracker_Datastet_test'...
-> Copying model file: best.pt
-> Randomly selected 1 folders: 0019
   - Copying '0019'...

✅ Successfully created the test dataset in '/content/drive/MyDrive/test_Cap_Tracker_Datastet_test'.


This code cell changes the current working directory to the test dataset directory. This is done to make it easier to work with the files in the test dataset.

In [None]:
# Going to the test directory
%cd ./test_Cap_Tracker_Datastet_test

/content/drive/MyDrive/test_Cap_Tracker_Datastet_test


## Cleaning Dataset
This code cell cleans the dataset by removing annotations of classes that are not in the `ALLOWED_INSTRUMENTS` list. It keeps the classes in the `ALWAYS_KEPT_CLASSES` list regardless of the `ALLOWED_INSTRUMENTS` list. The cleaned annotations are saved to a new file named `annotation_cleaned.json`.

### Configurations:
- `DATASET_ROOT`: The root directory of the dataset.
- `VIDEOS_TO_PROCESS`: A list of video folder names to process.
- `ALLOWED_INSTRUMENTS`: A list of instrument class names to keep.

In [None]:
import json
import os

# --------------------------------------------------------------------------
# ✏️ 1. CONFIGURATION
#    Modify the variables in this section to match your needs.
# --------------------------------------------------------------------------

# Path to the root directory of the dataset.
# This should point to the folder containing your video subfolders.
DATASET_ROOT = "dataset/"

# A list of video folder names to process.
# Example: VIDEOS_TO_PROCESS = ["0020", "0481"]
VIDEOS_TO_PROCESS = ["0019", "0063"] #<-- CHANGE THIS

# A list of instrument class names that you want to KEEP.
# "Cannula", "Cap-Cystotome", "Cap-Forceps", "Cornea", "Forceps",
# "IA-Handpiece", "Lens-Injector", "Phaco-Handpiece", "Primary-Knife",
# "Pupil", "Second-Instrument", "Secondary-Knife"
# Example: ALLOWED_INSTRUMENTS = ["Forceps"]
ALLOWED_INSTRUMENTS = ["Forceps", "Cap-Cystotome", "Cap-Forceps"] #<-- CHANGE THIS

# --------------------------------------------------------------------------
# ⚙️ 2. CORE LOGIC
#    You don't need to change the code below this line.
# --------------------------------------------------------------------------

# These classes are always preserved, regardless of the allowed instruments list.
ALWAYS_KEPT_CLASSES = {"Cornea", "Pupil"}

def clean_annotations(data, allowed_instruments):
    """
    Filters the annotations list in the dataset based on a set of allowed classes.

    Args:
        data (dict): The loaded JSON data from an annotation file.
        allowed_instruments (set): A set of instrument class names that are permitted.

    Returns:
        tuple: A tuple containing the modified data (dict) and the number of removed annotations.
    """
    # Combine user-specified instruments with the classes that are always kept.
    final_allowed_classes = set(allowed_instruments).union(ALWAYS_KEPT_CLASSES)

    # Create a mapping from category ID to category name for easy lookup.
    category_id_to_name = {cat['id']: cat['name'] for cat in data['categories']}

    original_annotation_count = len(data['annotations'])
    cleaned_annotations = []

    # Iterate through each annotation instance in the file.
    for ann in data['annotations']:
        category_id = ann['category_id']
        class_name = category_id_to_name.get(category_id)

        # If the class name is in our final allowed list, keep the annotation.
        if class_name in final_allowed_classes:
            cleaned_annotations.append(ann)

    # Replace the old annotations list with the new, filtered one.
    data['annotations'] = cleaned_annotations
    removed_count = original_annotation_count - len(cleaned_annotations)

    return data, removed_count

# --- Main Execution Block ---
def run_cleaning():
    """
    Parses arguments and runs the dataset cleaning process.
    Saves the output as a new 'annotation_cleaned.json' file.
    """
    # --- Setup ---
    if not os.path.exists(DATASET_ROOT):
        print(f"❌ [Error] Dataset directory not found at '{DATASET_ROOT}'")
        print("Please make sure the folder exists and you have uploaded your data.")
        return

    allowed_instruments_set = set(ALLOWED_INSTRUMENTS)
    print(f"✅ Cleaning specified videos to only contain these instruments: {sorted(list(allowed_instruments_set))}")
    print(f"(Note: '{', '.join(ALWAYS_KEPT_CLASSES)}' will always be kept)\n")

    # --- Processing Loop ---
    for video_name in VIDEOS_TO_PROCESS:
        video_folder_path = os.path.join(DATASET_ROOT, video_name)
        print(f"--- Processing video: {video_name} ---")

        if not os.path.isdir(video_folder_path):
            print(f"  [Warning] Video folder not found: {video_folder_path}. Skipping.")
            continue

        # 1. Define file paths
        original_annotation_path = os.path.join(video_folder_path, "annotation.json")
        cleaned_annotation_path = os.path.join(video_folder_path, "annotation_cleaned.json")

        if not os.path.exists(original_annotation_path):
            print(f"  [Warning] 'annotation.json' not found for {video_name}. Skipping.")
            continue

        # 2. Load the original annotation data
        with open(original_annotation_path, 'r') as f:
            data_to_clean = json.load(f)

        # 3. Clean the annotation data
        cleaned_data, removed_count = clean_annotations(data_to_clean, allowed_instruments_set)

        if removed_count > 0:
            print(f"  Removed {removed_count} annotations for non-allowed instruments.")
        else:
            print("  No non-allowed instruments found. Annotation file is already clean.")

        # 4. Save the new, cleaned data to annotation_cleaned.json
        with open(cleaned_annotation_path, 'w') as f:
            json.dump(cleaned_data, f, indent=4)
        print(f"  Saved cleaned annotations to '{cleaned_annotation_path}'")

        print(f"--- Finished cleaning {video_name} ---\n")

    print("🎉 All done!")

# --- Run the script ---
run_cleaning()

✅ Cleaning specified videos to only contain these instruments: ['Cap-Cystotome', 'Cap-Forceps', 'Forceps']
(Note: 'Pupil, Cornea' will always be kept)

--- Processing video: 0019 ---
  No non-allowed instruments found. Annotation file is already clean.
  Saved cleaned annotations to 'dataset/0019/annotation_cleaned.json'
--- Finished cleaning 0019 ---

--- Processing video: 0063 ---
  Removed 3 annotations for non-allowed instruments.
  Saved cleaned annotations to 'dataset/0063/annotation_cleaned.json'
--- Finished cleaning 0063 ---

🎉 All done!


## Handling Missing Labels
This code cell handles missing labels in the dataset by using non-linear interpolation. It fills in the gaps in the annotations up to a specified maximum gap size. The interpolated annotations are saved to a new file named `annotation_miss_handled.json`.

### Configurations:
- `DATASET_ROOT`: The root directory of the dataset.
- `MAX_GAP_SIZE`: The maximum number of consecutive missing frames to interpolate.
- `VIDEOS_TO_PROCESS`: A list of specific video folder names to process. Leave empty to process all videos.
- `INPUT_ANNOTATION_FILENAME`: The filename of the input annotation file.
- `OUTPUT_ANNOTATION_FILENAME`: The filename for the output annotation file.

In [None]:
import cv2
import numpy as np
import json
import os
import glob
import math

# --------------------------------------------------------------------------
# ✏️ 1. CONFIGURATION
#    Modify the variables in this section to match your needs.
# --------------------------------------------------------------------------

# Path to the root directory of the dataset.
DATASET_ROOT = "dataset/"

# The maximum number of consecutive missing frames to interpolate.
MAX_GAP_SIZE = 10

# A list of specific video folder names to process.
# LEAVE EMPTY (e.g., []) to process ALL video folders found in DATASET_ROOT.
# Example: VIDEOS_TO_PROCESS = ["0020", "0481"]
VIDEOS_TO_PROCESS = [] #<-- CHANGE THIS

# The filename of the input annotation file (the one to read from).
INPUT_ANNOTATION_FILENAME = "annotation_cleaned.json"

# The filename for the output annotation file (the one that will be created).
OUTPUT_ANNOTATION_FILENAME = "annotation_miss_handled.json"


# --------------------------------------------------------------------------
# ⚙️ 2. CORE LOGIC
#    You don't need to change the code below this line.
# --------------------------------------------------------------------------

# --- Helper Functions ---

def ease_in_out_sine(t):
    """A non-linear easing function for smoother interpolation."""
    return -(math.cos(math.pi * t) - 1) / 2

def interpolate_bbox(bbox_start, bbox_end, t):
    """Interpolates bounding box [x, y, w, h] using an easing function."""
    t_eased = ease_in_out_sine(t)
    return [
        int(bbox_start[0] * (1 - t_eased) + bbox_end[0] * t_eased),
        int(bbox_start[1] * (1 - t_eased) + bbox_end[1] * t_eased),
        int(bbox_start[2] * (1 - t_eased) + bbox_end[2] * t_eased),
        int(bbox_start[3] * (1 - t_eased) + bbox_end[3] * t_eased),
    ]

def interpolate_keypoints(kp_start, kp_end, t):
    """Interpolates keypoints [x, y, v] using an easing function."""
    t_eased = ease_in_out_sine(t)
    # Only interpolate if both points are visible (v=2)
    if kp_start[2] == 2 and kp_end[2] == 2:
        return [
            int(kp_start[0] * (1 - t_eased) + kp_end[0] * t_eased),
            int(kp_start[1] * (1 - t_eased) + kp_end[1] * t_eased),
            2 # Mark as visible
        ]
    return [0, 0, 0] # Return non-visible if start or end is not visible

def generate_interpolated_mask(seg_start, bbox_start, bbox_interpolated):
    """
    Generates a new segmentation mask by resizing a template mask.
    """
    if not seg_start or not seg_start[0]:
        return None

    x_s, y_s, w_s, h_s = bbox_start
    if w_s <= 0 or h_s <= 0: return None

    template_mask = np.zeros((h_s, w_s), dtype=np.uint8)
    poly_start = np.array(seg_start[0], dtype=np.int32).reshape((-1, 1, 2))
    poly_start[:, :, 0] -= x_s
    poly_start[:, :, 1] -= y_s
    cv2.fillPoly(template_mask, [poly_start], 255)

    x_i, y_i, w_i, h_i = bbox_interpolated
    if w_i <= 0 or h_i <= 0: return None

    resized_mask = cv2.resize(template_mask, (w_i, h_i), interpolation=cv2.INTER_NEAREST)

    contours, _ = cv2.findContours(resized_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours: return None

    main_contour = contours[0]
    main_contour[:, :, 0] += x_i
    main_contour[:, :, 1] += y_i

    return [main_contour.flatten().tolist()]


# --- Main Processing Function ---

def process_annotations(data, max_gap_size):
    """
    Finds and fills missing label gaps in the annotation data.
    """
    total_gaps_filled = 0
    for ann in data['annotations']:
        num_frames = len(ann['segmentations'])
        idx = 0
        while idx < num_frames:
            # Find the start of a potential gap
            if ann['segmentations'][idx] is None and idx > 0 and ann['segmentations'][idx-1] is not None:
                start_gap_idx = idx

                # Find the end of the gap
                end_gap_idx = -1
                for j in range(start_gap_idx, num_frames):
                    if ann['segmentations'][j] is not None:
                        end_gap_idx = j
                        break

                # If a valid gap is found within the threshold, process it
                if end_gap_idx != -1:
                    gap_size = end_gap_idx - start_gap_idx
                    if 0 < gap_size <= max_gap_size:
                        print(f"  Found gap of size {gap_size} for annotation ID {ann['id']} from frame {start_gap_idx} to {end_gap_idx-1}. Interpolating...")
                        total_gaps_filled += gap_size

                        bbox_start = ann['bboxes'][start_gap_idx - 1]
                        bbox_end = ann['bboxes'][end_gap_idx]
                        seg_start = ann['segmentations'][start_gap_idx - 1]

                        # Find category to determine keypoint stride
                        category = next((cat for cat in data['categories'] if cat['id'] == ann['category_id']), None)
                        if not category or 'keypoints' not in category: continue

                        num_keypoints = len(category['keypoints'])
                        kp_stride = num_keypoints * 3

                        kp_list_start = ann['keypoints'][(start_gap_idx - 1) * kp_stride : start_gap_idx * kp_stride]
                        kp_list_end = ann['keypoints'][end_gap_idx * kp_stride : (end_gap_idx + 1) * kp_stride]

                        for i in range(gap_size):
                            frame_idx = start_gap_idx + i
                            t = (i + 1) / (gap_size + 1.0)

                            inter_bbox = interpolate_bbox(bbox_start, bbox_end, t)
                            ann['bboxes'][frame_idx] = inter_bbox

                            inter_seg = generate_interpolated_mask(seg_start, bbox_start, inter_bbox)
                            ann['segmentations'][frame_idx] = inter_seg

                            if inter_seg and inter_seg[0]:
                                contour = np.array(inter_seg[0]).reshape(-1, 2)
                                ann['areas'][frame_idx] = int(cv2.contourArea(contour))
                            else:
                                ann['areas'][frame_idx] = 0

                            inter_kp_list = []
                            for kp_idx in range(num_keypoints):
                                kp_start = kp_list_start[kp_idx*3 : (kp_idx+1)*3]
                                kp_end = kp_list_end[kp_idx*3 : (kp_idx+1)*3]
                                inter_kp = interpolate_keypoints(kp_start, kp_end, t)
                                inter_kp_list.extend(inter_kp)

                            start_kp_json_idx = frame_idx * kp_stride
                            ann['keypoints'][start_kp_json_idx : start_kp_json_idx + kp_stride] = inter_kp_list

                    idx = end_gap_idx
                else:
                    # No end to the gap was found, stop searching for this annotation
                    idx = num_frames
            else:
                idx += 1

    print(f"Total missing labels filled: {total_gaps_filled}")
    return data


# --------------------------------------------------------------------------
# ▶️ 3. EXECUTION
#    This block runs the script using the configuration above.
# --------------------------------------------------------------------------

def run_interpolation_script():
    if not os.path.exists(DATASET_ROOT):
        print(f"❌ [Error] Input dataset directory not found at '{DATASET_ROOT}'")
        return

    # If no specific videos are listed, find all subdirectories in the dataset folder
    if not VIDEOS_TO_PROCESS:
        print(f"No specific video provided. Processing all videos in '{DATASET_ROOT}'...")
        video_names = [os.path.basename(d) for d in glob.glob(os.path.join(DATASET_ROOT, '*')) if os.path.isdir(d)]
    else:
        video_names = VIDEOS_TO_PROCESS

    if not video_names:
        print(f"❌ No video subdirectories found in '{DATASET_ROOT}'")
        return

    print(f"Found {len(video_names)} video(s) to process: {sorted(video_names)}")

    for video_name in sorted(video_names):
        video_folder_path = os.path.join(DATASET_ROOT, video_name)
        print(f"\n--- Processing video: {video_name} ---")

        # Define file paths
        input_annotation_path = os.path.join(video_folder_path, INPUT_ANNOTATION_FILENAME)
        output_annotation_path = os.path.join(video_folder_path, OUTPUT_ANNOTATION_FILENAME)

        if not os.path.exists(input_annotation_path):
            print(f"  [Warning] Input file '{INPUT_ANNOTATION_FILENAME}' not found for {video_name}. Skipping.")
            continue

        with open(input_annotation_path, 'r') as f:
            cleaned_data = json.load(f)

        # Process the data to fill gaps
        handled_data = process_annotations(cleaned_data, MAX_GAP_SIZE)

        # Save the new annotation file
        with open(output_annotation_path, 'w') as f:
            json.dump(handled_data, f, indent=4)
        print(f"  ✅ Saved new annotations to {output_annotation_path}")

        print(f"--- Finished processing {video_name} ---")

    print("\n🎉 All done!")

# --- Run the script ---
run_interpolation_script()

No specific video provided. Processing all videos in 'dataset/'...
Found 2 video(s) to process: ['0019', '0063']

--- Processing video: 0019 ---
  Found gap of size 1 for annotation ID 0 from frame 409 to 409. Interpolating...
  Found gap of size 1 for annotation ID 0 from frame 1346 to 1346. Interpolating...
  Found gap of size 1 for annotation ID 0 from frame 1923 to 1923. Interpolating...
  Found gap of size 1 for annotation ID 0 from frame 1937 to 1937. Interpolating...
  Found gap of size 2 for annotation ID 0 from frame 2677 to 2678. Interpolating...
  Found gap of size 1 for annotation ID 1 from frame 223 to 223. Interpolating...
  Found gap of size 1 for annotation ID 1 from frame 409 to 409. Interpolating...
  Found gap of size 2 for annotation ID 1 from frame 471 to 472. Interpolating...
  Found gap of size 1 for annotation ID 1 from frame 587 to 587. Interpolating...
  Found gap of size 2 for annotation ID 1 from frame 595 to 596. Interpolating...
  Found gap of size 1 for a

## Outlier Removal
This code cell removes outliers from the dataset. It uses a sliding window to detect outliers based on velocity and corrects them using Cubic Spline interpolation. The smoothed annotations are saved to a new file named `annotation_smooth.json`.

### Configurations:
- `DATASET_ROOT`: The root directory of the dataset.
- `WINDOW_SIZE`: The size of the sliding window used to check for local outliers.
- `THRESHOLD_STD_DEV`: The number of standard deviations from the median velocity to consider a point an outlier.
- `INSTRUMENT_CLASSES`: These classes are instruments whose trajectories will be smoothed.
- `VIDEOS_TO_PROCESS`: A list of specific video folder names to process. Leave empty to process all videos.
- `INPUT_ANNOTATION_FILENAME`: The filename of the input annotation file.
- `OUTPUT_ANNOTATION_FILENAME`: The filename for the output annotation file.

In [None]:
import json
import numpy as np
import os
import glob
from collections import deque
from scipy.interpolate import CubicSpline

# --------------------------------------------------------------------------
# ✏️ 1. CONFIGURATION
#    Modify the variables in this section to match your needs.
# --------------------------------------------------------------------------

# Path to the root directory of the dataset.
DATASET_ROOT = "dataset/"

# --- Smoothing Parameters ---
# The size of the sliding window used to check for local outliers.
WINDOW_SIZE = 30 # since the frame rate of videos is 30fps
# The number of standard deviations from the median velocity to consider a point an outlier.
THRESHOLD_STD_DEV = 2.0
# These classes are instruments whose trajectories will be smoothed.
INSTRUMENT_CLASSES = {
    "Cannula", "Cap-Cystotome", "Cap-Forceps", "Forceps", "IA-Handpiece",
    "Lens-Injector", "Phaco-Handpiece", "Primary-Knife", "Second-Instrument",
    "Secondary-Knife"
}

# --- Video & File Settings ---
# A list of specific video folder names to process.
# LEAVE EMPTY (e.g., []) to process ALL video folders found in DATASET_ROOT.
# Example: VIDEOS_TO_PROCESS = ["0020", "0481"]
VIDEOS_TO_PROCESS = [] #<-- CHANGE THIS

# The filename of the input annotation file (the one to read from).
INPUT_ANNOTATION_FILENAME = "annotation_miss_handled.json"

# The filename for the output annotation file (the one that will be created).
OUTPUT_ANNOTATION_FILENAME = "annotation_smooth.json"


# --------------------------------------------------------------------------
# ⚙️ 2. CORE LOGIC
#    You don't need to change the code below this line.
# --------------------------------------------------------------------------

def smooth_trajectory_with_spline(keypoints, num_keypoints, window_size, threshold):
    """
    Detects outliers in a trajectory based on velocity and corrects them
    using Cubic Spline interpolation.
    """
    # 1. Extract tip trajectory and calculate frame-to-frame velocities
    # Assumes the 'tip' is the second keypoint in the list for a given frame.
    tip_track, velocities = [], [0.0]
    for i in range(0, len(keypoints), num_keypoints * 3):
        # Tip is the second keypoint, its data starts at index 3
        tip_data = keypoints[i+3 : i+6]
        tip_track.append([tip_data[0], tip_data[1]] if tip_data[2] == 2 else None)

    for i in range(1, len(tip_track)):
        if tip_track[i] is not None and tip_track[i-1] is not None:
            velocities.append(np.linalg.norm(np.array(tip_track[i]) - np.array(tip_track[i-1])))
        else:
            velocities.append(0.0)

    # 2. Pass 1: Detect Outliers using a sliding window on velocity
    outlier_indices = set()
    window = deque(maxlen=window_size)
    for i, velocity in enumerate(velocities):
        window.append(velocity)
        if len(window) < window_size // 2: continue

        median_vel = np.median(window)
        std_dev_vel = np.std(window)
        # Set a minimum std deviation to handle flat-line velocity sections
        if std_dev_vel < 1.0: std_dev_vel = 1.0

        if velocity > median_vel + threshold * std_dev_vel:
            outlier_indices.add(i)

    if not outlier_indices:
        return keypoints # No changes needed

    print(f"    -> Detected {len(outlier_indices)} outliers. Applying spline correction.")

    # 3. Pass 2: Correct Outliers with Cubic Spline Interpolation
    good_indices, good_points = [], []
    for i, point in enumerate(tip_track):
        if i not in outlier_indices and point is not None:
            good_indices.append(i)
            good_points.append(point)

    # A cubic spline needs at least 4 points for good results
    if len(good_indices) < 4:
        print(f"    -> Warning: Not enough good points ({len(good_indices)}) for a reliable spline. Outliers will be removed (set to null).")
        for i in outlier_indices:
            tip_track[i] = None
    else:
        # Create splines for x and y coordinates
        spline_x = CubicSpline(good_indices, [p[0] for p in good_points])
        spline_y = CubicSpline(good_indices, [p[1] for p in good_points])
        # Use the splines to predict new positions for the outlier frames
        for i in outlier_indices:
            tip_track[i] = [spline_x(i), spline_y(i)]

    # 4. Reconstruct the flat keypoints list with the corrected data
    new_keypoints = list(keypoints)
    for i, point in enumerate(tip_track):
        idx = i * num_keypoints * 3
        if point:
            new_keypoints[idx + 3:idx + 6] = [int(point[0]), int(point[1]), 2]
        else:
            # If a point was an outlier and couldn't be interpolated, mark it as not visible
            new_keypoints[idx + 3:idx + 6] = [0, 0, 0]

    return new_keypoints

def process_annotations(data, window_size, threshold):
    """
    Main processing function that iterates through annotations and applies smoothing.
    """
    category_map = {cat['id']: cat for cat in data['categories']}

    for ann in data["annotations"]:
        class_name = category_map.get(ann['category_id'], {}).get('name')
        if class_name in INSTRUMENT_CLASSES:
            print(f"  Processing trajectory for '{class_name}' (ID: {ann['id']})...")
            num_kps = len(category_map[ann['category_id']]['keypoints'])
            # We need at least 2 keypoints (e.g., center and tip) to smooth the tip
            if num_kps < 2:
                print(f"    -> Skipping, not enough keypoints ({num_kps}).")
                continue

            ann['keypoints'] = smooth_trajectory_with_spline(
                ann['keypoints'], num_kps, window_size, threshold
            )
    return data

# --------------------------------------------------------------------------
# ▶️ 3. EXECUTION
#    This block runs the script using the configuration above.
# --------------------------------------------------------------------------

def run_smoothing_script():
    """
    Finds annotation files and runs the trajectory smoothing process.
    """
    if not os.path.exists(DATASET_ROOT):
        print(f"❌ [Error] Input dataset directory not found at '{DATASET_ROOT}'")
        return

    # If no specific videos are listed, find all subdirectories
    if not VIDEOS_TO_PROCESS:
        print(f"No specific video provided. Processing all videos in '{DATASET_ROOT}'...")
        video_names = [os.path.basename(d) for d in glob.glob(os.path.join(DATASET_ROOT, '*')) if os.path.isdir(d)]
    else:
        video_names = VIDEOS_TO_PROCESS

    if not video_names:
        print(f"❌ No video subdirectories found in '{DATASET_ROOT}'")
        return

    print(f"Found {len(video_names)} video(s) to process: {sorted(video_names)}")

    for video_name in sorted(video_names):
        video_folder_path = os.path.join(DATASET_ROOT, video_name)
        print(f"\n--- Processing video: {video_name} ---")

        # Define file paths
        input_path = os.path.join(video_folder_path, INPUT_ANNOTATION_FILENAME)
        output_path = os.path.join(video_folder_path, OUTPUT_ANNOTATION_FILENAME)

        if not os.path.exists(input_path):
            print(f"  [Warning] Input file '{INPUT_ANNOTATION_FILENAME}' not found. Skipping.")
            continue

        with open(input_path, 'r') as f:
            data = json.load(f)

        # Run the smoothing process
        smoothed_data = process_annotations(data, WINDOW_SIZE, THRESHOLD_STD_DEV)

        # Save the new annotation file
        with open(output_path, 'w') as f:
            json.dump(smoothed_data, f, indent=4)
        print(f"  ✅ Saved smoothed annotations to {output_path}")

        print(f"--- Finished smoothing for {video_name} ---")

    print("\n🎉 All done!")

# --- Run the script ---
run_smoothing_script()

No specific video provided. Processing all videos in 'dataset/'...
Found 2 video(s) to process: ['0019', '0063']

--- Processing video: 0019 ---
  Processing trajectory for 'Forceps' (ID: 0)...
    -> Detected 269 outliers. Applying spline correction.
  Processing trajectory for 'Cap-Cystotome' (ID: 1)...
    -> Detected 243 outliers. Applying spline correction.
  Processing trajectory for 'Cap-Forceps' (ID: 220)...
  ✅ Saved smoothed annotations to dataset/0019/annotation_smooth.json
--- Finished smoothing for 0019 ---

--- Processing video: 0063 ---
  Processing trajectory for 'Cap-Cystotome' (ID: 0)...
    -> Detected 469 outliers. Applying spline correction.
  Processing trajectory for 'Forceps' (ID: 1)...
    -> Detected 129 outliers. Applying spline correction.
  ✅ Saved smoothed annotations to dataset/0063/annotation_smooth.json
--- Finished smoothing for 0063 ---

🎉 All done!


## Motion features
This code cell adds motion features to the dataset. It calculates velocity, acceleration, and jerk for each instrument. It also calculates the relative position of the surgical instrument to the Pupil and the relative motion features. The enriched annotations are saved to a new file named `annotation_full.json`.

### Configurations:
- `DATASET_ROOT`: The root directory of the dataset.
- `INSTRUMENT_CLASSES`: These classes are instruments and will have motion features calculated.
- `VIDEOS_TO_PROCESS`: A list of specific video folder names to process. Leave empty to process all videos.
- `INPUT_ANNOTATION_FILENAME`: The filename of the input annotation file.
- `OUTPUT_ANNOTATION_FILENAME`: The filename for the final output annotation file.

In [None]:
import numpy as np
import json
import os
import glob

# --------------------------------------------------------------------------
# ✏️ 1. CONFIGURATION
#    Modify the variables in this section to match your needs.
# --------------------------------------------------------------------------

# Path to the root directory of the dataset.
DATASET_ROOT = "dataset/"

# These classes are instruments and will have motion features calculated.
INSTRUMENT_CLASSES = {
    "Cannula", "Cap-Cystotome", "Cap-Forceps", "Forceps", "IA-Handpiece",
    "Lens-Injector", "Phaco-Handpiece", "Primary-Knife", "Second-Instrument",
    "Secondary-Knife"
}

# --- Video & File Settings ---
# A list of specific video folder names to process.
# LEAVE EMPTY (e.g., []) to process ALL video folders found in DATASET_ROOT.
# Example: VIDEOS_TO_PROCESS = ["0020", "0481"]
VIDEOS_TO_PROCESS = [] #<-- CHANGE THIS

# The filename of the input annotation file (the one to read from).
INPUT_ANNOTATION_FILENAME = "annotation_smooth.json"

# The filename for the final output annotation file.
OUTPUT_ANNOTATION_FILENAME = "annotation_full.json"


# --------------------------------------------------------------------------
# ⚙️ 2. CORE LOGIC
#    You don't need to change the code below this line.
# --------------------------------------------------------------------------

def _calculate_kinematics(position_track):
    """
    Calculates velocity, acceleration, and jerk from a trajectory of 2D points.
    """
    num_frames = len(position_track)
    velocities = [None] * num_frames
    accelerations = [None] * num_frames
    jerks = [None] * num_frames

    # Calculate Velocities (pixels/frame)
    for i in range(1, num_frames):
        p1 = position_track[i-1]
        p2 = position_track[i]
        if p1 is not None and p2 is not None:
            velocities[i] = [p2[0] - p1[0], p2[1] - p1[1]]

    # Calculate Accelerations (pixels/frame^2)
    for i in range(1, num_frames):
        v1 = velocities[i-1]
        v2 = velocities[i]
        if v1 is not None and v2 is not None:
            accelerations[i] = [v2[0] - v1[0], v2[1] - v1[1]]

    # Calculate Jerks (pixels/frame^3)
    for i in range(1, num_frames):
        a1 = accelerations[i-1]
        a2 = accelerations[i]
        if a1 is not None and a2 is not None:
            jerks[i] = [a2[0] - a1[0], a2[1] - a1[1]]

    return velocities, accelerations, jerks

def process_video_annotations(data):
    """
    Adds motion features to all instrument annotations in the dataset.
    """
    category_map = {cat['id']: cat for cat in data['categories']}
    if not data.get('videos') or not data['videos'][0].get('file_names'):
        print("  [Error] 'videos' or 'file_names' not found in JSON. Cannot determine frame count.")
        return None
    num_frames = len(data['videos'][0]['file_names'])

    # 1. Find the Pupil's center trajectory first. This is our reference.
    pupil_center_track = [None] * num_frames
    pupil_ann = next((ann for ann in data['annotations'] if category_map.get(ann['category_id'], {}).get('name') == "Pupil"), None)

    if pupil_ann:
        num_keypoints = len(category_map[pupil_ann['category_id']]['keypoints'])
        for i in range(num_frames):
            kp_base_idx = i * num_keypoints * 3
            # Center is the first keypoint
            center_data = pupil_ann['keypoints'][kp_base_idx : kp_base_idx + 3]
            if center_data[2] == 2: # If center is visible
                pupil_center_track[i] = [center_data[0], center_data[1]]
    else:
        print("  [Warning] Pupil annotation not found. Relative kinematics will not be calculated.")

    # 2. Iterate through annotations and process instruments
    for ann in data['annotations']:
        class_name = category_map.get(ann['category_id'], {}).get('name')

        if class_name in INSTRUMENT_CLASSES:
            print(f"  Calculating motion features for '{class_name}' (ID: {ann['id']}).")
            num_keypoints = len(category_map[ann['category_id']]['keypoints'])
            if num_keypoints < 2:
                print(f"    -> Skipping, instrument requires at least 2 keypoints but has {num_keypoints}.")
                continue

            # a. Extract the instrument's absolute tip trajectory
            tip_track_abs = [None] * num_frames
            for i in range(num_frames):
                kp_base_idx = i * num_keypoints * 3
                # Tip is the second keypoint
                tip_data = ann['keypoints'][kp_base_idx + 3 : kp_base_idx + 6]
                if tip_data[2] == 2: # If tip is visible
                    tip_track_abs[i] = [tip_data[0], tip_data[1]]

            # b. Calculate absolute kinematics
            vel_abs, acc_abs, jerk_abs = _calculate_kinematics(tip_track_abs)

            # c. Calculate the relative position trajectory
            tip_track_rel = [None] * num_frames
            for i in range(num_frames):
                if tip_track_abs[i] is not None and pupil_center_track[i] is not None:
                    tip_track_rel[i] = [
                        tip_track_abs[i][0] - pupil_center_track[i][0],
                        tip_track_abs[i][1] - pupil_center_track[i][1]
                    ]

            # d. Calculate relative kinematics
            vel_rel, acc_rel, jerk_rel = _calculate_kinematics(tip_track_rel)

            # e. Add the new "motion_features" object to the annotation
            ann['motion_features'] = {
                "absolute": {
                    "velocity": vel_abs,
                    "acceleration": acc_abs,
                    "jerk": jerk_abs
                },
                "relative_to_pupil": {
                    "position": tip_track_rel,
                    "velocity": vel_rel,
                    "acceleration": acc_rel,
                    "jerk": jerk_rel
                }
            }

    return data

# --------------------------------------------------------------------------
# ▶️ 3. EXECUTION
#    This block runs the script using the configuration above.
# --------------------------------------------------------------------------

def run_feature_calculation():
    """
    Finds annotation files and runs the motion feature calculation process.
    """
    if not os.path.exists(DATASET_ROOT):
        print(f"❌ [Error] Input dataset directory not found at '{DATASET_ROOT}'")
        return

    # If no specific videos are listed, find all subdirectories
    if not VIDEOS_TO_PROCESS:
        print(f"No specific video provided. Processing all videos in '{DATASET_ROOT}'...")
        video_names = [os.path.basename(d) for d in glob.glob(os.path.join(DATASET_ROOT, '*')) if os.path.isdir(d)]
    else:
        video_names = VIDEOS_TO_PROCESS

    if not video_names:
        print(f"❌ No video subdirectories found in '{DATASET_ROOT}'")
        return

    print(f"Found {len(video_names)} video(s) to process: {sorted(video_names)}")

    # --- Processing Loop ---
    for video_name in sorted(video_names):
        video_folder_path = os.path.join(DATASET_ROOT, video_name)
        print(f"\n--- Processing video: {video_name} ---")

        # Define file paths
        input_path = os.path.join(video_folder_path, INPUT_ANNOTATION_FILENAME)
        output_path = os.path.join(video_folder_path, OUTPUT_ANNOTATION_FILENAME)

        if not os.path.exists(input_path):
            print(f"  [Warning] Input file '{INPUT_ANNOTATION_FILENAME}' not found. Skipping.")
            continue

        with open(input_path, 'r') as f:
            data = json.load(f)

        # Process the data to add motion features
        enriched_data = process_video_annotations(data)

        if enriched_data:
            # Save the new, fully-featured annotation file
            with open(output_path, 'w') as f:
                json.dump(enriched_data, f, indent=4)
            print(f"  ✅ Saved final annotations with motion features to {output_path}")
        else:
            print(f"  [Error] Processing failed for video {video_name}. Output file not saved.")

        print(f"--- Finished processing {video_name} ---")

    print("\n🎉 All done!")


# --- Run the script ---
run_feature_calculation()

No specific video provided. Processing all videos in 'dataset/'...
Found 2 video(s) to process: ['0019', '0063']

--- Processing video: 0019 ---
  Calculating motion features for 'Forceps' (ID: 0).
  Calculating motion features for 'Cap-Cystotome' (ID: 1).
  Calculating motion features for 'Cap-Forceps' (ID: 220).
  ✅ Saved final annotations with motion features to dataset/0019/annotation_full.json
--- Finished processing 0019 ---

--- Processing video: 0063 ---
  Calculating motion features for 'Cap-Cystotome' (ID: 0).
  Calculating motion features for 'Forceps' (ID: 1).
  ✅ Saved final annotations with motion features to dataset/0063/annotation_full.json
--- Finished processing 0063 ---

🎉 All done!


## Visualize motion features and trajectories
This code cell visualizes the motion features and trajectories. It generates and saves a grid of plots for each specified instrument. The plots include the absolute and relative trajectories, as well as the absolute and relative velocity, acceleration, and jerk.

### Configurations:
- `DATASET_ROOT`: The root directory of the dataset to visualize.
- `VISUALIZATION_OUTPUT_DIR`: The directory where the output plot images will be saved.
- `VIDEOS_TO_PROCESS`: A list of specific video folder names to process. Leave empty to process all videos.
- `INSTRUMENTS_TO_PLOT`: A list of specific instrument names to plot. Leave empty to plot all instruments.
- `COMPARE_WITH_CLEANED`: Set to True to overlay the original trajectory from "annotation_cleaned.json" for comparison.

In [None]:
import json
import numpy as np
import os
import glob
import matplotlib.pyplot as plt

# ==================================
# === 📝 GLOBAL CONFIGURATION 📝 ===
# ==================================
# The root directory of the dataset to visualize.
DATASET_ROOT = "dataset/"

# The directory where the output plot images will be saved.
VISUALIZATION_OUTPUT_DIR = "visualizations/"

# A list of specific video folder names to process.
# LEAVE THIS LIST EMPTY (e.g., []) to process ALL videos.
VIDEOS_TO_PROCESS = []

# A list of specific instrument names to plot.
# LEAVE THIS LIST EMPTY (e.g., []) to plot ALL instruments.
INSTRUMENTS_TO_PLOT = []

# Set to True to overlay the original trajectory from "annotation_cleaned.json"
# for comparison. Set to False to only plot the final trajectory.
COMPARE_WITH_CLEANED = True


# --- Helper Functions ---

def calculate_magnitude(vectors):
    """Calculates the magnitude of a list of 2D vectors."""
    magnitudes = [np.linalg.norm(v) if v is not None and len(v) == 2 else np.nan for v in vectors]
    return magnitudes

def plot_kinematics(ax, data, title, color):
    """Plots a single kinematic feature (magnitude vs. time)."""
    ax.plot(data, label=title, color=color, linewidth=1.5)
    ax.set_title(title)
    ax.set_xlabel("Frame Number")
    ax.set_ylabel("Magnitude (pixels/frame^n)")
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()
    ax.margins(x=0.01)

def plot_trajectory(ax, trajectory, title, color, alpha=1.0):
    """Plots a 2D trajectory (Y vs. X)."""
    valid_points = np.array([p for p in trajectory if p is not None])
    if valid_points.size > 0:
        ax.plot(valid_points[:, 0], valid_points[:, 1], 'o-', label=title, color=color, markersize=2, linewidth=1, alpha=alpha)
    ax.set_title(title)
    ax.set_xlabel("X Coordinate")
    ax.set_ylabel("Y Coordinate")
    ax.set_aspect('equal', adjustable='box')
    if "Absolute" in title:
        ax.invert_yaxis()
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()

# --- Main Processing Function ---

def generate_plots_for_video(data, video_name, instruments_to_plot, cleaned_data=None):
    """Generates and saves a grid of plots for each specified instrument."""
    category_map = {cat['id']: cat for cat in data['categories']}
    cleaned_category_map = {cat['id']: cat for cat in cleaned_data['categories']} if cleaned_data else None

    for ann in data['annotations']:
        category_id = ann['category_id']
        class_name = category_map.get(category_id, {}).get('name')

        if not class_name or class_name not in instruments_to_plot:
            continue

        if 'motion_features' not in ann:
            print(f"  -> Skipping '{class_name}': No 'motion_features' found.")
            continue

        print(f"  -> Generating plots for '{class_name}' (ID: {ann['id']})...")

        features = ann['motion_features']

        # Extract the smoothed absolute trajectory
        num_keypoints = len(category_map[category_id]['keypoints'])
        kp_stride = num_keypoints * 3
        absolute_trajectory = [[ann['keypoints'][i+3], ann['keypoints'][i+4]] if ann['keypoints'][i+5] == 2 else None for i in range(0, len(ann['keypoints']), kp_stride)]

        # Find and extract the original trajectory for comparison
        original_absolute_trajectory = None
        if cleaned_data:
            original_ann = next((o_ann for o_ann in cleaned_data['annotations'] if cleaned_category_map.get(o_ann['category_id'], {}).get('name') == class_name), None)
            if original_ann:
                print(f"     Found corresponding 'cleaned' annotation for comparison.")
                o_num_keypoints = len(cleaned_category_map[original_ann['category_id']]['keypoints'])
                o_kp_stride = o_num_keypoints * 3
                original_absolute_trajectory = [[original_ann['keypoints'][i+3], original_ann['keypoints'][i+4]] if original_ann['keypoints'][i+5] == 2 else None for i in range(0, len(original_ann['keypoints']), o_kp_stride)]

        # --- Create a 4x2 plot grid ---
        fig, axes = plt.subplots(4, 2, figsize=(18, 24))
        fig.suptitle(f"Motion Analysis for '{class_name}'\nVideo: {video_name}", fontsize=20, y=0.96)

        # Plotting
        abs_ax = axes[0, 0]
        plot_trajectory(abs_ax, absolute_trajectory, 'Smoothed Trajectory', 'royalblue')
        if original_absolute_trajectory:
            plot_trajectory(abs_ax, original_absolute_trajectory, 'Original (Cleaned)', 'red', alpha=0.7)
        abs_ax.set_title("Absolute Trajectory Comparison")
        abs_ax.legend()

        plot_trajectory(axes[0, 1], features['relative_to_pupil']['position'], 'Relative Trajectory (to Pupil)', 'seagreen')
        plot_kinematics(axes[1, 0], calculate_magnitude(features['absolute']['velocity']), 'Absolute Velocity', 'royalblue')
        plot_kinematics(axes[1, 1], calculate_magnitude(features['relative_to_pupil']['velocity']), 'Relative Velocity', 'seagreen')
        plot_kinematics(axes[2, 0], calculate_magnitude(features['absolute']['acceleration']), 'Absolute Acceleration', 'royalblue')
        plot_kinematics(axes[2, 1], calculate_magnitude(features['relative_to_pupil']['acceleration']), 'Relative Acceleration', 'seagreen')
        plot_kinematics(axes[3, 0], calculate_magnitude(features['absolute']['jerk']), 'Absolute Jerk', 'royalblue')
        plot_kinematics(axes[3, 1], calculate_magnitude(features['relative_to_pupil']['jerk']), 'Relative Jerk', 'seagreen')

        plt.tight_layout(rect=[0, 0, 1, 0.95])

        # Save the figure
        output_filename = os.path.join(VISUALIZATION_OUTPUT_DIR, f"{video_name}_{class_name}_motion_analysis.png")
        plt.savefig(output_filename, bbox_inches='tight')
        print(f"     Saved plot to {output_filename}")
        plt.close(fig)

# --- Main Execution Block ---

def main():
    os.makedirs(VISUALIZATION_OUTPUT_DIR, exist_ok=True)
    if not os.path.exists(DATASET_ROOT):
        print(f"[Error] Input dataset directory not found at '{DATASET_ROOT}'")
        return

    # Determine which videos to process
    if VIDEOS_TO_PROCESS:
        video_names = VIDEOS_TO_PROCESS
    else:
        print("No specific video provided. Processing all videos...")
        video_names = sorted([os.path.basename(d) for d in glob.glob(os.path.join(DATASET_ROOT, '*')) if os.path.isdir(d)])

    if not video_names:
        print(f"No video subdirectories found in '{DATASET_ROOT}'")
        return

    for video_name in video_names:
        video_folder_path = os.path.join(DATASET_ROOT, video_name)
        print(f"\n--- Processing video: {video_name} ---")

        full_path = os.path.join(video_folder_path, "annotation_full.json")
        if not os.path.exists(full_path):
            print(f"  [Warning] Input file 'annotation_full.json' not found. Skipping.")
            continue

        with open(full_path, 'r') as f:
            data = json.load(f)

        cleaned_data_for_comparison = None
        if COMPARE_WITH_CLEANED:
            cleaned_path = os.path.join(video_folder_path, "annotation_cleaned.json")
            if os.path.exists(cleaned_path):
                print(f"  -> Loading 'annotation_cleaned.json' for comparison.")
                with open(cleaned_path, 'r') as f:
                    cleaned_data_for_comparison = json.load(f)
            else:
                print(f"  [Warning] Comparison file 'annotation_cleaned.json' not found. Comparison skipped.")

        # Determine which instruments to plot
        all_instrument_names = {cat['name'] for cat in data['categories'] if cat['name'] not in {'Pupil', 'Cornea'}}
        if INSTRUMENTS_TO_PLOT:
            instruments_to_plot = set(INSTRUMENTS_TO_PLOT)
        else:
            instruments_to_plot = all_instrument_names

        generate_plots_for_video(data, video_name, instruments_to_plot, cleaned_data=cleaned_data_for_comparison)

        print(f"--- Finished plotting for {video_name} ---")

# --- Run the script ---
main()

No specific video provided. Processing all videos...

--- Processing video: 0019 ---
  -> Loading 'annotation_cleaned.json' for comparison.
  -> Generating plots for 'Forceps' (ID: 0)...
     Found corresponding 'cleaned' annotation for comparison.
     Saved plot to visualizations/0019_Forceps_motion_analysis.png
  -> Generating plots for 'Cap-Cystotome' (ID: 1)...
     Found corresponding 'cleaned' annotation for comparison.
     Saved plot to visualizations/0019_Cap-Cystotome_motion_analysis.png
  -> Generating plots for 'Cap-Forceps' (ID: 220)...
     Found corresponding 'cleaned' annotation for comparison.


  ax.legend()
  abs_ax.legend()


     Saved plot to visualizations/0019_Cap-Forceps_motion_analysis.png
--- Finished plotting for 0019 ---

--- Processing video: 0063 ---
  -> Loading 'annotation_cleaned.json' for comparison.
  -> Generating plots for 'Cap-Cystotome' (ID: 0)...
     Found corresponding 'cleaned' annotation for comparison.
     Saved plot to visualizations/0063_Cap-Cystotome_motion_analysis.png
  -> Generating plots for 'Forceps' (ID: 1)...
     Found corresponding 'cleaned' annotation for comparison.
     Saved plot to visualizations/0063_Forceps_motion_analysis.png
--- Finished plotting for 0063 ---


## Video Visualization
This code cell creates annotated videos from the dataset. It draws the annotations on each frame of the video and saves the annotated video to a new file.

### Configurations:
- `DATASET_ROOT`: The root directory of the dataset.
- `VIDEO_OUTPUT_DIR`: The directory where the output annotated videos will be saved.
- `INPUT_ANNOTATION_FILENAME`: The exact name of the annotation file to use for visualization.
- `VIDEOS_TO_PROCESS`: A list of specific video folder names to process. Leave empty to process all videos.
- `COLOR_DICT`: A dictionary of colors to use for the different classes.

In [None]:
import cv2
import numpy as np
import json
import os
import glob

# --------------------------------------------------------------------------
# ✏️ 1. CONFIGURATION
#    Modify the variables in this section to match your needs.
# --------------------------------------------------------------------------

# Path to the root directory of the dataset.
DATASET_ROOT = "dataset/"

# Directory where the output annotated videos will be saved.
VIDEO_OUTPUT_DIR = "visualized_videos_motion/"

# The exact name of the annotation file to use for visualization.
# This file must exist in each video folder you process.
# Examples: "annotation.json", "annotation_cleaned.json", "annotation_full.json"
INPUT_ANNOTATION_FILENAME = "annotation_full.json" #<-- CHANGE THIS

# A list of specific video folder names to process.
# LEAVE EMPTY (e.g., []) to process ALL video folders found in DATASET_ROOT.
# Example: VIDEOS_TO_PROCESS = ["0020"]
VIDEOS_TO_PROCESS = [] #<-- CHANGE THIS

# Define a color dictionary for consistent class colors
COLOR_DICT = {
    "Cannula": (255, 0, 0), "Cap-Cystotome": (0, 255, 0), "Cap-Forceps": (0, 0, 255),
    "Cornea": (255, 255, 0), "Forceps": (255, 0, 255), "IA-Handpiece": (0, 255, 255),
    "Lens-Injector": (125, 125, 0), "Phaco-Handpiece": (0, 125, 125), "Primary-Knife": (125, 0, 125),
    "Pupil": (50, 200, 200), "Second-Instrument": (200, 200, 50), "Secondary-Knife": (200, 50, 200),
    "Default": (128, 128, 128)
}

# --------------------------------------------------------------------------
# ⚙️ 2. CORE LOGIC
#    You don't need to change the code below this line.
# --------------------------------------------------------------------------

def draw_annotations_on_frame(frame, annotations_for_frame):
    """Draws all annotations for a single frame."""
    overlay = frame.copy()

    for ann in annotations_for_frame:
        class_name = ann['class_name']
        color = COLOR_DICT.get(class_name, COLOR_DICT["Default"])

        # Draw segmentation mask with transparency
        if ann['segmentation']:
            poly = np.array(ann['segmentation'][0], dtype=np.int32).reshape((-1, 1, 2))
            cv2.fillPoly(overlay, [poly], color)

        # Draw bounding box
        if ann['bbox']:
            x, y, w, h = [int(v) for v in ann['bbox']]
            cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
            # Draw class label
            label = f"{class_name}"
            cv2.putText(frame, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

        # Draw keypoints (assumes max 2 keypoints: center and tip)
        if ann['keypoints']:
            # Center (first keypoint)
            if ann['keypoints'][2] == 2:
                center_x, center_y = int(ann['keypoints'][0]), int(ann['keypoints'][1])
                cv2.circle(frame, (center_x, center_y), 6, (255, 0, 0), -1, cv2.LINE_AA) # Blue center
            # Tip (second keypoint, if it exists)
            if len(ann['keypoints']) > 5 and ann['keypoints'][5] == 2:
                tip_x, tip_y = int(ann['keypoints'][3]), int(ann['keypoints'][4])
                cv2.circle(frame, (tip_x, tip_y), 6, (0, 0, 255), -1, cv2.LINE_AA) # Red tip

    # Apply the overlay with transparency
    cv2.addWeighted(overlay, 0.4, frame, 0.6, 0, frame)
    return frame

def create_video(video_folder_path, json_filename):
    """Creates an annotated video from a folder of frames and a JSON file."""
    video_name = os.path.basename(video_folder_path)
    input_json_path = os.path.join(video_folder_path, json_filename)

    if not os.path.exists(input_json_path):
        print(f"  [Error] Annotation file not found: {input_json_path}. Skipping video creation.")
        return

    print(f"  -> Loading annotations from: {json_filename}")
    with open(input_json_path, 'r') as f:
        data = json.load(f)

    category_map = {cat['id']: cat['name'] for cat in data['categories']}
    cat_id_to_num_kps = {cat['id']: len(cat.get('keypoints', [])) for cat in data['categories']}

    num_frames = len(data['videos'][0]['file_names'])
    all_frames_data = [[] for _ in range(num_frames)]

    for ann in data['annotations']:
        class_name = category_map.get(ann['category_id'], "Unknown")
        num_keypoints = cat_id_to_num_kps.get(ann['category_id'], 0)
        kp_stride = num_keypoints * 3

        for i in range(num_frames):
            # Check if the annotation exists for this frame
            if ann.get('segmentations') and i < len(ann['segmentations']) and ann['segmentations'][i]:
                frame_ann = {
                    'class_name': class_name,
                    'segmentation': ann['segmentations'][i],
                    'bbox': ann['bboxes'][i] if ann.get('bboxes') and i < len(ann['bboxes']) else None,
                    'keypoints': ann['keypoints'][i * kp_stride : (i + 1) * kp_stride] if ann.get('keypoints') and kp_stride > 0 else []
                }
                all_frames_data[i].append(frame_ann)

    # --- Video Creation ---
    output_video_name = f"{video_name}_from_{os.path.splitext(json_filename)[0]}.mp4"
    output_video_path = os.path.join(VIDEO_OUTPUT_DIR, output_video_name)

    width = data['videos'][0]['width']
    height = data['videos'][0]['height']

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out_writer = cv2.VideoWriter(output_video_path, fourcc, 30.0, (width, height))

    print(f"  -> Creating video: {output_video_name}")
    for i, frame_filename in enumerate(data['videos'][0]['file_names']):
        frame_path = os.path.join(video_folder_path, frame_filename)
        if not os.path.exists(frame_path):
            print(f"\n    [Warning] Frame not found: {frame_path}. Using a black frame instead.")
            frame = np.zeros((height, width, 3), dtype=np.uint8)
        else:
            frame = cv2.imread(frame_path)

        annotated_frame = draw_annotations_on_frame(frame, all_frames_data[i])

        out_writer.write(annotated_frame)
        print(f"    Processing frame {i+1}/{num_frames}", end='\r')

    out_writer.release()
    print(f"\n  -> ✅ Successfully created video: {output_video_path}")


# --------------------------------------------------------------------------
# ▶️ 3. EXECUTION
#    This block runs the script using the configuration above.
# --------------------------------------------------------------------------

def run_video_creation_script():
    """Finds videos and their annotation files to generate annotated videos."""
    os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)
    if not os.path.exists(DATASET_ROOT):
        print(f"❌ [Error] Input dataset directory not found at '{DATASET_ROOT}'")
        return

    if not VIDEOS_TO_PROCESS:
        print("No specific video provided. Processing all videos...")
        video_names = [os.path.basename(d) for d in glob.glob(os.path.join(DATASET_ROOT, '*')) if os.path.isdir(d)]
    else:
        video_names = VIDEOS_TO_PROCESS

    if not video_names:
        print(f"❌ No video subdirectories found in '{DATASET_ROOT}'")
        return

    print(f"Found {len(video_names)} video(s) to process: {sorted(video_names)}")

    for video_name in sorted(video_names):
        video_folder_path = os.path.join(DATASET_ROOT, video_name)
        print(f"\n--- Processing video: {video_name} ---")

        create_video(video_folder_path, INPUT_ANNOTATION_FILENAME)

        print(f"--- Finished video creation for {video_name} ---")

    print("\n🎉 All done!")

# --- Run the script ---
run_video_creation_script()

No specific video provided. Processing all videos...
Found 2 video(s) to process: ['0019', '0063']

--- Processing video: 0019 ---
  -> Loading annotations from: annotation_full.json
  -> Creating video: 0019_from_annotation_full.mp4
    Processing frame 3276/3276
  -> ✅ Successfully created video: visualized_videos_motion/0019_from_annotation_full.mp4
--- Finished video creation for 0019 ---

--- Processing video: 0063 ---
  -> Loading annotations from: annotation_full.json
  -> Creating video: 0063_from_annotation_full.mp4
    Processing frame 5799/5799
  -> ✅ Successfully created video: visualized_videos_motion/0063_from_annotation_full.mp4
--- Finished video creation for 0063 ---

🎉 All done!
