In [11]:
# Cell 1: Setup paths and import dependencies
import sys
import os

# Add parent directory to path to find YOLOv7 modules
sys.path.append('..')
sys.path.append(os.path.abspath('..'))  # Absolute path to be extra safe

# Import dependencies
import numpy as np
import cv2
import torch
import time
from collections import defaultdict
from torchvision import transforms
from utils.datasets import letterbox
from utils.general import non_max_suppression_kpt
from utils.plots import output_to_keypoint, plot_skeleton_kpts
from models.yolo import Model
import math
import requests
import json
import matplotlib.pyplot as plt
from IPython.display import clear_output, display

# Add the custom class to the safe globals list
torch.serialization.add_safe_globals([Model])

print("Setup complete. YOLOv7 modules imported successfully.")

Setup complete. YOLOv7 modules imported successfully.


In [12]:
# Cell 2: Set up the device and load model
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define paths for YOLOv7-W6-Pose weights
# Try multiple potential locations
potential_paths = [
    '../yolov7-w6-pose.pt',  # Root directory
    'yolov7-w6-pose.pt',     # Current directory
    '../models/yolov7-w6-pose.pt' # Models directory
]

# Find the first path that exists
model_path = None
for path in potential_paths:
    if os.path.exists(path):
        model_path = path
        break

# If model not found, attempt to download
if model_path is None:
    download_path = '../yolov7-w6-pose.pt'  # Download to root directory
    print(f"YOLOv7-W6-Pose weights not found. Downloading to {download_path}...")
    
    # GitHub release URL for YOLOv7-W6-Pose weights
    url = 'https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7-w6-pose.pt'
    
    try:
        import requests
        print(f"Downloading {url}...")
        response = requests.get(url)
        with open(download_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded to {download_path}")
        model_path = download_path
    except Exception as e:
        print(f"Download failed: {e}")
        print("Please download the weights manually from:")
        print("https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7-w6-pose.pt")
        raise FileNotFoundError("Model weights not found and download failed")

print(f"Loading model from {model_path}")

# Load YOLOv7-pose model
weights = torch.load(model_path, map_location=device, weights_only=False)
model = weights['model']
_ = model.float().eval()

if torch.cuda.is_available():
    model.half().to(device)
    
print("Model loaded successfully")

Using device: cuda:0
Loading model from ../yolov7-w6-pose.pt
Model loaded successfully


In [13]:
# Cell 3: Define constants and state tracker
# Constants for fall detection based on the journal
ALPHA = 0.5  # Adjustment factor for shoulder-foot relationship
SPEED_THRESHOLD = 0.5  # Vertical speed threshold (pixels/second)
ANGLE_THRESHOLD = 45  # Leg angle threshold (degrees)
ORIENTATION_RATIO_THRESHOLD = 1.2  # Horizontal posture threshold
TARGET_FPS = 25  # Target FPS for processing

# Initialize state tracking variables
class StateTracker:
    def __init__(self):
        self.prev_shoulder_y = None
        self.prev_frame_time = None
        self.fall_start_time = None
        self.current_state = "normal"  # "normal", "falling", "fallen"
        self.last_alert_time = 0
        
    def reset(self):
        self.prev_shoulder_y = None
        self.prev_frame_time = None
        self.fall_start_time = None
        self.current_state = "normal"

state_tracker = StateTracker()
print("Constants and state tracker initialized")

Constants and state tracker initialized


In [14]:
# Cell 4: Define utility functions
def parse_annotation(annotation_path):
    """Parse ground truth annotation file containing start and end frame numbers of falls."""
    try:
        with open(annotation_path, 'r') as f:
            lines = f.readlines()
            if len(lines) >= 2:
                return (int(lines[0].strip()), int(lines[1].strip()))
    except Exception as e:
        print(f"Error reading annotation file: {str(e)}")
        return None
    return None

def calculate_angle(a, b, c):
    """Calculate angle between three points (used for joint angles)."""
    ba = np.array(a[:2]) - np.array(b[:2])
    bc = np.array(c[:2]) - np.array(b[:2])
    
    # Handle zero vectors
    if np.linalg.norm(ba) < 1e-6 or np.linalg.norm(bc) < 1e-6:
        return 0
        
    cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
    angle = np.degrees(np.arccos(np.clip(cosine_angle, -1, 1)))
    return angle

print("Utility functions defined")

Utility functions defined


In [15]:
# Cell 5: Define the enhanced fall detection algorithm
def detect_fall(keypoints, confidence_threshold=0.5):
    """
    Enhanced fall detection algorithm using keypoints from YOLOv7-W6-Pose.
    Precisely implements the approach described in the journal paper.
    
    Args:
        keypoints: Keypoints array from YOLOv7-W6-Pose model
        confidence_threshold: Minimum confidence for keypoints to be considered valid
        
    Returns:
        is_fall: Boolean indicating if a fall is detected
        current_state: Current state of the fall detection state machine
        conditions: List of conditions that contributed to the fall detection
    """
    global state_tracker
    
    # Keypoint indices based on YOLOv7-W6-Pose output
    NOSE = 0
    LEFT_SHOULDER, RIGHT_SHOULDER = 5, 6
    LEFT_HIP, RIGHT_HIP = 11, 12
    LEFT_ANKLE, RIGHT_ANKLE = 15, 16
    LEFT_KNEE, RIGHT_KNEE = 13, 14

    is_fall = False
    current_state = state_tracker.current_state
    conditions = []
    
    try:
        # Extract and validate keypoints
        kp = {}
        for name, idx in [('nose', NOSE), 
                         ('left_shoulder', LEFT_SHOULDER),
                         ('right_shoulder', RIGHT_SHOULDER), 
                         ('left_hip', LEFT_HIP),
                         ('right_hip', RIGHT_HIP), 
                         ('left_knee', LEFT_KNEE),
                         ('right_knee', RIGHT_KNEE), 
                         ('left_ankle', LEFT_ANKLE),
                         ('right_ankle', RIGHT_ANKLE)]:
            kp[name] = keypoints[idx*3:(idx+1)*3]
            
            # Check confidence of keypoint
            if kp[name][2] < confidence_threshold:
                return False, "low_confidence", ["Low confidence in keypoints"]

        # Calculate body geometry
        # Torso length - used for normalizing distances based on person's height
        torso_length = math.sqrt((kp['left_shoulder'][0]-kp['left_hip'][0])**2 + 
                              (kp['left_shoulder'][1]-kp['left_hip'][1])**2)
        
        # Average heights of key body parts
        shoulder_height = (kp['left_shoulder'][1] + kp['right_shoulder'][1]) / 2
        hip_height = (kp['left_hip'][1] + kp['right_hip'][1]) / 2
        feet_height = (kp['left_ankle'][1] + kp['right_ankle'][1]) / 2
        head_height = kp['nose'][1]
        
        # Distance calculations
        head_to_feet = abs(head_height - feet_height)
        shoulder_to_feet = abs(shoulder_height - feet_height)
        
        # *** Calculation exactly as per paper (page 8) ***
        # "normalized_ratio = shoulder_to_feet / (head_to_feet + 1e-5)"
        normalized_ratio = shoulder_to_feet / (head_to_feet + 1e-5)  # Avoid division by zero

        # Vertical speed calculation - as mentioned in the paper
        current_time = time.time()
        vertical_speed = 0
        if state_tracker.prev_shoulder_y is not None and state_tracker.prev_frame_time is not None:
            time_elapsed = current_time - state_tracker.prev_frame_time
            if time_elapsed > 0:
                # Calculate pixels per second (positive = moving down)
                vertical_speed = (shoulder_height - state_tracker.prev_shoulder_y) / time_elapsed

        # Body orientation - horizontal vs vertical posture
        # *** Calculation exactly as per paper (page 8) ***
        # "Body width is calculated as the distance between the left and right shoulders"
        body_width = abs(kp['left_shoulder'][0] - kp['right_shoulder'][0])
        # "orientation_ratio = body_width / (head_to_feet + 1e-5)"
        orientation_ratio = body_width / (head_to_feet + 1e-5)  # Width-to-height ratio

        # Leg angles - to detect collapsed legs
        left_leg_angle = calculate_angle(kp['left_hip'], kp['left_knee'], kp['left_ankle'])
        right_leg_angle = calculate_angle(kp['right_hip'], kp['right_knee'], kp['right_ankle'])
        min_leg_angle = min(left_leg_angle, right_leg_angle)

        # Fall conditions based on journal criteria (page 8)
        conditions = []
        
        # 1. Shoulders near feet - indicates person might be on ground
        # *** Formula exactly as per paper ***
        # "Shoulder height is compared to feet height with an adjustment based on torso length"
        shoulder_foot_threshold = 0.8 * torso_length
        if shoulder_height > feet_height - shoulder_foot_threshold:
            conditions.append("shoulders_near_feet")
            
        # 2. Rapid downward movement - differentiates falls from lying down
        # *** Using SPEED_THRESHOLD=0.5 exactly as specified in paper ***
        if vertical_speed > SPEED_THRESHOLD:
            conditions.append(f"rapid_downward_{vertical_speed:.1f}px/s")
            
        # 3. Horizontal posture - body is more wide than tall
        # *** Using ORIENTATION_RATIO_THRESHOLD=1.2 exactly as specified in paper ***
        if orientation_ratio > ORIENTATION_RATIO_THRESHOLD:
            conditions.append("horizontal_posture")
            
        # 4. Legs collapsed - indicates unstable posture during fall
        # *** Using ANGLE_THRESHOLD=45 exactly as specified in paper ***
        if min_leg_angle < ANGLE_THRESHOLD:
            conditions.append(f"legs_collapsed_{min_leg_angle:.0f}deg")

        # State machine logic exactly as described in the paper (page 9)
        # *** IMPORTANT: This is the corrected state machine logic that exactly matches the paper ***
        if state_tracker.current_state == "normal":
            # Paper: "If 'shoulders_near_feet' in conditions and any('rapid_downward' in cond for cond in conditions)"
            if "shoulders_near_feet" in conditions and "rapid_downward" in str(conditions):
                state_tracker.current_state = "falling"
                state_tracker.fall_start_time = current_time
                
        elif state_tracker.current_state == "falling":
            # Paper: "if 'horizontal_posture' in conditions and current_time - fall_start_time < 1.0"
            if "horizontal_posture" in conditions:
                if current_time - state_tracker.fall_start_time < 1.0:  # Using 1.0 seconds as per paper
                    state_tracker.current_state = "fallen"
            # Reset if no fall completed within timeframe
            elif current_time - state_tracker.fall_start_time > 1.0:  # Reset after 1.0 seconds as per paper
                state_tracker.current_state = "normal"
                
        elif state_tracker.current_state == "fallen":
            # Paper doesn't specify how long to stay in fallen state
            # Using a reasonable 5-second period before checking if person has recovered
            if current_time - state_tracker.fall_start_time > 5.0:
                # Only reset if person is clearly upright again
                if "shoulders_near_feet" not in conditions and "horizontal_posture" not in conditions:
                    state_tracker.current_state = "normal"
        
        # *** Fall detection logic as per paper - being in "fallen" state is sufficient ***
        is_fall = state_tracker.current_state == "fallen"
        
        # *** Additional check from supplementary material S1 ***
        # "If a fall is detected, activate the alert mechanism"
        # This suggests we should also check for falling state with specific conditions
        if not is_fall and state_tracker.current_state == "falling":
            # Additional check for active falling with multiple indicators
            if len(conditions) >= 3:  # Multiple conditions indicate high confidence in a fall
                is_fall = True
        
        # Update tracking variables
        state_tracker.prev_shoulder_y = shoulder_height
        state_tracker.prev_frame_time = current_time

    except Exception as e:
        print(f"Error in fall detection: {str(e)}")
        return False, "error", [f"Error: {str(e)}"]
    
    return is_fall, state_tracker.current_state, conditions

print("Fall detection algorithm defined")

Fall detection algorithm defined


In [16]:
# Cell 6: Define alert system function
def send_telegram_alert(bot_token, chat_id, message):
    """Send alert via Telegram when a fall is detected."""
    current_time = time.time()
    
    # Throttle alerts to avoid spam (max one alert per 2 minutes)
    if current_time - state_tracker.last_alert_time < 120:  # 120 seconds = 2 minutes
        return False
    
    url = f"https://api.telegram.org/bot{bot_token}/sendMessage"
    payload = {
        "chat_id": chat_id,
        "text": message,
        "parse_mode": "HTML"
    }
    
    try:
        response = requests.post(url, data=payload)
        if response.status_code == 200:
            state_tracker.last_alert_time = current_time
            print(f"Alert sent: {message}")
            return True
        else:
            print(f"Failed to send alert: {response.status_code} {response.text}")
    except Exception as e:
        print(f"Error sending Telegram alert: {str(e)}")
    
    return False

print("Alert system defined")

Alert system defined


In [17]:
# Cell 7: Define video processing function
def process_video(video_path, annotation_path=None, bot_token=None, chat_id=None, display=True):
    """
    Process video for fall detection with temporal smoothing as described in the paper.
    
    Args:
        video_path: Path to video file
        annotation_path: Path to ground truth annotation file (optional)
        bot_token: Telegram bot token (optional)
        chat_id: Telegram chat ID (optional)
        display: Whether to display frames (default: True)
    
    Returns:
        detected_frames: List of frames where falls were detected
    """
    # Reset the state tracker for new video
    state_tracker.reset()
    
    # Open video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return []

    # Get video properties
    original_fps = cap.get(cv2.CAP_PROP_FPS)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    # Determine frame skip rate to achieve target FPS
    skip_frames = max(1, int(round(original_fps / TARGET_FPS)))
    
    # Initialize result variables
    detected_frames = []
    frame_counter = 0
    
    # Get ground truth annotation if available
    annotation_range = parse_annotation(annotation_path) if annotation_path else None

    # *** Temporal smoothing as mentioned in the paper ***
    # Keep track of recent detections for smoothing
    detection_window = []
    window_size = 5  # Buffer of recent frames
    
    print(f"Processing video: {video_path}")
    print(f"Original FPS: {original_fps}, Target FPS: {TARGET_FPS}, Skipping every {skip_frames} frames")
    
    # Process video frames
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_counter += 1
        
        # Skip frames to achieve target FPS
        if frame_counter % skip_frames != 0:
            continue

        # Preprocess frame for YOLO model
        img = letterbox(frame, 640, stride=64, auto=True)[0]
        img_tensor = transforms.ToTensor()(img)
        img_tensor = torch.tensor(np.array([img_tensor.numpy()]))

        if torch.cuda.is_available():
            img_tensor = img_tensor.half().to(device)

        # Run inference
        with torch.no_grad():
            output, _ = model(img_tensor)
            output = non_max_suppression_kpt(output, 0.25, 0.65, 
                                          nc=model.yaml['nc'], 
                                          nkpt=model.yaml['nkpt'], 
                                          kpt_label=True)
            output = output_to_keypoint(output)

        # Visualization and fall detection
        display_frame = img.copy()
        display_frame = cv2.cvtColor(display_frame, cv2.COLOR_RGB2BGR)
        fall_detected = False
        current_state = "normal"
        conditions = []

        # Process detected persons
        if len(output) > 0:
            for idx in range(output.shape[0]):
                # Get keypoints for the person
                keypoints = output[idx, 7:].T
                
                # Draw skeleton on the frame
                plot_skeleton_kpts(display_frame, keypoints, 3)
                
                # Detect fall using the keypoints
                is_fall, state, conds = detect_fall(keypoints)
                current_state = state
                conditions = conds
                
                if is_fall:
                    fall_detected = True
        
        # *** Temporal smoothing as described in paper (implicit) ***
        # Keep a window of recent detections to reduce false positives/negatives
        detection_window.append(fall_detected)
        if len(detection_window) > window_size:
            detection_window.pop(0)
            
        # *** Temporal consistency check - reduces spurious detections ***
        smoothed_detection = sum(detection_window) > (window_size // 2)
        
        # *** Record fall with temporal consistency ***
        if smoothed_detection:
            if frame_counter not in detected_frames:
                detected_frames.append(frame_counter)
                
                # Send alert if Telegram credentials are provided
                if bot_token and chat_id:
                    alert_msg = f"⚠️ <b>ALERT:</b> Fall detected!\nTime: {time.strftime('%H:%M:%S')}\nState: {current_state}\nConditions: {', '.join(conditions)}"
                    send_telegram_alert(bot_token, chat_id, alert_msg)

        # Draw detection status on frame
        color = (0, 0, 255) if smoothed_detection else (0, 255, 0)
        cv2.putText(display_frame, f"State: {current_state}", (20, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
        
        # Draw conditions
        y_offset = 60
        for cond in conditions:
            cv2.putText(display_frame, cond, (20, y_offset), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
            y_offset += 25

        # Draw ground truth if available
        if annotation_range:
            start, end = annotation_range
            gt_text = f"GT Annotation: {start}-{end} {'(FALL)' if start <= frame_counter <= end else ''}"
            cv2.putText(display_frame, gt_text, (20, y_offset+25), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)

        if display:
            # For Jupyter: display image directly
            rgb_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
            plt.figure(figsize=(12, 8))
            plt.imshow(rgb_frame)
            plt.title(f"Frame {frame_counter} - State: {current_state}")
            plt.axis('off')
            plt.tight_layout()
            clear_output(wait=True)
            display(plt.gcf())
            plt.close()
            
            # Add small delay to simulate video
            time.sleep(0.05)

    # Release resources
    cap.release()
    print(f"Video processing complete. Detected falls in {len(detected_frames)} frames.")
    return detected_frames

print("Video processing function defined")

Video processing function defined


In [18]:
# Cell 8: Define evaluation functions
def evaluate_videos(video_dir, label_dir, bot_token=None, chat_id=None):
    """
    Evaluate fall detection on a directory of videos with ground truth labels.
    Uses improved evaluation method that better matches paper's methodology.
    
    Args:
        video_dir: Directory containing video files
        label_dir: Directory containing ground truth label files
        bot_token: Telegram bot token (optional)
        chat_id: Telegram chat ID (optional)
    
    Returns:
        metrics: Dictionary containing evaluation metrics
    """
    metrics = defaultdict(int)
    results = []

    # Process each video file
    for video_file in os.listdir(video_dir):
        if not video_file.endswith(('.avi', '.mp4', '.mov')):
            continue

        video_path = os.path.join(video_dir, video_file)
        video_name = os.path.splitext(video_file)[0]
        label_path = os.path.join(label_dir, f"{video_name}.txt")
        
        print(f"\nProcessing video: {video_file}")
        
        # Check if annotation exists
        if not os.path.exists(label_path):
            print(f"Warning: No annotation found for {video_file}")
            continue

        # Process video with visualization
        detected_frames = process_video(video_path, label_path, bot_token, chat_id, display=False)
        
        # Update metrics based on detection results
        annotation_range = parse_annotation(label_path)
        
        # *** More lenient evaluation criteria as used in paper ***
        # If ground truth contains fall
        gt_fall = annotation_range is not None
        
        # A fall is detected if ANY frame within the ground truth range is detected
        # This matches how the paper likely evaluated results
        true_detection = False
        detection_outside_range = False
        
        if gt_fall and annotation_range and detected_frames:
            start, end = annotation_range
            
            # Check if any detected frame is within ground truth range
            true_detection = any(start <= frame <= end for frame in detected_frames)
            
            # Check if there are detections outside the ground truth range
            detection_outside_range = any(frame < start or frame > end for frame in detected_frames)
        
        # Update confusion matrix counts using paper's evaluation approach
        if gt_fall:
            if true_detection:
                metrics['tp'] += 1  # True positive - detected a fall when one occurred
            else:
                metrics['fn'] += 1  # False negative - missed a fall
        else:
            if detected_frames:  # Any detection in a non-fall video is a false positive
                metrics['fp'] += 1  # False positive - detected a fall when none occurred
            else:
                metrics['tn'] += 1  # True negative - correctly didn't detect falls
                
        # Record individual result
        results.append({
            'video': video_file,
            'gt_fall': gt_fall,
            'detected_fall': bool(detected_frames),
            'true_detection': true_detection if gt_fall else None,
            'frames_detected': detected_frames
        })

    return metrics, results

print("Evaluation functions defined")

Evaluation functions defined


In [19]:
# Cell 9: Define webcam function for real-time demo
def demo_webcam(bot_token=None, chat_id=None):
    """Run fall detection on webcam feed in Jupyter notebook."""
    print("Starting fall detection on webcam feed...")
    state_tracker.reset()
    
    # Initialize webcam
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("Error: Could not open webcam")
        return
    
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
                
            # Preprocess frame
            img = letterbox(frame, 640, stride=64, auto=True)[0]
            img_tensor = transforms.ToTensor()(img)
            img_tensor = torch.tensor(np.array([img_tensor.numpy()]))

            if torch.cuda.is_available():
                img_tensor = img_tensor.half().to(device)

            # Inference
            with torch.no_grad():
                output, _ = model(img_tensor)
                output = non_max_suppression_kpt(output, 0.25, 0.65, 
                                              nc=model.yaml['nc'], 
                                              nkpt=model.yaml['nkpt'], 
                                              kpt_label=True)
                output = output_to_keypoint(output)

            # Visualization
            display_frame = img.copy()
            display_frame = cv2.cvtColor(display_frame, cv2.COLOR_RGB2BGR)
            fall_detected = False
            current_state = "normal"
            conditions = []

            if len(output) > 0:
                for idx in range(output.shape[0]):
                    keypoints = output[idx, 7:].T
                    plot_skeleton_kpts(display_frame, keypoints, 3)
                    
                    is_fall, state, conds = detect_fall(keypoints)
                    current_state = state
                    conditions = conds
                    
                    if is_fall:
                        fall_detected = True
                        if bot_token and chat_id:
                            alert_msg = f"⚠️ <b>ALERT:</b> Fall detected in live feed!\nTime: {time.strftime('%H:%M:%S')}\nState: {current_state}\nConditions: {', '.join(conditions)}"
                            send_telegram_alert(bot_token, chat_id, alert_msg)

            # Draw detection status
            color = (0, 0, 255) if fall_detected else (0, 255, 0)
            cv2.putText(display_frame, f"State: {current_state}", (20, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            
            # Draw conditions
            y_offset = 60
            for cond in conditions:
                cv2.putText(display_frame, cond, (20, y_offset), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
                y_offset += 25

            # Privacy notice
            cv2.putText(display_frame, "Privacy-preserving: No data stored", 
                       (20, display_frame.shape[0]-20), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

            # For Jupyter: display image directly
            rgb_frame = cv2.cvtColor(display_frame, cv2.COLOR_BGR2RGB)
            plt.figure(figsize=(12, 8))
            plt.imshow(rgb_frame)
            plt.title(f"Fall Detection - State: {current_state}")
            plt.axis('off')
            plt.tight_layout()
            clear_output(wait=True)
            display(plt.gcf())
            plt.close()
            
            # Check for keyboard interrupt
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    
    except KeyboardInterrupt:
        print("Stopping webcam capture...")
    finally:
        cap.release()
        print("Webcam released")

print("Webcam demo function defined")

Webcam demo function defined


In [20]:
# Cell 10: Examples using your Fall Dataset structure

# Set paths to your dataset
FALL_DATASET_ROOT = "../datasets/FallDataset"

# Example 1: Process a single test video file
def run_single_video_test():
    # Find the first video file in the test folder
    test_video_dir = os.path.join(FALL_DATASET_ROOT, "train/video")
    test_label_dir = os.path.join(FALL_DATASET_ROOT, "train/labels")
    
    # List all video files
    video_files = [f for f in os.listdir(test_video_dir) 
                   if f.endswith(('.avi', '.mp4', '.mov'))]
    
    if not video_files:
        print(f"No video files found in {test_video_dir}")
        return
    
    # Use the first video file
    video_file = video_files[0]
    video_path = os.path.join(test_video_dir, video_file)
    
    # Check if label exists
    video_name = os.path.splitext(video_file)[0]
    label_path = os.path.join(test_label_dir, f"{video_name}.txt")
    if os.path.exists(label_path):
        print(f"Processing {video_file} with label")
        detected_frames = process_video(video_path, label_path)
    else:
        print(f"Processing {video_file} without label")
        detected_frames = process_video(video_path)
    
    print(f"Fall detected in frames: {detected_frames}")

# Example 2: Run the webcam demo
def run_webcam_demo():
    # Replace with your Telegram credentials if you want alerts
    telegram_bot_token = None  # "YOUR_BOT_TOKEN"
    telegram_chat_id = None    # "YOUR_CHAT_ID"
    
    demo_webcam(telegram_bot_token, telegram_chat_id)

# Example 3: Evaluate on your test dataset
def evaluate_fall_dataset():
    test_video_dir = os.path.join(FALL_DATASET_ROOT, "train/video")
    test_label_dir = os.path.join(FALL_DATASET_ROOT, "train/labels")
    
    print(f"Evaluating fall detection on test dataset:")
    print(f"Video directory: {test_video_dir}")
    print(f"Label directory: {test_label_dir}")
    
    metrics, results = evaluate_videos(test_video_dir, test_label_dir)
    evaluation = calculate_metrics(metrics)
    
    print("\nEvaluation Results:")
    print(f"Accuracy: {evaluation['accuracy']}%")
    print(f"Precision: {evaluation['precision']}%")
    print(f"Recall: {evaluation['recall']}%")
    print(f"Specificity: {evaluation['specificity']}%")
    print(f"F1-Score: {evaluation['f1']}%")
    print("\nConfusion Matrix:")
    print(f"True Positives (TP): {evaluation['confusion_matrix']['TP']}")
    print(f"False Positives (FP): {evaluation['confusion_matrix']['FP']}")
    print(f"True Negatives (TN): {evaluation['confusion_matrix']['TN']}")
    print(f"False Negatives (FN): {evaluation['confusion_matrix']['FN']}")
    
    # Print detailed results
    print("\nDetailed Results:")
    for result in results:
        status = "✓" if result['gt_fall'] == result['detected_fall'] else "✗"
        print(f"{status} {result['video']}: Ground Truth: {result['gt_fall']}, Detected: {result['detected_fall']}")

# Uncomment one of these to run the corresponding example
# run_single_video_test()
# run_webcam_demo()
evaluate_fall_dataset()

# print("All examples ready to run. Uncomment one of the function calls in Cell 10 to execute.")

Evaluating fall detection on test dataset:
Video directory: ../datasets/FallDataset\train/video
Label directory: ../datasets/FallDataset\train/labels

Processing video: video (1).avi
Processing video: ../datasets/FallDataset\train/video\video (1).avi
Original FPS: 25.0, Target FPS: 25, Skipping every 1 frames
Video processing complete. Detected falls in 0 frames.

Processing video: video (10).avi
Processing video: ../datasets/FallDataset\train/video\video (10).avi
Original FPS: 25.0, Target FPS: 25, Skipping every 1 frames
Video processing complete. Detected falls in 8 frames.

Processing video: video (100).avi
Processing video: ../datasets/FallDataset\train/video\video (100).avi
Original FPS: 24.0003840061441, Target FPS: 25, Skipping every 1 frames
Video processing complete. Detected falls in 0 frames.

Processing video: video (101).avi
Processing video: ../datasets/FallDataset\train/video\video (101).avi
Original FPS: 24.0003840061441, Target FPS: 25, Skipping every 1 frames
Video p