# This is a sample Jupyter Notebook

Below is an example of a code cell. 
Put your cursor into the cell and press Shift+Enter to execute it and select the next one, or click 'Run Cell' button.

Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.

To learn more about Jupyter Notebooks in PyCharm, see [help](https://www.jetbrains.com/help/pycharm/ipython-notebook-support.html).
For an overview of PyCharm, go to Help -> Learn IDE features or refer to [our documentation](https://www.jetbrains.com/help/pycharm/getting-started.html).

In [2]:
# Added/modified parts are marked with comments
import os
import cv2
import torch
import time
import subprocess
from datetime import datetime, timedelta
from collections import deque
from ultralytics import YOLO
from shutil import move

# Configuration
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_SIZE = 640
SEGMENT_DURATION = 600
PROGRESS_INTERVAL = 45
BBOX_PADDING = 0.15
FRAME_SKIP = 3
MIN_SEGMENT_LENGTH = 300
MIN_CONFIDENCE = 0.6  # Added confidence threshold
HISTORY_BUFFER_SIZE = 15  # Frames to remember for position consensus

class VideoProcessor:
    def __init__(self):
        self.model = YOLO('referee17february.pt').to(DEVICE)
        self.model.fuse()
        if DEVICE == 'cuda':
            self.model.half()
            torch.backends.cudnn.benchmark = True
        self.class_id = self._get_referee_class_id()
        self.detection_history = deque(maxlen=HISTORY_BUFFER_SIZE)  # Track recent positions

    def _get_referee_class_id(self):
        """Get class ID for 'referee' from model metadata"""
        return list(self.model.names.keys())[
            list(self.model.names.values()).index('referee')
        ]

    def process_videos(self, input_dir='dataVideo', output_dir='forLabel', used_dir='used'):
        """Main processing pipeline for all videos in input directory"""
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(used_dir, exist_ok=True)

        for video_file in [f for f in os.listdir(input_dir) if f.endswith('.mp4')]:
            video_path = os.path.join(input_dir, video_file)
            print(f"\n🚀 Starting processing: {video_file}")
            try:
                self._process_single_video(video_path, output_dir, used_dir)
                print(f"\n✅ Successfully processed: {video_file}")
            except Exception as e:
                print(f"\n❌ Error processing {video_file}: {str(e)}")

    def _process_single_video(self, video_path, output_dir, used_dir):
        """Process individual video file with split/processing logic"""
        duration = self._get_video_duration(video_path)
        
        if duration > 3600:  # Split videos longer than 1 hour
            print("⏳ Splitting into 1-hour chunks...")
            chunks = self._split_into_hourly_chunks(video_path, output_dir)
            print(f"📦 Created {len(chunks)} temporary chunks")
            
            # Process each chunk separately
            for i, chunk in enumerate(chunks, 1):
                print(f"\n🔧 Processing chunk {i}/{len(chunks)}")
                self._process_video_chunk(chunk, output_dir)
                os.remove(chunk)
                print(f"🧹 Cleaned temporary chunk {i}")
            
            self._safe_move(video_path, used_dir)
            print(f"\n📦 Moved original to: {used_dir}")
            return

        # Direct processing for short videos
        print("🔍 Processing single video chunk...")
        self._process_video_chunk(video_path, output_dir)
        self._safe_move(video_path, used_dir)
        print(f"\n📦 Moved original to: {used_dir}")

    def _process_video_chunk(self, video_path, output_dir):
        """Core video processing with segmentation and tracking"""
        cap = cv2.VideoCapture(video_path)
        orig_fps = cap.get(cv2.CAP_PROP_FPS)
        base_name = os.path.splitext(os.path.basename(video_path))[0]
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Processing state
        writer = None
        last_valid = None
        segment_num = 1
        segment_start_time = 0
        last_progress_update = time.time()
        detections = 0
        frame_counter = 0

        try:
            while cap.isOpened():
                # Frame skipping for performance
                for _ in range(FRAME_SKIP):
                    if not cap.grab():
                        break
                
                ret, frame = cap.retrieve()
                if not ret:
                    break

                current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000
                frame_counter += 1

                # Process frame with detection/tracking
                processed_frame, detected = self._process_frame(frame, last_valid)
                
                if processed_frame is not None:
                    # Initialize or update video writer
                    if writer is None or (current_time - segment_start_time >= SEGMENT_DURATION):
                        if writer:
                            writer.release()
                            print(f"💾 Saved segment {segment_num} ({timedelta(seconds=int(current_time - segment_start_time))})")
                        segment_num += 1
                        writer = self._create_writer(base_name, timestamp, output_dir, segment_num, orig_fps)
                        segment_start_time = current_time
                        print(f"🎬 Started segment {segment_num} at {timedelta(seconds=int(segment_start_time))}")

                    writer.write(processed_frame)
                    last_valid = processed_frame
                    if detected:
                        detections += 1

                # Progress reporting
                if time.time() - last_progress_update >= PROGRESS_INTERVAL:
                    elapsed = timedelta(seconds=int(current_time))
                    pct = (current_time / self._get_video_duration(video_path)) * 100
                    print(
                        f"\n⏱️ Progress [{elapsed}] "
                        f"{pct:.1f}% complete | "
                        f"Segments: {segment_num} | "
                        f"Recent detections: {detections}"
                    )
                    last_progress_update = time.time()
                    detections = 0

        finally:
            # Final segment handling
            if writer:
                final_duration = current_time - segment_start_time
                if final_duration >= MIN_SEGMENT_LENGTH:
                    writer.release()
                    print(f"💾 Saved final segment {segment_num} ({timedelta(seconds=int(final_duration))})")
                else:
                    writer.release()
                    os.remove(writer.filename)
                    print(f"🧹 Removed short segment (<{MIN_SEGMENT_LENGTH}s)")
            cap.release()

    def _process_frame(self, frame, last_valid):
        """Process frame with detection consensus and quality preservation"""
        original_frame = frame.copy()
        
        # Get detection with highest confidence that meets threshold
        resized_frame = cv2.resize(frame, (MODEL_SIZE, MODEL_SIZE), interpolation=cv2.INTER_LINEAR)
        tensor = torch.from_numpy(resized_frame).to(DEVICE)
        tensor = tensor.permute(2, 0, 1).float() / 255.0
        if DEVICE == 'cuda':
            tensor = tensor.half()
        tensor = tensor.unsqueeze(0)

        with torch.no_grad():
            results = self.model(tensor, verbose=False)[0]

        best_bbox = None
        if len(results.boxes) > 0:
            # Filter by confidence and get best detection
            valid_detections = [box for box in results.boxes.data if box[4] >= MIN_CONFIDENCE]
            if valid_detections:
                best_detection = max(valid_detections, key=lambda x: x[4])
                best_bbox = best_detection[:4].cpu().numpy()
                self.detection_history.append(best_bbox)  # Add to history buffer

        # Use most frequent recent detection for stability
        current_bbox = self._get_consensus_bbox()
        
        if current_bbox is not None:
            # Convert coordinates to original frame scale
            x_scale = frame.shape[1] / MODEL_SIZE
            y_scale = frame.shape[0] / MODEL_SIZE
            
            # Apply padding relative to detection size
            bbox_width = current_bbox[2] - current_bbox[0]
            bbox_height = current_bbox[3] - current_bbox[1]
            pad_x = int(bbox_width * BBOX_PADDING)
            pad_y = int(bbox_height * BBOX_PADDING)
            
            # Calculate coordinates with boundary checks
            x1 = max(0, int(current_bbox[0] * x_scale) - pad_x)
            y1 = max(0, int(current_bbox[1] * y_scale) - pad_y)
            x2 = min(frame.shape[1], int(current_bbox[2] * x_scale) + pad_x)
            y2 = min(frame.shape[0], int(current_bbox[3] * y_scale) + pad_y)
            
            # Crop and resize using original frame data for better quality
            cropped = original_frame[y1:y2, x1:x2]
            return cv2.resize(cropped, (MODEL_SIZE, MODEL_SIZE), interpolation=cv2.INTER_LINEAR), True
        
        # Fallback to last valid detection if available
        return last_valid, False

    def _get_consensus_bbox(self):
        """Get most common recent bbox from history buffer"""
        if not self.detection_history:
            return None
            
        # Find most frequent bbox position using histogram
        bbox_counts = {}
        for bbox in self.detection_history:
            key = tuple(map(int, bbox))
            bbox_counts[key] = bbox_counts.get(key, 0) + 1
        
        if bbox_counts:
            return max(bbox_counts.items(), key=lambda x: x[1])[0]
        return None


    def _padded_coords(self, coords, shape):
        """Calculate padded coordinates with boundary checks"""
        h, w = shape[:2]
        dx = int((coords[2]-coords[0])*BBOX_PADDING)
        dy = int((coords[3]-coords[1])*BBOX_PADDING)
        return (
            max(0, int(coords[0])-dx),
            max(0, int(coords[1])-dy),
            min(w, int(coords[2])+dx),
            min(h, int(coords[3])+dy)
        )

    def _create_writer(self, base_name, timestamp, output_dir, segment_num, fps):
        """Create video writer with standardized naming"""
        output_path = os.path.join(
            output_dir,
            f"{base_name}_{timestamp}_part{segment_num:03d}.mp4"
        )
        return cv2.VideoWriter(
            output_path,
            cv2.VideoWriter_fourcc(*'mp4v'),
            fps,
            (MODEL_SIZE, MODEL_SIZE)
        )

    def _split_into_hourly_chunks(self, video_path, output_dir):
        """Split long videos using FFmpeg"""
        base_name = os.path.splitext(os.path.basename(video_path))[0]
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_pattern = os.path.join(output_dir, f"{base_name}_{timestamp}_temp_part%03d.mp4")
        
        try:
            subprocess.run([
                'ffmpeg', '-i', video_path,
                '-c', 'copy',
                '-map', '0',
                '-segment_time', '01:00:00',
                '-f', 'segment',
                '-reset_timestamps', '1',
                '-loglevel', 'error',
                output_pattern
            ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        except subprocess.CalledProcessError as e:
            print(f"❌ FFmpeg error: {e.stderr.decode()}")
            raise

        return sorted([
            os.path.join(output_dir, f) 
            for f in os.listdir(output_dir) 
            if f.startswith(f"{base_name}_{timestamp}_temp")
        ])

    def _get_video_duration(self, video_path):
        """Get video duration in seconds"""
        cap = cv2.VideoCapture(video_path)
        frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        fps = cap.get(cv2.CAP_PROP_FPS)
        cap.release()
        return frames / fps if fps else 0

    def _safe_move(self, src, dest_dir):
        """Atomic file move with retry logic"""
        dest = os.path.join(dest_dir, os.path.basename(src))
        for attempt in range(5):
            try:
                if os.path.exists(dest):
                    os.remove(dest)
                move(src, dest)
                print(f"📤 Moved {os.path.basename(src)} successfully")
                return
            except Exception as e:
                if attempt == 4:
                    print(f"❌ Failed to move {os.path.basename(src)}: {str(e)}")
                time.sleep(2 ** attempt)

if __name__ == "__main__":
    processor = VideoProcessor()
    print("🚀 Starting video processing pipeline...")
    processor.process_videos()
    print("\n🎉 All processing completed successfully!")


YOLO11x summary (fused): 464 layers, 56,828,179 parameters, 0 gradients, 194.4 GFLOPs
🚀 Starting video processing pipeline...

🚀 Starting processing: videoSplitted.mp4
🔍 Processing single video chunk...
🎬 Started segment 2 at 0:00:00

⏱️ Progress [0:00:45] 3.6% complete | Segments: 2 | Recent detections: 0

⏱️ Progress [0:01:30] 7.2% complete | Segments: 2 | Recent detections: 0


KeyboardInterrupt: 