In [None]:
import cv2
import time
import torch
import numpy as np
import os
import math
from collections import deque
from utils.datasets import letterbox
from utils.torch_utils import select_device
from models.experimental import attempt_load
from utils.plots import output_to_keypoint, plot_skeleton_kpts
from utils.general import non_max_suppression_kpt, strip_optimizer
from torchvision import transforms

class FallDetector:
    def __init__(self, poseweights='yolov7-w6-pose.pt', device='0'):
        """
        Initialize the Fall Detector with parameters as defined in the paper
        "Enhanced Fall Detection Using YOLOv7-W6-Pose for Real-Time Elderly Monitoring"
        
        Key parameters:
        - LENGTH_FACTOR_ALPHA (α): Used in height condition formula (Section 3.1)
        - VELOCITY_THRESHOLD: Threshold for fall speed detection (Section 3.2)
        - LEG_ANGLE_THRESHOLD: Degrees threshold for leg angles (Section 3.2)
        - TORSO_ANGLE_THRESHOLD: Degrees threshold for torso orientation (Section 3.2)
        - ASPECT_RATIO_THRESHOLD: Width/height ratio threshold (Section 3.1)
        - CONFIDENCE_THRESHOLD: Minimum keypoint confidence for reliable detection
        """
        print(f"Initializing Fall Detector with weights: {poseweights} on device: {device}")
        
        # Select the appropriate device
        self.device = select_device(device)
        self.half = self.device.type != 'cpu'
        
        # Load model
        self.model = attempt_load(poseweights, map_location=self.device)
        self.model.eval()
        
        # Create output directory if it doesn't exist
        os.makedirs('output', exist_ok=True)
        
        # Threshold parameters as defined in the paper
        self.LENGTH_FACTOR_ALPHA = 0.5  # α in the height condition formula
        self.VELOCITY_THRESHOLD = 1.0    # px/frame for fall speed
        self.LEG_ANGLE_THRESHOLD = 45    # degrees for leg angles
        self.TORSO_ANGLE_THRESHOLD = 50  # degrees for torso orientation
        self.ASPECT_RATIO_THRESHOLD = 0.8 # width/height ratio
        self.CONFIDENCE_THRESHOLD = 0.4  # minimum keypoint confidence
        self.TARGET_FPS = 25
        
        # State tracking variables
        self.prev_keypoints = None
        self.velocity_buffer = deque(maxlen=3)  # tracks vertical speed
        self.fall_buffer = deque(maxlen=2)      # confirmation buffer
        self.prev_frame_time = None
        self.fall_start_time = None
        self.prev_shoulder_y = None
        
        # Fall detection status
        self.fall_detected = False
    
    def calculate_euclidean_distance(self, point1, point2):
        """
        Calculate Euclidean distance between two points
        Used in the paper to measure distances between key body points,
        particularly for the Lfactor (length factor) calculation in Section 3.1
        
        Args:
            point1, point2: Coordinate points (x,y)
        Returns:
            Euclidean distance between the points
        """
        return math.hypot(point1[0]-point2[0], point1[1]-point2[1])

    def calculate_angle(self, a, b, c):
        """
        Calculate angle between three points (in degrees)
        Used in the paper for calculating leg angles (Section 3.2)
        
        Args:
            a, b, c: Three points where b is the vertex
        Returns:
            Angle in degrees
        """
        try:
            ba = np.array([a[0]-b[0], a[1]-b[1]])
            bc = np.array([c[0]-b[0], c[1]-b[1]])
            cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc) + 1e-6)
            return np.degrees(np.arccos(np.clip(cosine_angle, -1.0, 1.0)))
        except:
            return 180  # return maximum angle if calculation fails

    def calculate_torso_angle(self, shoulders, hips):
        """
        Calculate torso angle relative to vertical axis
        Implements the torso orientation assessment described in Section 3.2
        of the paper to detect when the torso is horizontal (fallen state)
        
        Args:
            shoulders: list of shoulder points [(x,y), (x,y)]
            hips: list of hip points [(x,y), (x,y)]
        Returns:
            angle in degrees between torso and vertical axis
        """
        shoulder_center = np.mean(shoulders, axis=0)
        hip_center = np.mean(hips, axis=0)
        vertical_vector = np.array([0, 1])
        torso_vector = np.array([hip_center[0]-shoulder_center[0], 
                                hip_center[1]-shoulder_center[1]])
        
        if np.linalg.norm(torso_vector) < 1e-6:
            return 90  # neutral angle if points overlap
            
        cosine = np.dot(torso_vector, vertical_vector) / (np.linalg.norm(torso_vector) + 1e-6)
        return np.degrees(np.arccos(np.clip(cosine, -1.0, 1.0)))

    def detect_fall(self, keypoints):
        """
        Main fall detection function implementing the paper's algorithm from Sections 3.1 and 3.2
        Combines multiple conditions (height, velocity, angles, aspect ratio) to detect falls
        
        Args:
            keypoints: Array of 17 keypoints with (x,y,confidence)
        Returns:
            tuple: (is_fall, state, condition_info)
        """
        # Keypoint indices as defined in the paper
        NOSE = 0
        LEFT_SHOULDER = 5
        RIGHT_SHOULDER = 6
        LEFT_HIP = 11
        RIGHT_HIP = 12
        LEFT_KNEE = 13
        RIGHT_KNEE = 14
        LEFT_ANKLE = 15
        RIGHT_ANKLE = 16
        
        try:
            # Extract keypoints with confidence check
            kp = {}
            
            # Reshape keypoints to get (x, y, conf) format for each keypoint
            reshaped_kpts = keypoints.reshape(-1, 3)
            
            # Extract specific keypoints
            kp['nose'] = reshaped_kpts[NOSE]
            kp['left_shoulder'] = reshaped_kpts[LEFT_SHOULDER]
            kp['right_shoulder'] = reshaped_kpts[RIGHT_SHOULDER]
            kp['left_hip'] = reshaped_kpts[LEFT_HIP]
            kp['right_hip'] = reshaped_kpts[RIGHT_HIP]
            kp['left_knee'] = reshaped_kpts[LEFT_KNEE]
            kp['right_knee'] = reshaped_kpts[RIGHT_KNEE]
            kp['left_ankle'] = reshaped_kpts[LEFT_ANKLE]
            kp['right_ankle'] = reshaped_kpts[RIGHT_ANKLE]
            
            # Confidence check for all keypoints
            if any(point[2] < self.CONFIDENCE_THRESHOLD for point in kp.values()):
                return False, "low_confidence", []

            # Get coordinates (convert to tuples for clarity)
            ls = (kp['left_shoulder'][0], kp['left_shoulder'][1])
            rs = (kp['right_shoulder'][0], kp['right_shoulder'][1])
            lh = (kp['left_hip'][0], kp['left_hip'][1])
            rh = (kp['right_hip'][0], kp['right_hip'][1])
            lk = (kp['left_knee'][0], kp['left_knee'][1])
            rk = (kp['right_knee'][0], kp['right_knee'][1])
            la = (kp['left_ankle'][0], kp['left_ankle'][1])
            ra = (kp['right_ankle'][0], kp['right_ankle'][1])

            """ 1. HEIGHT CONDITION (Paper Section 3.1) """
            # Calculate length factor (Lfactor) as Euclidean distance
            torso_mid = ((lh[0] + rh[0])/2, (lh[1] + rh[1])/2)
            Lfactor = self.calculate_euclidean_distance(ls, torso_mid)
            
            # Get vertical positions
            max_feet_y = max(la[1], ra[1])
            min_shoulder_y = min(ls[1], rs[1])
            
            # Paper's height condition: yl ≤ yFl + α·Lfactor
            height_cond = min_shoulder_y >= (max_feet_y - self.LENGTH_FACTOR_ALPHA * Lfactor)
            
            """ 2. VELOCITY CONDITION (Paper Section 3.2) """
            current_time = time.time()
            vertical_speed = 0
            current_min_y = min(ls[1], rs[1])
            
            if self.prev_shoulder_y is not None and self.prev_frame_time is not None:
                time_elapsed = current_time - self.prev_frame_time
                if time_elapsed > 0:
                    vertical_speed = (current_min_y - self.prev_shoulder_y) / time_elapsed
                    self.velocity_buffer.append(abs(vertical_speed))
            
            avg_speed = sum(self.velocity_buffer)/len(self.velocity_buffer) if self.velocity_buffer else 0
            speed_cond = avg_speed >= self.VELOCITY_THRESHOLD
            
            """ 3. ANGLE CONDITIONS (Paper Section 3.2) """
            left_leg_angle = self.calculate_angle(lh, lk, la)
            right_leg_angle = self.calculate_angle(rh, rk, ra)
            leg_angle_cond = min(left_leg_angle, right_leg_angle) < self.LEG_ANGLE_THRESHOLD
            
            # Torso orientation (not explicitly in paper but mentioned in text)
            torso_angle = self.calculate_torso_angle([ls, rs], [lh, rh])
            torso_cond = torso_angle > self.TORSO_ANGLE_THRESHOLD
            
            """ 4. ASPECT RATIO CONDITION (Paper Section 3.1) """
            # Body orientation ratio: width/height
            body_width = abs(ls[0] - rs[0])
            head_to_feet = abs(kp['nose'][1] - max_feet_y)
            orientation_ratio = body_width / (head_to_feet + 1e-6)
            aspect_cond = orientation_ratio > self.ASPECT_RATIO_THRESHOLD
            
            """ FALL DECISION LOGIC (Paper Section 3) """
            # Combined conditions - at least 2 must be true
            conditions_met = sum([height_cond, speed_cond, leg_angle_cond, torso_cond, aspect_cond])
            
            # State determination
            current_state = "normal"
            conditions_info = []
            
            if height_cond:
                if speed_cond:  # Rapid descent
                    current_state = "falling"
                    self.fall_start_time = current_time
                    conditions_info.append(f"speed:{avg_speed:.1f}px/s")
                elif torso_cond and self.fall_start_time and (current_time - self.fall_start_time < 1.0):
                    current_state = "fallen"
                    conditions_info.append("horizontal")
            
            if leg_angle_cond:
                conditions_info.append(f"leg_angle:{min(left_leg_angle, right_leg_angle):.0f}°")
            
            if aspect_cond:
                conditions_info.append(f"aspect:{orientation_ratio:.2f}")
            
            # Final decision with confirmation buffer
            is_fall = conditions_met >= 2
            self.fall_buffer.append(is_fall)
            final_detection = sum(self.fall_buffer) >= 2 if len(self.fall_buffer) >= 1 else is_fall
            
            if final_detection:
                current_state = "fallen"
                self.fall_detected = True
            else:
                self.fall_detected = False
            
            # Update tracking variables
            self.prev_keypoints = kp
            self.prev_shoulder_y = current_min_y
            self.prev_frame_time = current_time
            
            # Diagnostic information
            conditions_info.extend([
                f"height:{'Y' if height_cond else 'N'}",
                f"speed:{'Y' if speed_cond else 'N'}",
                f"leg_angle:{'Y' if leg_angle_cond else 'N'}",
                f"torso:{'Y' if torso_cond else 'N'}",
                f"aspect:{'Y' if aspect_cond else 'N'}",
                f"conf:{min(p[2] for p in kp.values()):.2f}"
            ])
            
            return final_detection, current_state, conditions_info
            
        except Exception as e:
            print(f"Detection error: {str(e)}")
            return False, "error", [f"Error: {str(e)}"]

    def process_frame(self, frame):
        """
        Process a single frame for fall detection
        
        Args:
            frame: Video frame to process
            
        Returns:
            frame: Processed frame with detections
            is_fall: Boolean indicating whether a fall was detected
            state: Current state (normal, falling, fallen)
            condition_info: List of conditions that triggered the detection
        """
        # Preprocess image
        orig_image = frame.copy()
        image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
        
        # Resize image while maintaining aspect ratio
        frame_height, frame_width = orig_image.shape[:2]
        image = letterbox(image, (frame_width), stride=64, auto=True)[0]
        
        # Convert to tensor
        image_ = image.copy()
        image = transforms.ToTensor()(image)
        image = torch.tensor(np.array([image.numpy()]))
        
        image = image.to(self.device)
        image = image.float()
        
        # Inference
        with torch.no_grad():
            output, _ = self.model(image)
            
        # Post-process
        output = non_max_suppression_kpt(output, 0.25, 0.65, nc=self.model.yaml['nc'], nkpt=self.model.yaml['nkpt'], kpt_label=True)
        output = output_to_keypoint(output)
        
        # Convert back to BGR for display
        img = image[0].permute(1, 2, 0) * 255
        img = img.cpu().numpy().astype(np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        
        # Initialize fall status and state for this frame
        is_fall = False
        current_state = "normal"
        condition_info = []
        
        # Process each person detected
        for idx in range(output.shape[0]):
            # Draw skeleton and keypoints
            plot_skeleton_kpts(img, output[idx, 7:].T, 3)
            
            # Calculate improved bounding box based on keypoints (YouTube approach)
            # Find the minimum and maximum x,y coordinates from all keypoints
            kpts = output[idx, 7:].reshape(-1, 3)
            
            # Initialize with first keypoint
            x_values = [kpt[0] for kpt in kpts if kpt[2] > 0.5]  # Only use keypoints with confidence > 0.5
            y_values = [kpt[1] for kpt in kpts if kpt[2] > 0.5]
            
            if x_values and y_values:  # Check if we have valid keypoints
                xmin, ymin = min(x_values), min(y_values)
                xmax, ymax = max(x_values), max(y_values)
                
                # Add padding to make bounding box a bit larger
                padding = 10
                xmin = max(0, xmin - padding)
                ymin = max(0, ymin - padding)
                xmax = xmax + padding
                ymax = ymax + padding
                
                # Calculate aspect ratio for reference (not used in detection)
                width = xmax - xmin
                height = ymax - ymin
                bbox_aspect_ratio = width / height if height > 0 else 0
                
                # Calculate center
                cx = int((xmin + xmax) // 2)
                cy = int((ymin + ymax) // 2)
                
                # For debugging: show aspect ratio on frame
                cv2.putText(img, f"Ratio: {bbox_aspect_ratio:.2f}", (10, 30), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            else:
                # Fallback to original bounding box if no valid keypoints
                x1, y1, x2, y2 = output[idx, 0], output[idx, 1], output[idx, 2], output[idx, 3]
                xmin, ymin = x1, y1
                xmax, ymax = x2, y2
                cx, cy = int((x1 + x2) // 2), int((y1 + y2) // 2)
            
            # Get key points for this person
            key_points = output[idx, 7:]
            
            # Detect fall for this person using enhanced algorithm
            person_fall, person_state, person_conditions = self.detect_fall(key_points)
            
            # If any person is falling, set global fall status
            if person_fall:
                is_fall = True
                current_state = person_state
                condition_info = person_conditions
                
                # Add visual indication of fall
                status_text = f"FALL DETECTED: {person_state.upper()}"
                cv2.putText(img, status_text, (50, 50), 
                           cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
                
                # Draw the bounding box in red for a fall
                cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255), 2)
                
                # For YouTube-style visual, add a colored rectangle at the center
                cv2.rectangle(img, (cx-10, cy-10), (cx+10, cy+10), (84, 61, 247), -1)
                
                # Add condition info to the frame
                for i, cond in enumerate(person_conditions[:3]):  # Show first 3 conditions only
                    cv2.putText(img, cond, (10, 60 + i*25), 
                              cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 1)
            else:
                # Draw normal bounding box in green for no fall
                cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 255, 0), 1)
                
                # Show normal state
                cv2.putText(img, f"State: {person_state}", (10, 60), 
                          cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
        
        return img, is_fall, current_state, condition_info

def run_fall_detection(poseweights='yolov7-w6-pose.pt', source='pose.mp4', device='cpu', display=True, save_output=True, 
                     save_false_detections=True, detector=None, batch_mode=False):
    """
    Run fall detection on a video or webcam feed
    
    Args:
        poseweights: Path to the YOLOv7 pose weights
        source: Path to video file or webcam ID (0, 1, etc.)
        device: Device to run inference on ('cpu' or '0', '1', etc. for GPU)
        display: Whether to show video with detections in real-time
        save_output: Whether to save the output video
        save_false_detections: Whether to save false detections to a file
        detector: Optional pre-initialized FallDetector instance (for batch processing)
        batch_mode: Whether this is being run as part of a batch process (suppresses some output)
    """
    # Initialize the fall detector or use the provided one
    if detector is None:
        detector = FallDetector(poseweights=poseweights, device=device)
    
    # Parse the input source
    input_path = source
    if source.isnumeric():
        input_path = int(source)
    
    # Open video capture
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        print(f"Error: Could not open video source {source}")
        return {
            'frames_processed': 0,
            'average_fps': 0,
            'false_detections': 0
        }
    
    # Get video properties
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Check if we're processing a Le2i dataset video
    is_le2i = False
    annotation_file = None
    ground_truth_fall_frames = []
    expected_fall_type = None
    
    # Check if it's a Le2i video by looking at the path
    if isinstance(input_path, str) and "Le2i_Sorted" in input_path and "Videos" in input_path:
        is_le2i = True
        
        # Create a dedicated output folder for Le2i dataset results
        le2i_output_dir = os.path.join('output', 'le2i_results')
        os.makedirs(le2i_output_dir, exist_ok=True)
        
        # Extract environment and Fall/Non Fall type from path for folder organization
        path_parts = input_path.split(os.sep)
        try:
            # Find Fall or Non Fall in the path
            fall_idx = -1
            for i, part in enumerate(path_parts):
                if part in ["Fall", "Non Fall"]:
                    fall_idx = i
                    expected_fall_type = part  # Save whether this is a Fall or Non Fall video
                    break
            
            if fall_idx >= 0 and fall_idx + 1 < len(path_parts):
                fall_type = path_parts[fall_idx]
                env_type = path_parts[fall_idx + 1]
                
                # Create subfolder for this environment and fall type
                env_fall_dir = os.path.join(le2i_output_dir, f"{env_type}_{fall_type}")
                os.makedirs(env_fall_dir, exist_ok=True)
            else:
                env_fall_dir = le2i_output_dir
        except:
            env_fall_dir = le2i_output_dir
            
        video_filename = os.path.basename(input_path)
        
        # Construct path to annotation file by replacing Videos with Annotation_files and changing extension
        annotation_path = input_path.replace("Videos", "Annotation_files").rsplit(".", 1)[0] + ".txt"
        
        # Check if annotation file exists
        if os.path.exists(annotation_path):
            try:
                with open(annotation_path, 'r') as f:
                    lines = f.readlines()
                    
                    # The first 2 lines might be metadata (number of frames, etc.)
                    # Skip them if they don't contain fall annotations
                    data_lines = []
                    for line in lines:
                        # Try to parse as comma-separated values
                        if ',' in line:
                            data_lines.append(line)
                        # Also try space-separated values
                        elif len(line.strip().split()) >= 2:
                            # Convert space-separated to comma-separated
                            values = line.strip().split()
                            data_lines.append(','.join(values))
                    
                    for line in data_lines:
                        parts = line.strip().split(',')
                        if len(parts) >= 2:
                            try:
                                # Format can be "frame_number,label,x,y,width,height" or similar
                                frame_num = int(parts[0])
                                label = int(parts[1])  # 1 for Fall, other values for no fall
                                
                                # If the label indicates a fall (usually 1, 7, or 8 in Le2i), 
                                # add this frame to ground truth fall frames
                                if label in [1, 7, 8]:  # Common fall labels in Le2i
                                    ground_truth_fall_frames.append(frame_num)
                            except (ValueError, IndexError):
                                # Skip lines that can't be parsed correctly
                                continue
                
                if not batch_mode:
                    print(f"Loaded {len(ground_truth_fall_frames)} annotated fall frames from {annotation_path}")
            except Exception as e:
                print(f"Error loading annotation file: {str(e)}")
        else:
            print(f"Warning: Annotation file not found at {annotation_path}")
            is_le2i = False
    
    # Setup output video writer if requested
    out = None
    if save_output:
        if isinstance(input_path, int):
            # For webcam
            output_path = os.path.join('output', f"webcam_fall_detection.mp4")
        elif is_le2i:
            # For Le2i dataset videos, save to the dedicated folder
            video_name = os.path.basename(input_path).split('.')[0]
            output_path = os.path.join(env_fall_dir, f"{video_name}_fall_detection.mp4")
        else:
            # For regular video files
            filename = os.path.basename(input_path).split('.')[0]
            output_path = os.path.join('output', f"{filename}_fall_detection.mp4")
        
        # Create VideoWriter
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
        if not batch_mode:
            print(f"Output will be saved to: {output_path}")
    
    # Process video frames
    frame_count = 0
    total_fps = 0
    
    # For performance tracking
    false_positives = 0
    false_negatives = 0
    true_positives = 0
    true_negatives = 0
    
    # For false detection tracking
    false_detections = []
    
    if not batch_mode:
        print(f"Starting fall detection on {os.path.basename(input_path) if isinstance(input_path, str) else 'webcam'}...")
        print(f"Total frames: {total_frames}")
    
    # Process the video frames with a timeout mechanism to prevent hanging
    start_time = time.time()
    max_process_time = 300  # 5 minutes max per video
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_count += 1
        
        # Print progress periodically if not in batch mode
        if not batch_mode and frame_count % 10 == 0:
            print(f"Processing frame {frame_count}/{total_frames if total_frames > 0 else 'unknown'}")
        
        # Check if we've been processing too long
        current_time = time.time()
        if current_time - start_time > max_process_time:
            print(f"Warning: Processing time limit reached ({max_process_time}s). Stopping early.")
            break
        
        # Process frame for fall detection
        try:
            frame_start_time = time.time()
            processed_frame, is_fall, current_state, condition_info = detector.process_frame(frame)
            frame_end_time = time.time()
            
            # Calculate FPS
            processing_fps = 1 / (frame_end_time - frame_start_time)
            total_fps += processing_fps
            
            # Resize processed frame to match original dimensions for display and saving
            processed_frame_resized = cv2.resize(processed_frame, (frame_width, frame_height))
            
            # Add FPS info and frame count
            cv2.putText(processed_frame_resized, f"FPS: {processing_fps:.2f}", (frame_width - 150, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
            cv2.putText(processed_frame_resized, f"Frame: {frame_count}", (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            
            # Check for false detections if we have annotation data
            if is_le2i and ground_truth_fall_frames:
                is_ground_truth_fall = frame_count in ground_truth_fall_frames
                
                # Update performance metrics
                if is_fall and is_ground_truth_fall:
                    true_positives += 1
                    outcome_text = "TRUE POSITIVE"
                    color = (0, 255, 0)  # Green for TP
                elif is_fall and not is_ground_truth_fall:
                    false_positives += 1
                    outcome_text = "FALSE POSITIVE"
                    color = (0, 0, 255)  # Red for FP
                elif not is_fall and is_ground_truth_fall:
                    false_negatives += 1
                    outcome_text = "FALSE NEGATIVE"
                    color = (255, 0, 0)  # Blue for FN
                else:
                    true_negatives += 1
                    outcome_text = "TRUE NEGATIVE"
                    color = (255, 255, 0)  # Cyan for TN
                
                # Add outcome label to frame
                cv2.putText(processed_frame_resized, outcome_text, (frame_width - 250, 60), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
                
                # If there's a mismatch between detection and ground truth, record it
                if is_fall != is_ground_truth_fall:
                    false_detections.append({
                        'frame': frame_count,
                        'predicted': 'Fall' if is_fall else 'No Fall',
                        'actual': 'Fall' if is_ground_truth_fall else 'No Fall',
                        'type': 'False Positive' if is_fall and not is_ground_truth_fall else 'False Negative'
                    })
            
            # Display the frame if requested
            if display:
                cv2.imshow('Fall Detection', processed_frame_resized)
                
                # Exit on 'q' press
                key = cv2.waitKey(1) & 0xFF
                if key == ord('q'):
                    break
                elif key == ord('n') and batch_mode:
                    # In batch mode, allow 'n' to skip to the next video
                    print("Skipping to next video...")
                    break
            
            # Save frame to output video if requested
            if save_output and out is not None:
                out.write(processed_frame_resized)
                
        except Exception as e:
            print(f"Error processing frame {frame_count}: {str(e)}")
            # Continue to next frame
            continue
    
    # Release resources
    cap.release()
    if save_output and out is not None:
        out.release()
    cv2.destroyAllWindows()
    
    # Print statistics if not in batch mode
    if frame_count > 0 and not batch_mode:
        avg_fps = total_fps / frame_count
        print(f"Processed {frame_count} frames")
        print(f"Average FPS: {avg_fps:.2f}")
        if save_output:
            print(f"Output saved to: {output_path}")
        
        # Print detection metrics
        if is_le2i:
            print("\nDetection Metrics:")
            print(f"True Positives: {true_positives}")
            print(f"True Negatives: {true_negatives}")
            print(f"False Positives: {false_positives}")
            print(f"False Negatives: {false_negatives}")
            
            # Calculate performance metrics
            accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives) if (true_positives + true_negatives + false_positives + false_negatives) > 0 else 0
            precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
            recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
            f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            print(f"Accuracy: {accuracy*100:.2f}%")
            print(f"Precision: {precision*100:.2f}%")
            print(f"Recall: {recall*100:.2f}%")
            print(f"F1 Score: {f1_score*100:.2f}%")
    
    # Create false detections file path
    false_detections_path = None
    
    # Save false detections to file if we processed Le2i data
    if is_le2i and save_false_detections and (false_positives > 0 or false_negatives > 0):
        # Create the false_detections.txt in the appropriate folder
        if 'env_fall_dir' in locals():
            false_detections_path = os.path.join(env_fall_dir, 'false_detections.txt')
        else:
            false_detections_path = os.path.join(le2i_output_dir, 'false_detections.txt')
        
        try:
            with open(false_detections_path, 'a') as f:
                f.write(f"\n--- False Detections for {os.path.basename(input_path)} ---\n")
                for detection in false_detections:
                    f.write(f"Frame {detection['frame']}: {detection['type']} - Predicted: {detection['predicted']}, Actual: {detection['actual']}\n")
            if not batch_mode:
                print(f"False detections saved to: {false_detections_path}")
        except Exception as e:
            print(f"Error saving false detections: {str(e)}")
    
    # Calculate average FPS
    avg_fps = total_fps / frame_count if frame_count > 0 else 0
    
    # Calculate performance metrics
    accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives) if (true_positives + true_negatives + false_positives + false_negatives) > 0 else 0
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    # Return statistics
    return {
        'frames_processed': frame_count,
        'average_fps': avg_fps,
        'false_detections': len(false_detections) if is_le2i else 0,
        'false_positives': false_positives,
        'false_negatives': false_negatives, 
        'true_positives': true_positives,
        'true_negatives': true_negatives,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score,
        'output_path': output_path if save_output else None,
        'false_detections_path': false_detections_path
    }

def run_interactive():
    """
    Interactive function to run fall detection with user input
    """
    # Get the weights file
    poseweights = input("Enter path to weights file [default: yolov7-w6-pose.pt]: ") or "yolov7-w6-pose.pt"
    
    # Get device type
    use_gpu = input("Use GPU? (y/n) [default: y]: ").lower() or "y"
    if use_gpu == "y":
        device = input("Enter GPU device ID [default: 0]: ") or "0"
    else:
        device = "cpu"
    
    # Get source type
    print("\nSelect input source:")
    print("1: Video file")
    print("2: Webcam")
    print("3: Single video from Le2i dataset")
    print("4: Process all Le2i dataset videos")
    source_choice = input("Enter choice [1/2/3/4]: ")
    
    if source_choice == "1":
        # Video file
        default_video = "sample_video.mp4"
        source = input(f"Enter video file path [default: {default_video}]: ") or default_video
        # Ask if user wants to display the processed video in real-time
        display_video = input("Display video with pose estimation in real-time? (y/n) [default: y]: ").lower() or "y"
        # Ask if user wants to save the output video
        save_video = input("Save output video? (y/n) [default: y]: ").lower() or "y"
        
        print(f"\nRunning fall detection with:")
        print(f"- Weights: {poseweights}")
        print(f"- Device: {device}")
        print(f"- Source: {source}")
        print(f"- Display: {'Yes' if display_video == 'y' else 'No'}")
        print(f"- Save output: {'Yes' if save_video == 'y' else 'No'}")
        confirmation = input("\nConfirm? (y/n) [default: y]: ").lower() or "y"
        
        if confirmation == "y":
            # Run the model
            run_with_display = (display_video == "y")
            save_output = (save_video == "y")
            
            # First strip optimizer to ensure model works correctly
            strip_optimizer(device, poseweights)
            
            # Run fall detection
            run_fall_detection(
                poseweights=poseweights,
                source=source,
                device=device,
                display=run_with_display,
                save_output=save_output
            )
        else:
            print("Operation cancelled")
    
    elif source_choice == "2":
        # Webcam
        cam_id = input("Enter webcam ID [default: 0]: ") or "0"
        source = cam_id
        
        # Ask if user wants to save the output video
        save_video = input("Save output video? (y/n) [default: y]: ").lower() or "y"
        
        print(f"\nRunning fall detection with:")
        print(f"- Weights: {poseweights}")
        print(f"- Device: {device}")
        print(f"- Source: Webcam {source}")
        print(f"- Display: Yes")  # Always display for webcam
        print(f"- Save output: {'Yes' if save_video == 'y' else 'No'}")
        confirmation = input("\nConfirm? (y/n) [default: y]: ").lower() or "y"
        
        if confirmation == "y":
            # First strip optimizer to ensure model works correctly
            strip_optimizer(device, poseweights)
            
            # Run fall detection
            run_fall_detection(
                poseweights=poseweights,
                source=source,
                device=device,
                display=True,  # Always display for webcam
                save_output=(save_video == "y")
            )
        else:
            print("Operation cancelled")
    
    elif source_choice == "3":
        # Single video from Le2i dataset
        default_dataset_path = "datasets"
        dataset_path = input(f"Enter dataset root path [default: {default_dataset_path}]: ") or default_dataset_path
        
        # Check if the dataset path exists
        if not os.path.exists(dataset_path):
            print(f"Error: Dataset path '{dataset_path}' does not exist.")
            return
        
        # Check for Le2i dataset structure
        le2i_path = os.path.join(dataset_path, "le2i")
        if not os.path.exists(le2i_path):
            print(f"Error: Le2i dataset not found at {le2i_path}")
            return
        
        # Check for Le2i_Sorted structure
        le2i_sorted_path = os.path.join(le2i_path, "Le2i_Sorted")
        is_sorted = os.path.exists(le2i_sorted_path)
        
        if is_sorted:
            print(f"Found Le2i dataset with sorted structure (Fall/Non Fall folders)")
            # Get environment folder
            env_folders = ["Coffee_room_01", "Coffee_room_02", "Home_01", "Home_02", "Lecture_room", "Office"]
            print("Choose environment folder:")
            for i, env in enumerate(env_folders, 1):
                print(f"{i}: {env}")
            env_choice = input("Enter choice [1-6, default: 1]: ") or "1"
            try:
                env_index = int(env_choice) - 1
                if 0 <= env_index < len(env_folders):
                    env_folder = env_folders[env_index]
                else:
                    print("Invalid choice. Using Coffee_room_01.")
                    env_folder = env_folders[0]
            except ValueError:
                print("Invalid input. Using Coffee_room_01.")
                env_folder = env_folders[0]
                
            # Choose Fall or Non Fall
            fall_type = input("Choose 'Fall' or 'Non Fall' [default: Fall]: ").strip() or "Fall"
            if fall_type.lower() not in ["fall", "non fall"]:
                print("Invalid choice. Using 'Fall'.")
                fall_type = "Fall"
                
            # Construct path to Videos folder
            videos_folder = os.path.join(le2i_path, "Le2i_Sorted", fall_type, env_folder, "Videos")
            if not os.path.exists(videos_folder):
                print(f"Error: Videos folder not found at {videos_folder}")
                return
                
            # List available videos
            video_files = [f for f in os.listdir(videos_folder) 
                         if f.endswith(('.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV'))]
            
            if not video_files:
                print(f"Error: No video files found in {videos_folder}")
                return
                
            print(f"Found {len(video_files)} video files.")
            print("Choose video file:")
            for i, video in enumerate(video_files, 1):
                print(f"{i}: {video}")
                
            video_choice = input(f"Enter choice [1-{len(video_files)}, default: 1]: ") or "1"
            try:
                video_index = int(video_choice) - 1
                if 0 <= video_index < len(video_files):
                    video_file = video_files[video_index]
                else:
                    print(f"Invalid choice. Using {video_files[0]}.")
                    video_file = video_files[0]
            except ValueError:
                print(f"Invalid input. Using {video_files[0]}.")
                video_file = video_files[0]
                
            # Construct full path to video file
            source = os.path.join(videos_folder, video_file)
        else:
            print("Le2i dataset with traditional structure not supported for this option.")
            return
            
        # Ask if user wants to display the processed video in real-time - default to Yes for batch process
        display_video = input("Display video with pose estimation in real-time? (y/n) [default: y]: ").lower() or "y"
        
        # Ask if user wants to save the output video
        save_video = input("Save output video? (y/n) [default: y]: ").lower() or "y"
        
        # Ask if user wants to save false detections to a file
        save_false_detections = input("Save false detections to a file? (y/n) [default: y]: ").lower() or "y"
        
        print(f"\nRunning fall detection with:")
        print(f"- Weights: {poseweights}")
        print(f"- Device: {device}")
        print(f"- Source: {source}")
        print(f"- Display: {'Yes' if display_video == 'y' else 'No'}")
        print(f"- Save output: {'Yes' if save_video == 'y' else 'No'}")
        print(f"- Save false detections: {'Yes' if save_false_detections == 'y' else 'No'}")
        print(f"- Fall video comparison: Enabled (will check annotations)")
        confirmation = input("\nConfirm? (y/n) [default: y]: ").lower() or "y"
        
        if confirmation == "y":
            # First strip optimizer to ensure model works correctly
            strip_optimizer(device, poseweights)
            
            # Run fall detection using the same function as the other options
            results = run_fall_detection(
                poseweights=poseweights,
                source=source,
                device=device,
                display=(display_video == "y"),  # Display parameter based on user input
                save_output=(save_video == "y"),
                save_false_detections=(save_false_detections == "y")
            )
            
            # Display summarized results
            if results:
                print("\nProcessing Summary:")
                print(f"Processed {results['frames_processed']} frames")
                print(f"Average FPS: {results['average_fps']:.2f}")
                print(f"False detections: {results['false_detections']}")
                
                if results['false_detections'] > 0 and save_false_detections == 'y':
                    print(f"False detection details saved to: {results['false_detections_path']}")
        else:
            print("Operation cancelled")
    
    elif source_choice == "4":
        # Process all Le2i dataset videos
        default_dataset_path = "datasets"
        dataset_path = input(f"Enter dataset root path [default: {default_dataset_path}]: ") or default_dataset_path
        
        # Check if the dataset path exists
        if not os.path.exists(dataset_path):
            print(f"Error: Dataset path '{dataset_path}' does not exist.")
            return
        
        # Check for Le2i dataset structure
        le2i_path = os.path.join(dataset_path, "le2i")
        if not os.path.exists(le2i_path):
            print(f"Error: Le2i dataset not found at {le2i_path}")
            return
        
        # Check for Le2i_Sorted structure
        le2i_sorted_path = os.path.join(le2i_path, "Le2i_Sorted")
        is_sorted = os.path.exists(le2i_sorted_path)
        
        if not is_sorted:
            print("Error: Le2i dataset with sorted structure (Fall/Non Fall folders) not found.")
            return
            
        print(f"Found Le2i dataset with sorted structure at {le2i_sorted_path}")
        
        # Ask if user wants to display the processed videos in real-time
        display_videos = input("Display videos with pose estimation in real-time? (y/n) [default: n]: ").lower() or "n"
        
        # Ask if user wants to save the output videos
        save_videos = input("Save output videos? (y/n) [default: y]: ").lower() or "y"
        
        # Ask if user wants to save false detections to a file
        save_false_detections = input("Save false detections to a file? (y/n) [default: y]: ").lower() or "y"
        
        # Create summary stats file
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        summary_path = os.path.join('output', 'le2i_results', f'batch_summary_{timestamp}.txt')
        os.makedirs(os.path.dirname(summary_path), exist_ok=True)
        
        print(f"\nRunning batch processing of all Le2i dataset videos with:")
        print(f"- Weights: {poseweights}")
        print(f"- Device: {device}")
        print(f"- Display: {'Yes' if display_videos == 'y' else 'No'}")
        print(f"- Save outputs: {'Yes' if save_videos == 'y' else 'No'}")
        print(f"- Save false detections: {'Yes' if save_false_detections == 'y' else 'No'}")
        print(f"- Summary will be saved to: {summary_path}")
        confirmation = input("\nConfirm? (y/n) [default: y]: ").lower() or "y"
        
        if confirmation == "y":
            # First strip optimizer to ensure model works correctly
            strip_optimizer(device, poseweights)
            
            # Initialize FallDetector once for all videos
            detector = FallDetector(poseweights=poseweights, device=device)
            
            # Get list of all environment folders
            fall_folders = ['Fall', 'Non Fall']
            env_folders = ["Coffee_room_01", "Coffee_room_02", "Home_01", "Home_02", "Lecture_room", "Office"]
            
            # Initialize counters for summary
            total_videos = 0
            total_videos_processed = 0
            total_frames = 0
            total_falls_detected = 0
            total_false_positives = 0
            total_false_negatives = 0
            total_true_positives = 0
            total_true_negatives = 0
            total_processing_time = 0
            
            # Open summary file
            with open(summary_path, 'w') as summary_file:
                summary_file.write(f"Le2i Dataset Batch Processing Summary - {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                summary_file.write("=" * 80 + "\n\n")
                
                # Process all fall types
                for fall_type in fall_folders:
                    summary_file.write(f"\n{fall_type} Videos:\n")
                    summary_file.write("-" * 50 + "\n")
                    
                    # Process all environment folders
                    for env_folder in env_folders:
                        videos_folder = os.path.join(le2i_path, "Le2i_Sorted", fall_type, env_folder, "Videos")
                        
                        # Skip if folder doesn't exist
                        if not os.path.exists(videos_folder):
                            continue
                        
                        # Get video files
                        video_files = [f for f in os.listdir(videos_folder) 
                                     if f.endswith(('.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV'))]
                        
                        if not video_files:
                            continue
                        
                        total_videos += len(video_files)
                        summary_file.write(f"\nEnvironment: {env_folder} - {len(video_files)} videos\n")
                        summary_file.flush()  # Flush to ensure progress is saved
                        
                        # Process each video
                        for video_file in video_files:
                            source = os.path.join(videos_folder, video_file)
                            print(f"\nProcessing: {fall_type}/{env_folder}/{video_file}")
                            
                            start_time = time.time()
                            
                            # Run fall detection
                            try:
                                results = run_fall_detection(
                                    detector=detector,  # Reuse the same detector
                                    source=source,
                                    device=device,
                                    display=(display_videos == "y"),
                                    save_output=(save_videos == "y"),
                                    save_false_detections=(save_false_detections == "y"),
                                    batch_mode=True  # Enable batch mode to suppress individual metrics
                                )
                                
                                end_time = time.time()
                                processing_time = end_time - start_time
                                
                                # Check if processing was actually completed
                                if results['frames_processed'] > 0:
                                    total_videos_processed += 1
                                    total_frames += results['frames_processed']
                                    total_processing_time += processing_time
                                    
                                    # Update metrics counters
                                    total_false_positives += results['false_positives']
                                    total_false_negatives += results['false_negatives']
                                    total_true_positives += results['true_positives']
                                    total_true_negatives += results['true_negatives']
                                    
                                    if fall_type == 'Fall':
                                        total_falls_detected += results['true_positives']
                                    
                                    # Write results to summary
                                    summary_file.write(f"  - {video_file}: {results['frames_processed']} frames, ")
                                    summary_file.write(f"FPS: {results['average_fps']:.2f}")
                                    
                                    if 'accuracy' in results:
                                        summary_file.write(f", Accuracy: {results['accuracy']*100:.2f}%")
                                    
                                    if fall_type == 'Fall':
                                        summary_file.write(f", False Negatives: {results['false_negatives']}, ")
                                        summary_file.write(f"True Positives: {results['true_positives']}")
                                    else:
                                        summary_file.write(f", False Positives: {results['false_positives']}, ")
                                        summary_file.write(f"True Negatives: {results['true_negatives']}")
                                    
                                    summary_file.write("\n")
                                    summary_file.flush()  # Flush to ensure progress is saved
                                else:
                                    print(f"Warning: No frames were processed for {video_file}. Skipping.")
                                    summary_file.write(f"  - {video_file}: SKIPPED - No frames processed\n")
                                    summary_file.flush()
                            except Exception as e:
                                print(f"Error processing {video_file}: {str(e)}")
                                summary_file.write(f"  - {video_file}: ERROR - {str(e)}\n")
                                summary_file.flush()
                
                # Calculate overall metrics
                if total_frames > 0:
                    overall_accuracy = (total_true_positives + total_true_negatives) / (total_true_positives + total_true_negatives + total_false_positives + total_false_negatives) if (total_true_positives + total_true_negatives + total_false_positives + total_false_negatives) > 0 else 0
                    overall_precision = total_true_positives / (total_true_positives + total_false_positives) if (total_true_positives + total_false_positives) > 0 else 0
                    overall_recall = total_true_positives / (total_true_positives + total_false_negatives) if (total_true_positives + total_false_negatives) > 0 else 0
                    overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
                    
                    # Write overall summary
                    summary_file.write("\n\nOverall Summary:\n")
                    summary_file.write("=" * 50 + "\n")
                    summary_file.write(f"Total videos: {total_videos}\n")
                    summary_file.write(f"Videos processed: {total_videos_processed}\n")
                    summary_file.write(f"Total frames processed: {total_frames}\n")
                    summary_file.write(f"Total processing time: {total_processing_time:.2f} seconds\n")
                    summary_file.write(f"Average FPS: {total_frames/total_processing_time if total_processing_time > 0 else 0:.2f}\n")
                    summary_file.write(f"Total true positives: {total_true_positives}\n")
                    summary_file.write(f"Total true negatives: {total_true_negatives}\n")
                    summary_file.write(f"Total false positives: {total_false_positives}\n")
                    summary_file.write(f"Total false negatives: {total_false_negatives}\n")
                    summary_file.write(f"Overall accuracy: {overall_accuracy*100:.2f}%\n")
                    summary_file.write(f"Overall precision: {overall_precision*100:.2f}%\n")
                    summary_file.write(f"Overall recall: {overall_recall*100:.2f}%\n")
                    summary_file.write(f"Overall F1 score: {overall_f1*100:.2f}%\n")
            
            # Print overall summary to console
            print("\n" + "="*50)
            print("BATCH PROCESSING COMPLETE")
            print("="*50)
            print(f"Videos processed: {total_videos_processed}/{total_videos}")
            print(f"Total frames processed: {total_frames}")
            print(f"Total processing time: {total_processing_time:.2f} seconds")
            print(f"Average FPS: {total_frames/total_processing_time if total_processing_time > 0 else 0:.2f}")
            print(f"Overall accuracy: {overall_accuracy*100:.2f}%")
            print(f"Overall precision: {overall_precision*100:.2f}%")
            print(f"Overall recall: {overall_recall*100:.2f}%")
            print(f"Overall F1 score: {overall_f1*100:.2f}%")
            print(f"\nDetailed summary saved to: {summary_path}")
        else:
            print("Operation cancelled")
    
    else:
        print("Invalid choice. Please run again and select a valid option.")
    
if __name__ == "__main__":
    # Run interactively
    run_interactive()