#!/usr/bin/env python3

import cv2
import depthai as dai
import numpy as np
import time
from scipy.spatial.transform import Rotation as R

def show_pipeline(pipeline):
    None

try:
    from luxhelp.pipeline import show_pipeline
except ImportError:
    pass

FPS = 10

LEFT_SOCKET = dai.CameraBoardSocket.CAM_B
RIGHT_SOCKET = dai.CameraBoardSocket.CAM_C

def perturb_calibration(calibration, angle_deg=0.5):
    extrinsics = calibration.getCameraExtrinsics(LEFT_SOCKET, RIGHT_SOCKET)
    rotation_matrix = [row[:3] for row in extrinsics[:3]]
    translation = [row[3] for row in extrinsics[:3]]

    rot_mat = np.array(rotation_matrix)


    # Perturbation rotation matrix    
    random_axis = np.random.uniform(-1, 1, size=3)
    random_axis /= np.linalg.norm(random_axis)
    angle_rad = np.deg2rad(angle_deg)
    perturbation = R.from_rotvec(random_axis * angle_rad).as_matrix()
    perturbed_translation = (perturbation @ np.array(translation)).tolist()

    perturbed_rotation = perturbation @ rot_mat

    tt = calibration.getCameraTranslationVector(LEFT_SOCKET, RIGHT_SOCKET)

    calibration.setCameraExtrinsics(
        LEFT_SOCKET, RIGHT_SOCKET,
        perturbed_rotation.tolist(), perturbed_translation, tt
    )
    return calibration


class FPSCounter:
    def __init__(self):
        self.frameTimes = []

    def tick(self):
        now = time.time()
        self.frameTimes.append(now)
        self.frameTimes = self.frameTimes[-10:]

    def getFps(self):
        if len(self.frameTimes) <= 1:
            return 0
        return (len(self.frameTimes) - 1) / (self.frameTimes[-1] - self.frameTimes[0])


pipeline = dai.Pipeline()

cameraLeft = pipeline.create(dai.node.Camera).build(LEFT_SOCKET, sensorFps=FPS)
cameraRight = pipeline.create(dai.node.Camera).build(RIGHT_SOCKET, sensorFps=FPS)
leftOutput = cameraLeft.requestFullResolutionOutput()
rightOutput = cameraRight.requestFullResolutionOutput()

neuralDepth = pipeline.create(dai.node.NeuralDepth).build(leftOutput, rightOutput, dai.DeviceModelZoo.NEURAL_DEPTH_LARGE)

# Dynamic calibration node
dynCalib = pipeline.create(dai.node.DynamicCalibration)

# Link cameras to dynamic calibration
leftOutput.link(dynCalib.left)
rightOutput.link(dynCalib.right)

confidenceQueue = neuralDepth.confidence.createOutputQueue()
edgeQueue = neuralDepth.edge.createOutputQueue()
disparityQueue = neuralDepth.disparity.createOutputQueue()

inputConfigQueue = neuralDepth.inputConfig.createInputQueue()

# Dynamic calibration queues
dynCalibCalibrationQueue = dynCalib.calibrationOutput.createOutputQueue()
dynCalibCoverageQueue = dynCalib.coverageOutput.createOutputQueue()
dynCalibInputControl = dynCalib.inputControl.createInputQueue()

# Connect to device and start pipeline
with pipeline:
    device = pipeline.getDefaultDevice()
    calibration = device.readCalibration()
    calibration = perturb_calibration(calibration, angle_deg=40.0)
    pipeline.setCalibrationData(calibration)

    pipeline.build()
    show_pipeline(pipeline)

    pipeline.start()
    time.sleep(1)  # Wait for auto exposure to settle

    # Set performance mode and start periodic calibration
    dynCalibInputControl.send(
        dai.DynamicCalibrationControl.setPerformanceMode(
            dai.DynamicCalibrationControl.OPTIMIZE_PERFORMANCE
        )
    )
    dynCalibInputControl.send(dai.DynamicCalibrationControl.startCalibration())
    lastRotationPrintTime = time.time()
    startTime = time.time()
    maxDisparity = 1
    colorMap = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_JET)
    colorMap[0] = [0, 0, 0]  # to make zero-disparity pixels black
    currentConfig = neuralDepth.initialConfig
    fpsCounter = FPSCounter()
    print("For adjusting thresholds, use keys:")
    print(" - 'w': Increase confidence threshold")
    print(" - 's': Decrease confidence threshold")
    print(" - 'd': Increase edge threshold")
    print(" - 'a': Decrease edge threshold")
    print(" - 't': Toggle temporal filtering")
    while pipeline.isRunning():
        fpsCounter.tick()

        # Print rotation matrix every 5 seconds
        currentTime = time.time()
        if currentTime - lastRotationPrintTime >= 5.0:
            currentCalibration = pipeline.getCalibrationData()
            extrinsics_a_to_b_current = currentCalibration.getCameraExtrinsics(dai.CameraBoardSocket.CAM_B, dai.CameraBoardSocket.CAM_C)
            rotation_matrix_a_to_b_current = [row[:3] for row in extrinsics_a_to_b_current[:3]]
            print("\n--- Rotation matrix (CAM_A -> CAM_B) at t={:.1f}s ---".format(currentTime - startTime))
            print(np.array(rotation_matrix_a_to_b_current))
            lastRotationPrintTime = currentTime

        confidenceData = confidenceQueue.get()
        assert isinstance(confidenceData, dai.ImgFrame)
        npConfidence = confidenceData.getFrame()
        colorizedConfidence = cv2.applyColorMap(((npConfidence)).astype(np.uint8), colorMap)
        cv2.imshow("confidence", colorizedConfidence)

        edgeData = edgeQueue.get()
        assert isinstance(edgeData, dai.ImgFrame)
        npEdge = edgeData.getFrame()
        colorizedEdge = cv2.applyColorMap(((npEdge)).astype(np.uint8), colorMap)
        cv2.imshow("edge", colorizedEdge)

        disparityData = disparityQueue.get()
        assert isinstance(disparityData, dai.ImgFrame)
        npDisparity = disparityData.getFrame()
        maxDisparity = max(maxDisparity, np.max(npDisparity))
        colorizedDisparity = cv2.applyColorMap(((npDisparity / maxDisparity) * 255).astype(np.uint8), colorMap)
        cv2.putText(
            colorizedDisparity,
            f"FPS: {fpsCounter.getFps():.2f}",
            (10, 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (255, 255, 255),
            2,
        )
        cv2.imshow("disparity", colorizedDisparity)

        # Check dynamic calibration coverage (non-blocking)
        coverage = dynCalibCoverageQueue.tryGet()
        if coverage is not None:
            print(f"2D Spatial Coverage = {coverage.meanCoverage:.1f}%, Data Acquired = {coverage.dataAcquired:.1f}%")

        # Check for calibration result (non-blocking)
        dynCalibrationResult = dynCalibCalibrationQueue.tryGet()
        if dynCalibrationResult is not None:
            print(f"Dynamic calibration status: {dynCalibrationResult.info}")
            calibrationData = dynCalibrationResult.calibrationData
            if calibrationData:
                print("Successfully calibrated - applying new calibration")
                dynCalibInputControl.send(
                    dai.DynamicCalibrationControl.applyCalibration(calibrationData.newCalibration)
                )
                q = calibrationData.calibrationDifference
                rotDiff = float(np.sqrt(q.rotationChange[0]**2 +
                                        q.rotationChange[1]**2 +
                                        q.rotationChange[2]**2))
                print(f"Rotation difference: {rotDiff:.2f} deg")
                print(f"Sampson error: current={q.sampsonErrorCurrent:.3f}px, new={q.sampsonErrorNew:.3f}px")
                # Reset and continue calibration
                dynCalibInputControl.send(dai.DynamicCalibrationControl.resetData())
                dynCalibInputControl.send(dai.DynamicCalibrationControl.startCalibration())

        key = cv2.waitKey(1)
        if key == ord('q'):
            pipeline.stop()
            break
        if key == ord('w'):
            currentThreshold = currentConfig.getConfidenceThreshold()
            currentConfig.setConfidenceThreshold((currentThreshold + 5) % 255)
            print("Setting confidence threshold to:", currentConfig.getConfidenceThreshold())
            inputConfigQueue.send(currentConfig)
        if key == ord('s'):
            currentThreshold = currentConfig.getConfidenceThreshold()
            currentConfig.setConfidenceThreshold((currentThreshold - 5) % 255)
            print("Setting confidence threshold to:", currentConfig.getConfidenceThreshold())
            inputConfigQueue.send(currentConfig)
        if key == ord('d'):
            currentThreshold = currentConfig.getEdgeThreshold()
            currentConfig.setEdgeThreshold((currentThreshold + 1) % 255)
            print("Setting edge threshold to:", currentConfig.getEdgeThreshold())
            inputConfigQueue.send(currentConfig)
        if key == ord('a'):
            currentThreshold = currentConfig.getEdgeThreshold()
            currentConfig.setEdgeThreshold((currentThreshold - 1) % 255)
            print("Setting edge threshold to:", currentConfig.getEdgeThreshold())
            inputConfigQueue.send(currentConfig)
        if key == ord('t'):
            currentConfig.postProcessing.temporalFilter.enable = not currentConfig.postProcessing.temporalFilter.enable
            print("Temporal filtering:", "on" if currentConfig.postProcessing.temporalFilter.enable else "off")
            inputConfigQueue.send(currentConfig)
