In [1]:
import os
import json
import math
import subprocess
from pathlib import Path
from IPython.display import display, HTML

"""
Jupyter Notebook script for partitioning a concatenated video and its corresponding JSON log file
into training, validation, and test sets with an 70:15:15 ratio.

This script is designed to be run within a Jupyter Notebook cell.
"""

# Configure these parameters as needed
INPUT_VIDEO = 'concatenated_output.mp4'  # Input video file
INPUT_JSON = 'concatenated_log.json'     # Input JSON log file
OUTPUT_DIR = '.'                         # Output directory
RATIO = (0.70, 0.15, 0.15)                  # Split ratio (train, val, test)

def display_progress(message, color='black'):
    """Display progress message in Jupyter notebook"""
    display(HTML(f"<p style='color:{color};'>{message}</p>"))

def get_video_duration(video_path):
    """Get the duration of a video file in seconds"""
    cmd = [
        'ffprobe',
        '-v', 'error',
        '-show_entries', 'format=duration',
        '-of', 'json',
        video_path
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        data = json.loads(result.stdout)
        return float(data['format']['duration'])
    except (subprocess.SubprocessError, json.JSONDecodeError, KeyError):
        display_progress(f"Warning: Could not determine duration of {video_path}", 'orange')
        return 0

def get_frame_count(video_path):
    """Get the total number of frames in a video"""
    cmd = [
        'ffprobe',
        '-v', 'error',
        '-select_streams', 'v:0',
        '-count_packets',
        '-show_entries', 'stream=nb_read_packets',
        '-of', 'json',
        video_path
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        data = json.loads(result.stdout)
        return int(data['streams'][0]['nb_read_packets'])
    except (subprocess.SubprocessError, json.JSONDecodeError, KeyError, IndexError):
        display_progress(f"Warning: Could not determine frame count of {video_path}", 'orange')
        # Fallback to estimating frames from duration and assuming 30fps
        duration = get_video_duration(video_path)
        return int(duration * 30)

def get_video_fps(video_path):
    """Get the frames per second of a video"""
    cmd = [
        'ffprobe',
        '-v', 'error',
        '-select_streams', 'v:0',
        '-show_entries', 'stream=r_frame_rate',
        '-of', 'json',
        video_path
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        data = json.loads(result.stdout)
        fps_str = data['streams'][0]['r_frame_rate']
        
        # Parse fraction like "30000/1001" to get actual fps
        if '/' in fps_str:
            num, denom = map(int, fps_str.split('/'))
            return num / denom
        else:
            return float(fps_str)
    except (subprocess.SubprocessError, json.JSONDecodeError, KeyError, IndexError, ValueError):
        display_progress(f"Warning: Could not determine FPS of {video_path}, assuming 30fps", 'orange')
        return 30.0

def extract_video_segment(input_video, output_video, start_time, duration):
    """Extract a segment from a video file"""
    cmd = [
        'ffmpeg',
        '-i', input_video,
        '-ss', str(start_time),
        '-t', str(duration),
        '-c', 'copy',  # Copy without re-encoding
        '-y',  # Overwrite output file
        output_video
    ]
    
    try:
        subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        return os.path.exists(output_video)
    except subprocess.SubprocessError:
        display_progress(f"Warning: Failed to extract segment to {output_video}", 'orange')
        
        # Try again with re-encoding (more reliable but slower)
        display_progress("Trying with re-encoding...", 'blue')
        cmd = [
            'ffmpeg',
            '-i', input_video,
            '-ss', str(start_time),
            '-t', str(duration),
            '-c:v', 'libx264',  # Re-encode video
            '-c:a', 'aac',      # Re-encode audio
            '-y',               # Overwrite output file
            output_video
        ]
        
        try:
            subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            return os.path.exists(output_video)
        except subprocess.SubprocessError:
            display_progress(f"Error: Failed to extract segment to {output_video} even with re-encoding", 'red')
            return False

def split_json_by_frames(json_data, train_end_frame, val_end_frame, total_frames):
    """Split JSON log data into training, validation, and test sets based on frame numbers"""
    train_segments = []
    val_segments = []
    test_segments = []
    
    # Get the segments from the first (and only) entry
    all_segments = json_data[0]['segments']
    
    # Adjust frame boundaries to ensure segments aren't split
    adjusted_train_end = train_end_frame
    adjusted_val_end = val_end_frame
    
    for segment in all_segments:
        start_frame = segment['start_frame']
        end_frame = segment['end_frame']
        
        # Determine which set this segment belongs to
        if end_frame <= adjusted_train_end:
            # Segment completely in training set
            train_segments.append(segment)
        elif start_frame < adjusted_train_end:
            # Segment spans training and validation
            # Move the boundary to include the full segment in training
            adjusted_train_end = end_frame
            train_segments.append(segment)
        elif end_frame <= adjusted_val_end:
            # Segment completely in validation set
            # Adjust frames to be relative to the start of validation set
            adjusted_segment = segment.copy()
            adjusted_segment['start_frame'] -= adjusted_train_end
            adjusted_segment['end_frame'] -= adjusted_train_end
            val_segments.append(adjusted_segment)
        elif start_frame < adjusted_val_end:
            # Segment spans validation and test
            # Move the boundary to include the full segment in validation
            adjusted_val_end = end_frame
            adjusted_segment = segment.copy()
            adjusted_segment['start_frame'] -= adjusted_train_end
            adjusted_segment['end_frame'] -= adjusted_train_end
            val_segments.append(adjusted_segment)
        else:
            # Segment in test set
            # Adjust frames to be relative to the start of test set
            adjusted_segment = segment.copy()
            adjusted_segment['start_frame'] -= adjusted_val_end
            adjusted_segment['end_frame'] -= adjusted_val_end
            test_segments.append(adjusted_segment)
    
    # Create new JSON entries for each set
    filename = json_data[0]['filename']
    base_name = os.path.splitext(filename)[0]
    
    train_json = [{
        "filename": f"{base_name}_train.mp4",
        "segments": train_segments
    }]
    
    val_json = [{
        "filename": f"{base_name}_val.mp4",
        "segments": val_segments
    }]
    
    test_json = [{
        "filename": f"{base_name}_test.mp4",
        "segments": test_segments
    }]
    
    # Calculate actual frame counts for each set
    train_frames = adjusted_train_end
    val_frames = adjusted_val_end - adjusted_train_end
    test_frames = total_frames - adjusted_val_end
    
    # Calculate actual percentages
    train_pct = train_frames / total_frames * 100
    val_pct = val_frames / total_frames * 100
    test_pct = test_frames / total_frames * 100
    
    display_progress("Actual split (adjusted to preserve segments):")
    display_progress(f"  Training:    {train_frames} frames ({train_pct:.1f}%)")
    display_progress(f"  Validation:  {val_frames} frames ({val_pct:.1f}%)")
    display_progress(f"  Testing:     {test_frames} frames ({test_pct:.1f}%)")
    
    return (train_json, val_json, test_json), (adjusted_train_end, adjusted_val_end, total_frames)

def partition_data(input_video, input_json, output_dir='.', ratio=(0.8, 0.1, 0.1)):
    """Partition video and JSON data into training, validation, and test sets"""
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get video information
    display_progress("Analyzing video...")
    total_frames = get_frame_count(input_video)
    duration = get_video_duration(input_video)
    fps = get_video_fps(input_video)
    
    display_progress(f"Video info: {total_frames} frames, {duration:.2f} seconds, {fps:.2f} fps")
    
    # Calculate frame boundaries based on ratio
    train_ratio, val_ratio, test_ratio = ratio
    train_end_frame = math.floor(total_frames * train_ratio)
    val_end_frame = train_end_frame + math.floor(total_frames * val_ratio)
    
    display_progress("Initial split:")
    display_progress(f"  Training:    0 to {train_end_frame} ({train_ratio*100:.1f}%)")
    display_progress(f"  Validation:  {train_end_frame+1} to {val_end_frame} ({val_ratio*100:.1f}%)")
    display_progress(f"  Testing:     {val_end_frame+1} to {total_frames} ({test_ratio*100:.1f}%)")
    
    # Read JSON data
    display_progress("Reading JSON data...")
    with open(input_json, 'r') as f:
        json_data = json.load(f)
    
    # Split JSON data
    display_progress("Splitting JSON data...")
    (train_json, val_json, test_json), (adj_train_end, adj_val_end, _) = split_json_by_frames(
        json_data, train_end_frame, val_end_frame, total_frames
    )
    
    # Write JSON files
    json_train_path = os.path.join(output_dir, 'train_set.json')
    json_val_path = os.path.join(output_dir, 'val_set.json')
    json_test_path = os.path.join(output_dir, 'test_set.json')
    
    with open(json_train_path, 'w') as f:
        json.dump(train_json, f, indent=2)
    
    with open(json_val_path, 'w') as f:
        json.dump(val_json, f, indent=2)
    
    with open(json_test_path, 'w') as f:
        json.dump(test_json, f, indent=2)
    
    display_progress("Created JSON files:")
    display_progress(f"  {json_train_path}")
    display_progress(f"  {json_val_path}")
    display_progress(f"  {json_test_path}")
    
    # Calculate time boundaries for video splitting
    train_end_time = adj_train_end / fps
    val_end_time = adj_val_end / fps
    
    # Extract video segments
    video_train_path = os.path.join(output_dir, 'train_set.mp4')
    video_val_path = os.path.join(output_dir, 'val_set.mp4')
    video_test_path = os.path.join(output_dir, 'test_set.mp4')
    
    display_progress("Extracting video segments...")
    
    # Training set (from start to train_end_time)
    display_progress(f"  Extracting training set: 0 to {train_end_time:.2f}s", 'blue')
    extract_video_segment(input_video, video_train_path, 0, train_end_time)
    
    # Validation set (from train_end_time to val_end_time)
    display_progress(f"  Extracting validation set: {train_end_time:.2f}s to {val_end_time:.2f}s", 'blue')
    extract_video_segment(input_video, video_val_path, train_end_time, val_end_time - train_end_time)
    
    # Test set (from val_end_time to end)
    display_progress(f"  Extracting test set: {val_end_time:.2f}s to end", 'blue')
    extract_video_segment(input_video, video_test_path, val_end_time, duration - val_end_time)
    
    # Verify the output files
    outputs = [
        video_train_path, video_val_path, video_test_path,
        json_train_path, json_val_path, json_test_path
    ]
    
    for output in outputs:
        if os.path.exists(output):
            size = os.path.getsize(output) / (1024 * 1024)  # Size in MB
            display_progress(f"Created {output} ({size:.2f} MB)", 'green')
        else:
            display_progress(f"Failed to create {output}", 'red')
    
    display_progress("Partitioning complete!", 'green')
    return {
        'train_video': video_train_path,
        'val_video': video_val_path,
        'test_video': video_test_path,
        'train_json': json_train_path,
        'val_json': json_val_path,
        'test_json': json_test_path
    }

# Check if inputs exist
if not os.path.exists(INPUT_VIDEO):
    display_progress(f"Error: Input video file '{INPUT_VIDEO}' not found", 'red')
elif not os.path.exists(INPUT_JSON):
    display_progress(f"Error: Input JSON file '{INPUT_JSON}' not found", 'red')
else:
    # Run the partitioning
    results = partition_data(INPUT_VIDEO, INPUT_JSON, OUTPUT_DIR, RATIO)