import cv2
import time
from picamera2 import Picamera2
import statistics  # <-- ADDED for median

def main():
    show_stream = True    # whether or not to display the camera stream
    display_every = 10 # how many frames to wait until an imshow() call
    show_print = True # whether or to print each frame through the terminal
    save_image = True
    save_every = 20
    run_duration_sec = 5

    # Initialize statistics
    min_duration = float('inf')
    max_duration = 0.0
    total_duration = 0.0
    frame_count = 0
    all_durations = []  
    all_fps = []
    
    supported_resolutions = [
    (2064, 1552),
    (1920, 1080),
    (1032, 776),
    (1024, 720)
    ]

    # Initialize Picamera2
    picam2 = Picamera2()
    camera_resolution = supported_resolutions[3]
    mode = picam2.sensor_modes[1]
    camera_config = picam2.create_video_configuration(
        sensor={'output_size': mode['size'], 'bit_depth': mode['bit_depth']},
        #raw={'format': 'R8'},
        main={"size": camera_resolution, "format": "XRGB8888", "preserve_ar": True}
        # {'format': 'R16', 'size': (2064, 1552), 'stride': 4160, 'framesize': 6456320} True
        # {'format': 'R16', 'size': (2064, 1552), 'stride': 4160, 'framesize': 6456320}

        
    )
    picam2.configure(camera_config)
    print(picam2.stream_configuration("raw"))


    # Start camera
    picam2.start()
    
    default_crop = picam2.camera_controls['ScalerCrop']
    print("Default crop: ", default_crop)
    
    # Discard first frame because it's always slow
    frame = picam2.capture_array()  
    del frame  # throw it away
    print(f"Camera stream started. Press 'q' to exit.")

    # Prepare filenames with resolution
    width, height = camera_resolution
    raw_filename = f"capture_raw_imx900_{width}x{height}.txt"          
    results_filename = f"capture_results_imx900_{width}x{height}.txt" 

    # Open raw data file for writing
    raw_file = open(raw_filename, "w")
    raw_file.write("Frame#\tDuration(ms)\tFPS\n")
    
    start_time_global = time.perf_counter()


    try:
        while True:
            if (time.perf_counter() - start_time_global) > run_duration_sec:
                break
              
            start_time = time.perf_counter()
            frame = picam2.capture_array()
            capture_duration = (time.perf_counter() - start_time) * 1000  # ms
            instant_fps = 1000.0 / capture_duration if capture_duration > 0 else 0

            # Update statistics
            frame_count += 1
            total_duration += capture_duration
            min_duration = min(min_duration, capture_duration)
            max_duration = max(max_duration, capture_duration)
            all_durations.append(capture_duration)
            all_fps.append(instant_fps)

            # Print to terminal
            if show_print == True:
                print(f"Frame {frame_count}: {capture_duration:.3f} ms | Instant FPS: {instant_fps:.2f} | "
                      f"Min: {min_duration:.3f} ms | Max: {max_duration:.3f} ms")

            # Write raw data
            raw_file.write(f"{frame_count}\t{capture_duration:.3f}\t{instant_fps:.2f}\n")
            raw_file.flush()

            # Slice first channel for grayscale
            gray_frame = frame[:, :, 0]

            # Display frame
            if show_stream == True:
                if (frame_count % display_every == 0):
                    frame = cv2.resize(frame, (640, 480))
                    cv2.imshow("Camera Stream", frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                        
            if save_image == True:
                if (frame_count % save_every == 0):
                    picam2.capture_file("test_true.jpg")
                    print(frame.shape)
                    print(frame.dtype)
                    print(frame.min(), frame.max())

    finally:
        picam2.stop()
        cv2.destroyAllWindows()
        raw_file.close()

        # Compute overall statistics
        overall_avg = total_duration / frame_count if frame_count > 0 else 0          # <-- ADDED
        overall_avg_fps = 1000.0 / overall_avg if overall_avg > 0 else 0             # <-- ADDED
        median_duration = statistics.median(all_durations) if all_durations else 0   # <-- ADDED
        median_fps = statistics.median(all_fps) if all_fps else 0                    # <-- ADDED
        min_fps = 1000.0 / min_duration if min_duration > 0 else 0
        max_fps = 1000.0 / max_duration if max_duration > 0 else 0

        # Write results file
        with open(results_filename, "w") as f:
            f.write(f"Camera resolution: {camera_resolution}\n")
            f.write(f"Total frames: {frame_count}\n")
            f.write(f"Min capture duration: {min_duration:.3f} ms | FPS: {min_fps:.2f}\n")
            f.write(f"Max capture duration: {max_duration:.3f} ms | FPS: {max_fps:.2f}\n")
            f.write(f"Overall average capture duration: {overall_avg:.3f} ms | FPS: {overall_avg_fps:.2f}\n")  # <-- ADDED
            f.write(f"Median capture duration: {median_duration:.3f} ms | FPS: {median_fps:.2f}\n")          # <-- ADDED

        print("Camera stopped. Raw data saved to", raw_filename)
        print("Results saved to", results_filename)

if __name__ == "__main__":
    main()
