In [None]:
import json
import logging
import math
import os
from collections import deque
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Union

import cv2
import imageio
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Ellipse
from matplotlib.lines import Line2D
from skimage.feature import canny
from skimage.measure import ransac, EllipseModel
from skimage.transform import hough_circle, hough_circle_peaks, hough_line, hough_line_peaks
from sklearn.linear_model import RANSACRegressor, Ridge
from sklearn.preprocessing import PolynomialFeatures

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


@dataclass
class Config:
    """Configuration parameters for the well tracking system"""
    # Data paths
    data_path: str = "/qfs/projects/bioprep/data/automation/new_grid_center_db.2/"
    output_dir: str = "output_videos"
    output_images_dir: str = "output_images"
    output_json_dir: str = "output_json"
    
    # Frame processing
    min_frame: int = 0
    max_frame: int = 4000
    phi_min: float = 110.0
    phi_max: float = 125.0
    loop_count: int = 1  # 0=no loop, 1=forward then reverse, 2=forward-reverse-forward, etc.
    
    # Detection parameters
    total_wells: int = 10
    target_radius: int = 110
    radius_view: int = 10
    roundness_threshold: float = 0.92
    min_circles_required: int = 2
    max_radius_variation: float = 20.0
    
    # Tracking parameters
    buffer_size: int = 10
    ransac_residual_threshold: float = 5.0
    ransac_max_trials: int = 100
    association_threshold_ratio: float = 0.45
    
    # Video output
    video_fps: int = 10
    save_video: bool = True
    save_individual_frames: bool = True
    display_frames: bool = True
    
    # Feature flags
    enable_well_tracking: bool = True
    track_well_centers: bool = True
    use_perpendicular_filter: bool = True
    use_average_mode: bool = True
    use_annular_mask: bool = True
    enable_motor_calibration: bool = True
    
    # Advanced parameters
    ellipse_margin_factor: float = 1.0
    min_edge_points: int = 0
    inner_factor: float = 0.0
    edge_sigma: float = 10.0
    edge_low_threshold: float = 0.15
    edge_high_threshold: float = 0.7
    hull_min_area: int = 200
    border_buffer: int = 2


@dataclass
class WellData:
    """Data structure for well information"""
    x: float
    y: float
    radius: float
    confidence: float = 1.0
    well_id: Optional[int] = None


@dataclass
class MotorPosition:
    """Motor position data"""
    x: float
    y: float
    z: float
    phi: float


class WellCenterTracker:
    """Tracks when each well is closest to the frame center"""
    
    def __init__(self, frame_shape: Optional[Tuple[int, int]] = None):
        self.frame_shape = frame_shape
        self.frame_center = None
        if frame_shape:
            self.set_frame_shape(frame_shape)
        
        self.best_positions: Dict[int, Dict] = {}
        self.all_positions: Dict[int, List] = {}
        
        self.best_frame_for_well_1 = None
        self.line_params_at_best_frame = None
        self.spacing_history = []
        self.well_1_coordinates = None
    
    def set_frame_shape(self, frame_shape: Tuple[int, int]):
        self.frame_shape = frame_shape
        height, width = frame_shape[:2]
        self.frame_center = (width / 2, height / 2)
    
    def update(self, frame_number: int, detected_wells: Dict, motor_data: MotorPosition,
               line_params: Optional[Tuple] = None, well_spacing: Optional[float] = None):
        """Update tracking with new frame data"""
        if not self.frame_center:
            return
        
        if well_spacing is not None:
            self.spacing_history.append(well_spacing)
        
        # Check if this is the best frame for well #1
        if 1 in detected_wells:
            num_detected = len(detected_wells)
            
            if (self.best_frame_for_well_1 is None or 
                num_detected > self.best_frame_for_well_1['num_detected']):
                
                self.best_frame_for_well_1 = {
                    'frame': frame_number,
                    'num_detected': num_detected,
                    'detected_wells': list(detected_wells.keys())
                }
                
                if line_params:
                    self.line_params_at_best_frame = line_params
                
                well_1_info = detected_wells[1]
                self.well_1_coordinates = {
                    'image': {
                        'x': float(well_1_info['x']),
                        'y': float(well_1_info['y']),
                        'radius': float(well_1_info.get('radius', 100))
                    },
                    'motor': {
                        'x': float(motor_data.x),
                        'y': float(motor_data.y),
                        'z': float(motor_data.z),
                        'phi': float(motor_data.phi)
                    }
                }
        
        # Track all wells
        for well_id, well_info in detected_wells.items():
            if well_id is None:
                continue
            
            distance = self._calculate_distance(well_info['x'], well_info['y'])
            
            if well_id not in self.all_positions:
                self.all_positions[well_id] = []
            
            self.all_positions[well_id].append({
                'frame': frame_number,
                'distance': distance,
                'position': (well_info['x'], well_info['y']),
                'motor_data': asdict(motor_data)
            })
            
            if well_id not in self.best_positions or distance < self.best_positions[well_id]['distance']:
                self.best_positions[well_id] = {
                    'frame': frame_number,
                    'distance': distance,
                    'position': (well_info['x'], well_info['y']),
                    'motor_data': asdict(motor_data),
                    'radius': well_info.get('radius')
                }
    
    def _calculate_distance(self, x: float, y: float) -> float:
        return math.sqrt((x - self.frame_center[0])**2 + (y - self.frame_center[1])**2)
    
    def get_average_spacing(self) -> Optional[float]:
        """Calculate average well spacing using median for robustness"""
        if not self.spacing_history:
            return None
        return float(np.median(self.spacing_history))
    
    def save_to_json(self, filename: Optional[str] = None, motor_calibration: Optional['MotorCalibration'] = None) -> str:
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"well_center_positions_{timestamp}.json"
        
        avg_spacing = self.get_average_spacing()
        
        json_data = {
            'metadata': {
                'frame_shape': list(self.frame_shape) if self.frame_shape else None,
                'frame_center': list(self.frame_center) if self.frame_center else None,
                'total_wells_tracked': len(self.best_positions),
                'timestamp': datetime.now().isoformat()
            },
            'calibration_data': {
                'line_parameters': {
                    'slope': float(self.line_params_at_best_frame[0]) if self.line_params_at_best_frame else None,
                    'intercept': float(self.line_params_at_best_frame[1]) if self.line_params_at_best_frame else None,
                    'best_frame': self.best_frame_for_well_1['frame'] if self.best_frame_for_well_1 else None,
                    'wells_detected_in_best_frame': self.best_frame_for_well_1['num_detected'] if self.best_frame_for_well_1 else None
                },
                'average_well_spacing': avg_spacing,
                'spacing_measurements_count': len(self.spacing_history),
                'well_1_reference': self.well_1_coordinates
            },
            'wells': {
                str(well_id): {
                    'well_id': int(well_id),
                    'closest_frame': int(data['frame']),
                    'distance_from_center': float(data['distance']),
                    'position_x': float(data['position'][0]),
                    'position_y': float(data['position'][1]),
                    **{f'motor_{k}': float(v) for k, v in data['motor_data'].items()},
                    'radius': float(data['radius']) if data.get('radius') else None
                }
                for well_id, data in self.best_positions.items()
            }
        }
        
        # Add motor calibration data if available
        if motor_calibration:
            calibration_export = motor_calibration.export_calibration()
            if calibration_export:
                json_data['motor_calibration'] = calibration_export
                logger.info(f"Motor Calibration Data Included: {calibration_export['model_type']} Model With {calibration_export['num_samples']} Samples")
        
        output_dir = Path(Config.output_json_dir)
        output_dir.mkdir(exist_ok=True)
        filepath = output_dir / filename
        
        with open(filepath, 'w') as f:
            json.dump(json_data, f, indent=2)
        
        logger.info(f"Well Center Positions Saved To: {filepath}")
        
        if self.line_params_at_best_frame:
            logger.info(f"Line Parameters: y = {self.line_params_at_best_frame[0]:.4f}x + {self.line_params_at_best_frame[1]:.2f}")
        if avg_spacing:
            logger.info(f"Average Well Spacing: {avg_spacing:.2f} Pixels")
        if self.well_1_coordinates:
            logger.info(f"Well #1 Reference Coordinates Saved")
        
        return str(filepath)


class WellTracker:
    """Adaptive well tracking using line fitting and spatial inference"""
    
    def __init__(self, config: Config):
        self.config = config
        self.frame_number = 0
        
        self.line_params = None  # (slope, intercept) for y = mx + b
        self.well_spacing = None
        self.predicted_positions = None
        
        self.detected_wells: Dict[int, Dict] = {}
        self.unassigned_detections: List[Tuple] = []
        
        self.average_radius = None
        self.spacing_history = deque(maxlen=20)
    
    def get_current_line_params(self) -> Optional[Tuple]:
        return self.line_params
    
    def get_current_spacing(self) -> Optional[float]:
        return self.well_spacing
    
    def update_tracks(self, detected_circles: Optional[Tuple]) -> Tuple[Optional[Tuple], Optional[List]]:
        """Update tracking with new detections"""
        self.frame_number += 1
        self.unassigned_detections = []
        
        if not detected_circles or len(detected_circles[1]) == 0:
            return self._get_predicted_wells_as_circles(), self.get_well_ids()
        
        accum, cx, cy, radii = detected_circles
        detections = [(cx[i], cy[i], radii[i]) for i in range(len(cx))]
        
        if len(radii) > 0:
            self.average_radius = np.mean(radii)
        
        if len(detections) >= 2:
            self._fit_line_and_predict(detections)
        
        if self.predicted_positions:
            self._associate_detections_to_predictions(detections)
            
            # Periodic recalibration
            if len(self.detected_wells) >= 3 and self.frame_number % 10 == 0:
                self._validate_and_recalibrate_assignments()
        else:
            # Initial assignment
            sorted_detections = sorted(detections, key=lambda d: d[0])
            for i, (x, y, r) in enumerate(sorted_detections):
                well_id = self.config.total_wells - i
                self.detected_wells[well_id] = {'x': x, 'y': y, 'radius': r}
        
        return self._get_tracked_wells_as_circles(), self.get_well_ids()
    
    def _fit_line_and_predict(self, detections: List[Tuple]):
        """Fit horizontal line through detections using RANSAC"""
        x_coords = np.array([d[0] for d in detections])
        y_coords = np.array([d[1] for d in detections])
        radii = np.array([d[2] for d in detections])
        
        X = x_coords.reshape(-1, 1)
        
        if len(detections) > 2:
            slope, intercept, inlier_detections = self._ransac_fit(X, y_coords, detections)
        else:
            slope = (y_coords[1] - y_coords[0]) / (x_coords[1] - x_coords[0]) if x_coords[1] != x_coords[0] else 0
            intercept = y_coords[0] - slope * x_coords[0]
            inlier_detections = detections
        
        self.line_params = (slope, intercept)
        self._calculate_spacing(inlier_detections)
        
        if not self.detected_wells and self.well_spacing:
            self._assign_initial_well_ids(inlier_detections)
        
        self._generate_predictions(np.mean(radii))
        
        if self.detected_wells and len(self.detected_wells) >= 2:
            self._calculate_spacing_from_known_wells()
    
    def _ransac_fit(self, X: np.ndarray, y: np.ndarray, detections: List) -> Tuple:
        """Perform RANSAC line fitting"""
        ransac = RANSACRegressor(
            random_state=42,
            min_samples=2,
            residual_threshold=None,
            max_trials=self.config.ransac_max_trials
        )
        ransac.fit(X, y)
        
        slope = ransac.estimator_.coef_[0]
        intercept = ransac.estimator_.intercept_
        inlier_mask = ransac.inlier_mask_
        inlier_detections = [d for i, d in enumerate(detections) if inlier_mask[i]]
        
        # Fallback to median fit if too many outliers
        if len(inlier_detections) < len(detections) * 0.5:
            slope, intercept = self._median_line_fit(X.flatten(), y, detections)
            inlier_detections = detections
        
        return slope, intercept, inlier_detections
    
    def _median_line_fit(self, x_coords: np.ndarray, y_coords: np.ndarray, 
                        detections: List) -> Tuple[float, float]:
        """Robust median-based line fitting"""
        slopes = []
        for i in range(len(detections)):
            for j in range(i + 1, len(detections)):
                if x_coords[j] != x_coords[i]:
                    slopes.append((y_coords[j] - y_coords[i]) / (x_coords[j] - x_coords[i]))
        
        if slopes:
            slope = np.median(slopes)
            intercepts = [y_coords[i] - slope * x_coords[i] for i in range(len(detections))]
            intercept = np.median(intercepts)
            return slope, intercept
        
        return 0, np.mean(y_coords)
    
    def _calculate_spacing(self, detections: List):
        """Calculate well spacing from detections"""
        if len(detections) < 2:
            return
        
        if self.detected_wells and len(self.detected_wells) >= 2:
            self._calculate_spacing_from_known_wells()
            return
        
        positions = sorted([d[0] for d in detections])
        
        spacings = []
        for i in range(len(positions) - 1):
            spacing = positions[i+1] - positions[0]
            if i == 0:
                spacings.append(spacing)
            else:
                n_wells_apart = round(spacing / spacings[0]) if spacings[0] > 0 else 1
                if n_wells_apart > 0:
                    spacings.append(spacing / n_wells_apart)
        
        if spacings:
            self._update_spacing_with_validation(np.median(spacings))
    
    def _calculate_spacing_from_known_wells(self):
        """Calculate spacing using known well IDs"""
        if len(self.detected_wells) < 2:
            return
        
        well_positions = [(well_id, well_info['x'])
                         for well_id, well_info in self.detected_wells.items()]
        well_positions.sort(key=lambda x: x[0])
        
        spacings = []
        for i in range(len(well_positions) - 1):
            id1, pos1 = well_positions[i]
            id2, pos2 = well_positions[i + 1]
            
            n_wells_apart = id2 - id1
            if n_wells_apart > 0:
                unit_spacing = abs(pos2 - pos1) / n_wells_apart
                weight = 1.0 / n_wells_apart  # Inverse distance weighting
                spacings.append((unit_spacing, weight))
        
        if spacings:
            spacings.sort(key=lambda x: x[0])
            values, weights = zip(*spacings)
            weighted_median = np.average(values, weights=weights)
            self._update_spacing_with_validation(weighted_median)
    
    def _update_spacing_with_validation(self, new_spacing: float):
        """Update spacing with validation"""
        # Basic sanity check
        if self.average_radius and new_spacing < self.average_radius:
            if not self.spacing_history:
                self.spacing_history.append(new_spacing)
                self.well_spacing = new_spacing
            return
        
        # Statistical validation for established tracking
        if self.well_spacing and len(self.spacing_history) >= 5:
            spacings_array = np.array(list(self.spacing_history))
            mean_spacing = np.mean(spacings_array)
            std_spacing = np.std(spacings_array)
            
            # Reject outliers beyond 2 standard deviations
            if std_spacing > 0 and abs(new_spacing - mean_spacing) > 2 * std_spacing:
                return
        
        self.spacing_history.append(new_spacing)
        
        # Calculate robust average
        if len(self.spacing_history) >= 3:
            self.well_spacing = np.median(list(self.spacing_history))
        else:
            self.well_spacing = np.mean(list(self.spacing_history))
    
    def _assign_initial_well_ids(self, detections: List):
        """Assign initial well IDs (rightmost to leftmost for numbering 1→10)"""
        sorted_detections = sorted(detections, key=lambda d: d[0], reverse=True)
        
        if len(sorted_detections) <= 2:
            for i, (x, y, r) in enumerate(sorted_detections):
                well_id = i + 1
                self.detected_wells[well_id] = {'x': x, 'y': y, 'radius': r}
            return
        
        positions = [d[0] for d in sorted_detections]
        gaps = [abs(positions[i] - positions[i + 1]) for i in range(len(positions) - 1)]
        
        if gaps:
            min_gap = np.min(gaps)
            max_gap = np.max(gaps)
            gaps_are_uniform = (min_gap > 0 and max_gap / min_gap < 1.5)
        else:
            gaps_are_uniform = True
        
        if gaps_are_uniform:
            for i, (x, y, r) in enumerate(sorted_detections):
                well_id = i + 1
                self.detected_wells[well_id] = {'x': x, 'y': y, 'radius': r}
        else:
            self._assign_with_gaps(sorted_detections, gaps)
    
    def _assign_with_gaps(self, sorted_detections: List, gaps: List):
        """Assign well IDs accounting for gaps"""
        unit_spacing = np.min(gaps) if gaps else 0
        current_id = 1
        
        for i, (x, y, r) in enumerate(sorted_detections):
            self.detected_wells[current_id] = {'x': x, 'y': y, 'radius': r}
            
            if i < len(gaps) and unit_spacing > 0:
                n_spacings = round(gaps[i] / unit_spacing)
                current_id += max(1, n_spacings)
            else:
                current_id += 1
        
        # Reset if IDs exceed total wells
        if max(self.detected_wells.keys()) > self.config.total_wells:
            self.detected_wells = {i + 1: {'x': d[0], 'y': d[1], 'radius': d[2]}
                                  for i, d in enumerate(sorted_detections)}
    
    def _generate_predictions(self, avg_radius: float):
        """Generate predicted positions for all wells"""
        if not self.line_params or not self.well_spacing:
            return
        
        slope, intercept = self.line_params
        self.predicted_positions = {}
        
        if self.detected_wells:
            well_1_x = self._calculate_anchor_position()
            
            # Generate all positions (well 1 is rightmost)
            for i in range(1, self.config.total_wells + 1):
                x = well_1_x - (i - 1) * self.well_spacing
                y = slope * x + intercept
                self.predicted_positions[i] = {'x': x, 'y': y, 'radius': avg_radius}
    
    def _calculate_anchor_position(self) -> float:
        """Calculate anchor x-position from detected wells"""
        if len(self.detected_wells) >= 2:
            positions = []
            
            for well_id, well_info in self.detected_wells.items():
                estimate = well_info['x'] + (well_id - 1) * self.well_spacing
                positions.append(estimate)
            
            anchor = np.median(positions)
            
            # Validate consensus
            mad = np.median(np.abs(positions - anchor))
            if mad > self.well_spacing * self.config.association_threshold_ratio:
                # Fallback to well with lowest ID
                anchor_id = min(self.detected_wells.keys())
                anchor_x = self.detected_wells[anchor_id]['x']
                anchor = anchor_x + (anchor_id - 1) * self.well_spacing
            
            return anchor
        else:
            anchor_id = list(self.detected_wells.keys())[0]
            anchor_x = self.detected_wells[anchor_id]['x']
            return anchor_x + (anchor_id - 1) * self.well_spacing
    
    def _associate_detections_to_predictions(self, detections: List):
        """Associate detections with predicted positions"""
        if not self.predicted_positions:
            return
        
        max_distance = self._get_association_distance()
        
        if max_distance is None:
            self.unassigned_detections = detections
            return
        
        self.detected_wells = {}
        self.unassigned_detections = []
        
        used_predictions = set()
        associations = []
        
        for det_x, det_y, det_r in detections:
            min_dist = float('inf')
            best_well_id = None
            
            for well_id, pred in self.predicted_positions.items():
                if well_id in used_predictions:
                    continue
                
                dist = math.sqrt((det_x - pred['x'])**2 + (det_y - pred['y'])**2)
                if dist < min_dist and dist < max_distance:
                    min_dist = dist
                    best_well_id = well_id
            
            if best_well_id:
                associations.append((best_well_id, det_x, det_y, det_r))
                used_predictions.add(best_well_id)
            else:
                self.unassigned_detections.append((det_x, det_y, det_r))
        
        for well_id, x, y, r in associations:
            self.detected_wells[well_id] = {'x': x, 'y': y, 'radius': r}
        
        if self.unassigned_detections and self.detected_wells and self.well_spacing:
            self._infer_well_ids_from_spatial_clues()
    
    def _get_association_distance(self) -> Optional[float]:
        """Calculate association distance threshold"""
        if self.well_spacing:
            return self.well_spacing * self.config.association_threshold_ratio
        return None
    
    def _infer_well_ids_from_spatial_clues(self):
        """Infer well IDs using spatial relationships"""
        if not self.well_spacing:
            return
            
        newly_assigned = []
        remaining_unassigned = []
        
        spacing_tolerance = self.well_spacing * self.config.association_threshold_ratio
        line_tolerance = spacing_tolerance
        
        for unassigned_x, unassigned_y, unassigned_r in self.unassigned_detections:
            best_inference = None
            best_score = float('inf')
            
            for candidate_id in range(1, self.config.total_wells + 1):
                if candidate_id in self.detected_wells:
                    continue
                
                scores = []
                
                for detected_id, detected_info in self.detected_wells.items():
                    offset = candidate_id - detected_id
                    
                    expected_x = detected_info['x'] - offset * self.well_spacing
                    x_error = abs(unassigned_x - expected_x)
                    y_error = abs(unassigned_y - detected_info['y'])
                    
                    if x_error < spacing_tolerance and y_error < line_tolerance:
                        score = math.sqrt(x_error**2 + y_error**2)
                        scores.append(score)
                
                if scores and min(scores) < best_score:
                    best_score = min(scores)
                    best_inference = candidate_id
            
            if best_inference:
                if self._validate_inference(best_inference, unassigned_x, unassigned_y, line_tolerance):
                    newly_assigned.append((best_inference, unassigned_x, unassigned_y, unassigned_r))
                else:
                    remaining_unassigned.append((unassigned_x, unassigned_y, unassigned_r))
            else:
                remaining_unassigned.append((unassigned_x, unassigned_y, unassigned_r))
        
        for well_id, x, y, r in newly_assigned:
            self.detected_wells[well_id] = {'x': x, 'y': y, 'radius': r}
        
        self.unassigned_detections = remaining_unassigned
        
        if newly_assigned:
            logger.debug(f"Spatially Inferred Well IDs: {[w[0] for w in newly_assigned]}")
    
    def _validate_inference(self, well_id: int, x: float, y: float, tolerance: float) -> bool:
        """Validate inferred well position"""
        current_ids = list(self.detected_wells.keys())
        if current_ids and len(current_ids) <= 3:
            min_id, max_id = min(current_ids), max(current_ids)
            if well_id < min_id - 3 or well_id > max_id + 3:
                return False
        
        if self.line_params:
            slope, intercept = self.line_params
            expected_y = slope * x + intercept
            line_distance = abs(y - expected_y)
            return line_distance < tolerance
        
        return True
    
    def _validate_and_recalibrate_assignments(self):
        """Validate and recalibrate well ID assignments"""
        if len(self.detected_wells) < 3 or not self.well_spacing:
            return
        
        well_positions = [(well_id, well_info['x'])
                         for well_id, well_info in self.detected_wells.items()]
        well_positions.sort(key=lambda x: x[1])
        
        # Check if IDs decrease as x position increases
        expected_order = all(well_positions[i+1][0] < well_positions[i][0]
                            for i in range(len(well_positions) - 1))
        
        if not expected_order:
            self._recalibrate_assignments()
    
    def _recalibrate_assignments(self):
        """Recalibrate well assignments based on spacing"""
        sorted_wells = [(well_info['x'], well_info)
                       for well_info in self.detected_wells.values()]
        sorted_wells.sort(key=lambda x: x[0], reverse=True)
        
        gaps = [abs(sorted_wells[i+1][0] - sorted_wells[i][0])
               for i in range(len(sorted_wells) - 1)]
        
        if gaps:
            median_gap = np.median(gaps)
        else:
            median_gap = self.well_spacing
        
        new_assignments = {}
        current_id = 1
        
        for i, (pos, well_info) in enumerate(sorted_wells):
            new_assignments[current_id] = well_info
            
            if i < len(gaps):
                n_missing = round(gaps[i] / median_gap) - 1 if median_gap > 0 else 0
                current_id += 1 + max(0, n_missing)
            else:
                current_id += 1
        
        if max(new_assignments.keys()) <= self.config.total_wells:
            old_ids = set(self.detected_wells.keys())
            new_ids = set(new_assignments.keys())
            if old_ids != new_ids:
                logger.debug(f"Recalibrated Well IDs: {sorted(old_ids)} → {sorted(new_ids)}")
            
            self.detected_wells = new_assignments
            self._calculate_spacing_from_known_wells()
    
    def _get_tracked_wells_as_circles(self) -> Optional[Tuple]:
        """Convert tracked wells to circle format"""
        if not self.detected_wells and not self.unassigned_detections:
            return None
        
        all_detections = []
        
        for well_id, well_data in sorted(self.detected_wells.items()):
            all_detections.append({
                'x': well_data['x'],
                'y': well_data['y'],
                'radius': well_data['radius'],
                'well_id': well_id
            })
        
        for x, y, r in self.unassigned_detections:
            all_detections.append({
                'x': x, 'y': y, 'radius': r,
                'well_id': None
            })
        
        if not all_detections:
            return None
        
        return (np.ones(len(all_detections)),
                np.array([d['x'] for d in all_detections]),
                np.array([d['y'] for d in all_detections]),
                np.array([d['radius'] for d in all_detections]))
    
    def _get_predicted_wells_as_circles(self) -> Optional[Tuple]:
        """Get predicted well positions as circles"""
        if not self.predicted_positions:
            return None
        
        data = []
        for well_id, pred in sorted(self.predicted_positions.items()):
            if well_id not in self.detected_wells:
                data.append((1.0, pred['x'], pred['y'], pred['radius']))
        
        if not data:
            return None
        
        return tuple(np.array(d) for d in zip(*data))
    
    def get_well_ids(self) -> List[Optional[int]]:
        """Get list of currently tracked well IDs"""
        ids = sorted(self.detected_wells.keys())
        ids.extend([None] * len(self.unassigned_detections))
        return ids
    
    def get_all_predicted_positions(self) -> Optional[Dict]:
        return self.predicted_positions
    
    def get_line_endpoints(self) -> Optional[Tuple]:
        """Get endpoints of fitted line for visualization"""
        if not self.predicted_positions or len(self.predicted_positions) < 2:
            return None
        
        well_10 = self.predicted_positions.get(10)
        well_1 = self.predicted_positions.get(1)
        
        if not well_10 or not well_1:
            return None
        
        return (well_10['x'], well_10['y']), (well_1['x'], well_1['y'])


class FrameBuffer:
    """Buffer for tracking frame statistics and generating phi suggestions"""
    
    def __init__(self, buffer_size: int = 7):
        self.buffer_size = buffer_size
        self.data = deque(maxlen=buffer_size)
    
    def add_frame(self, frame_number: int, phi: float, roundness: float,
                  motor_data: MotorPosition, is_good_frame: bool = False):
        self.data.append({
            'frame_number': frame_number,
            'phi': phi,
            'roundness': roundness,
            'motor_data': motor_data,
            'is_good_frame': is_good_frame
        })
    
    def get_phi_trend(self, min_change: float = 0.5) -> str:
        if len(self.data) < 3:
            return 'Stable'
        
        phi_values = [frame['phi'] for frame in self.data]
        phi_change = phi_values[-1] - phi_values[0]
        
        if abs(phi_change) < min_change:
            return 'Stable'
        return 'Increasing' if phi_change > 0 else 'Decreasing'
    
    def get_roundness_trend(self, min_change: float = 0.05) -> str:
        if len(self.data) < 3:
            return 'Stable'
        
        roundness_values = [frame['roundness'] for frame in self.data]
        roundness_change = roundness_values[-1] - roundness_values[0]
        
        if abs(roundness_change) < min_change:
            return 'Stable'
        return 'Increasing' if roundness_change > 0 else 'Decreasing'
    
    def get_phi_suggestion(self, current_frame_is_good: bool = False) -> str:
        """Generate phi adjustment suggestion"""
        if current_frame_is_good:
            return "Good Frame"
        
        if len(self.data) < 4:
            return "Insufficient Data"
        
        recent_good = sum(1 for frame in list(self.data)[-3:]
                         if frame.get('is_good_frame', False))
        if recent_good >= 2:
            return "Recent Good Frames - Monitor"
        
        phi_trend = self.get_phi_trend()
        roundness_trend = self.get_roundness_trend()
        
        if roundness_trend == 'Increasing' and phi_trend != 'Stable':
            return f"Continue {phi_trend} Phi"
        elif roundness_trend == 'Decreasing' and phi_trend != 'Stable':
            return f"Try {'Decreasing' if phi_trend == 'Increasing' else 'Increasing'} Phi"
        
        return "Monitor Trends"


class MotorCalibration:
    """Learns the transformation between motor movements and pixel shifts"""
    
    def __init__(self, min_samples: int = 10, max_samples: int = 100,
                 use_polynomial: bool = False, alpha: float = 1.0):
        self.min_samples = min_samples
        self.max_samples = max_samples
        self.use_polynomial = use_polynomial
        self.alpha = alpha
        
        self.motor_history = deque(maxlen=max_samples)
        self.pixel_history = deque(maxlen=max_samples)
        self.well_id_history = deque(maxlen=max_samples)
        
        self.model_x = Ridge(alpha=alpha)
        self.model_y = Ridge(alpha=alpha)
        self.poly_features = PolynomialFeatures(degree=2, include_bias=False) if use_polynomial else None
        
        self.is_calibrated = False
        self.calibration_score_x = 0.0
        self.calibration_score_y = 0.0
        self.last_prediction_error = {}
        
        self.prev_motor_data = None
        self.prev_pixel_positions = {}
    
    def add_observation(self, motor_data: MotorPosition, detected_wells: Dict[int, Dict]):
        """Add a new observation of motor positions and corresponding pixel positions"""
        if not detected_wells:
            return
        
        current_motor = np.array([motor_data.x, motor_data.y, motor_data.z])
        
        if self.prev_motor_data is not None:
            prev_motor = np.array([self.prev_motor_data.x, self.prev_motor_data.y, 
                                  self.prev_motor_data.z])
            motor_delta = current_motor - prev_motor
            
            # Only store if there's actual movement
            if np.linalg.norm(motor_delta) > 1e-6:
                for well_id, well_info in detected_wells.items():
                    if well_id in self.prev_pixel_positions:
                        pixel_delta = np.array([
                            well_info['x'] - self.prev_pixel_positions[well_id]['x'],
                            well_info['y'] - self.prev_pixel_positions[well_id]['y']
                        ])
                        
                        self.motor_history.append(motor_delta)
                        self.pixel_history.append(pixel_delta)
                        self.well_id_history.append(well_id)
        
        self.prev_motor_data = motor_data
        self.prev_pixel_positions = detected_wells.copy()
        
        if len(self.motor_history) >= self.min_samples:
            self._train_models()
    
    def _train_models(self):
        """Train the regression models with current data"""
        if len(self.motor_history) < self.min_samples:
            return
        
        X = np.array(self.motor_history)
        y_x = np.array([p[0] for p in self.pixel_history])
        y_y = np.array([p[1] for p in self.pixel_history])
        
        if self.use_polynomial and self.poly_features:
            X = self.poly_features.fit_transform(X)
        
        self.model_x.fit(X, y_x)
        self.model_y.fit(X, y_y)
        
        self.calibration_score_x = self.model_x.score(X, y_x)
        self.calibration_score_y = self.model_y.score(X, y_y)
        
        self.is_calibrated = True
    
    def predict_pixel_shifts(self, motor_delta: np.ndarray) -> Optional[np.ndarray]:
        """Predict pixel shifts based on motor movement"""
        if not self.is_calibrated:
            return None
        
        X = motor_delta.reshape(1, -1)
        
        if self.use_polynomial and self.poly_features:
            X = self.poly_features.transform(X)
        
        dx_pixel = self.model_x.predict(X)[0]
        dy_pixel = self.model_y.predict(X)[0]
        
        return np.array([dx_pixel, dy_pixel])
    
    def predict_well_positions(self, current_motor: MotorPosition, 
                             next_motor: MotorPosition,
                             current_wells: Dict[int, Dict]) -> Dict[int, Dict]:
        """Predict well positions at next frame based on motor movement"""
        if not self.is_calibrated or not current_wells:
            return {}
        
        motor_delta = np.array([
            next_motor.x - current_motor.x,
            next_motor.y - current_motor.y,
            next_motor.z - current_motor.z
        ])
        
        pixel_shift = self.predict_pixel_shifts(motor_delta)
        if pixel_shift is None:
            return {}
        
        predicted_wells = {}
        for well_id, well_info in current_wells.items():
            predicted_wells[well_id] = {
                'x': well_info['x'] + pixel_shift[0],
                'y': well_info['y'] + pixel_shift[1],
                'radius': well_info.get('radius', 100)
            }
        
        return predicted_wells
    
    def update_prediction_error(self, predicted_wells: Dict[int, Dict],
                               actual_wells: Dict[int, Dict]):
        """Calculate and store prediction errors for monitoring"""
        self.last_prediction_error = {}
        
        for well_id in predicted_wells:
            if well_id in actual_wells:
                pred = predicted_wells[well_id]
                actual = actual_wells[well_id]
                
                error = np.sqrt((pred['x'] - actual['x'])**2 + 
                               (pred['y'] - actual['y'])**2)
                self.last_prediction_error[well_id] = error
    
    def get_calibration_info(self) -> Dict:
        """Get calibration status and quality metrics"""
        if not self.is_calibrated:
            return {
                'is_calibrated': False,
                'samples_collected': len(self.motor_history),
                'samples_needed': self.min_samples
            }
        
        avg_error = np.mean(list(self.last_prediction_error.values())) if self.last_prediction_error else 0
        
        return {
            'is_calibrated': True,
            'samples_collected': len(self.motor_history),
            'calibration_score_x': self.calibration_score_x,
            'calibration_score_y': self.calibration_score_y,
            'avg_score': (self.calibration_score_x + self.calibration_score_y) / 2,
            'last_avg_error': avg_error,
            'method': 'Polynomial Ridge' if self.use_polynomial else 'Linear Ridge'
        }
    
    def export_calibration(self) -> Optional[Dict]:
        """Export calibration data for saving to JSON"""
        if not self.is_calibrated:
            return None
        
        calibration_data = {
            'model_type': 'Polynomial Ridge' if self.use_polynomial else 'Linear Ridge',
            'alpha': float(self.alpha),
            'num_samples': len(self.motor_history),
            'calibration_scores': {
                'x': float(self.calibration_score_x),
                'y': float(self.calibration_score_y),
                'average': float((self.calibration_score_x + self.calibration_score_y) / 2)
            }
        }
        
        # Export model coefficients and intercepts
        if self.use_polynomial and self.poly_features:
            calibration_data['polynomial_degree'] = self.poly_features.degree
            calibration_data['polynomial_include_bias'] = self.poly_features.include_bias
            
            try:
                feature_names = self.poly_features.get_feature_names_out(['motor_x', 'motor_y', 'motor_z'])
                calibration_data['feature_names'] = feature_names.tolist()
            except:
                calibration_data['feature_names'] = [f'feature_{i}' for i in range(len(self.model_x.coef_))]
        
        calibration_data['model_x'] = {
            'coefficients': self.model_x.coef_.tolist(),
            'intercept': float(self.model_x.intercept_)
        }
        
        calibration_data['model_y'] = {
            'coefficients': self.model_y.coef_.tolist(),
            'intercept': float(self.model_y.intercept_)
        }
        
        # Store sample statistics for validation
        if len(self.motor_history) > 0:
            motor_array = np.array(self.motor_history)
            pixel_array = np.array(self.pixel_history)
            
            calibration_data['motor_statistics'] = {
                'mean': motor_array.mean(axis=0).tolist(),
                'std': motor_array.std(axis=0).tolist(),
                'min': motor_array.min(axis=0).tolist(),
                'max': motor_array.max(axis=0).tolist()
            }
            
            calibration_data['pixel_statistics'] = {
                'mean': pixel_array.mean(axis=0).tolist(),
                'std': pixel_array.std(axis=0).tolist(),
                'min': pixel_array.min(axis=0).tolist(),
                'max': pixel_array.max(axis=0).tolist()
            }
        
        if self.last_prediction_error:
            calibration_data['last_prediction_errors'] = {
                str(well_id): float(error) 
                for well_id, error in self.last_prediction_error.items()
            }
            calibration_data['avg_prediction_error'] = float(np.mean(list(self.last_prediction_error.values())))
        
        return calibration_data
    
    def reset_calibration(self):
        """Reset calibration data and models"""
        self.motor_history.clear()
        self.pixel_history.clear()
        self.well_id_history.clear()
        self.is_calibrated = False
        self.calibration_score_x = 0.0
        self.calibration_score_y = 0.0
        self.last_prediction_error = {}
        self.prev_motor_data = None
        self.prev_pixel_positions = {}


class ImageProcessor:
    """Image processing utilities"""
    
    @staticmethod
    def find_circles(img: np.ndarray, config: Config) -> Tuple:
        """Detect circles using Hough transform"""
        edge = canny(img, sigma=config.edge_sigma,
                    low_threshold=config.edge_low_threshold,
                    high_threshold=config.edge_high_threshold,
                    use_quantiles=True)
        
        rads = np.arange(config.target_radius - config.radius_view,
                        config.target_radius + config.radius_view + 1)
        out = hough_circle(edge, rads)
        accum, cx, cy, radii = hough_circle_peaks(out, rads,
                                                  min_xdistance=225,
                                                  min_ydistance=225,
                                                  num_peaks=4,
                                                  threshold=0.15)
        return edge, (accum, cx, cy, radii)
    
    @staticmethod
    def extract_contour_coordinates(edge_image: np.ndarray, min_area: int = 100, 
                                   remove_border_points: bool = True, 
                                   border_buffer: int = 2) -> Tuple[Optional[np.ndarray], Optional[List]]:
        """Extract and process contour coordinates from edge image using convex hull"""
        if edge_image.dtype != np.uint8:
            edge_uint8 = (edge_image * 255).astype(np.uint8)
        else:
            edge_uint8 = edge_image
        
        height, width = edge_image.shape
        
        # Morphological operations to connect edges
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        processed = cv2.morphologyEx(edge_uint8, cv2.MORPH_CLOSE, kernel, iterations=1)
        processed = cv2.dilate(processed, kernel, iterations=2)
        
        contours, _ = cv2.findContours(processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if not contours:
            return None, None
        
        # Create convex hull from all contour points
        all_points = np.vstack([contour.reshape(-1, 2) for contour in contours])
        hull = cv2.convexHull(all_points)
        hull_coords = hull.reshape(-1, 2)
        
        hull_area = cv2.contourArea(hull)
        if hull_area < min_area:
            return None, None
        
        if remove_border_points:
            mask = (
                (hull_coords[:, 0] >= border_buffer) &
                (hull_coords[:, 0] < width - border_buffer) &
                (hull_coords[:, 1] >= border_buffer) &
                (hull_coords[:, 1] < height - border_buffer)
            )
            
            kept_indices = np.where(mask)[0]
            filtered_hull_coords = hull_coords[mask]
            segments = ImageProcessor._create_segments_from_hull(hull_coords, kept_indices)
        else:
            filtered_hull_coords = hull_coords
            segments = [hull_coords]
        
        return filtered_hull_coords, segments
    
    @staticmethod
    def _create_segments_from_hull(hull_coords: np.ndarray, kept_indices: np.ndarray) -> List:
        """Create contiguous segments from hull coordinates"""
        segments = []
        if len(kept_indices) <= 1:
            return segments
        
        gaps = np.diff(kept_indices) > 1
        
        if not np.any(gaps):
            segments.append(hull_coords[kept_indices])
        else:
            gap_positions = np.where(gaps)[0] + 1
            start_idx = 0
            for gap_pos in gap_positions:
                if gap_pos > start_idx:
                    segment_indices = kept_indices[start_idx:gap_pos]
                    if len(segment_indices) > 1:
                        segments.append(hull_coords[segment_indices])
                start_idx = gap_pos
            
            if start_idx < len(kept_indices):
                segment_indices = kept_indices[start_idx:]
                if len(segment_indices) > 1:
                    segments.append(hull_coords[segment_indices])
        
        return segments
    
    @staticmethod
    def find_lines(contour_coords: Optional[np.ndarray], img_shape: Tuple, 
                  segments: Optional[List] = None, config: Optional[Config] = None,
                  backup_img: Optional[np.ndarray] = None, 
                  border_buffer: Optional[int] = None) -> Tuple:
        """Detect lines using Hough transform with hull-based or edge-based fallback"""
        threshold = 80
        min_distance = 20
        min_angle = 80
        num_peaks = 4
        backup_sigma = 15.0
        backup_low = 0.2
        backup_high = 0.7
        
        if border_buffer is None:
            border_buffer = config.border_buffer if config else 3
        
        primary_lines = []
        contour_img = None
        ph = None
        
        # Try to detect lines from hull/contour coordinates first
        if contour_coords is not None and len(contour_coords) > 0 and img_shape is not None:
            height, width = img_shape
            contour_img = np.zeros((height, width), dtype=np.uint8)
            
            if segments is not None:
                for segment in segments:
                    if len(segment) > 1:
                        for i in range(len(segment) - 1):
                            pt1 = tuple(segment[i].astype(int))
                            pt2 = tuple(segment[i + 1].astype(int))
                            cv2.line(contour_img, pt1, pt2, 255, thickness=2)
            else:
                for point in contour_coords:
                    cv2.circle(contour_img, tuple(point.astype(int)), 1, 255, -1)
            
            angs = np.linspace(-np.pi/2, np.pi/2, 360, endpoint=False)
            h, theta, d = hough_line(contour_img, angs)
            ph, pang, pdist = hough_line_peaks(h, theta, d, threshold=threshold, 
                                              min_distance=min_distance, min_angle=min_angle, 
                                              num_peaks=num_peaks)
            
            xline = np.arange(contour_img.shape[1])
            for ang, dist in zip(pang, pdist):
                x0, y0 = dist * np.cos(ang), dist * np.sin(ang)
                if x0 == 0 or y0 == 0:
                    continue
                m = y0 / x0
                m2 = -1 / m
                yline = m2 * (xline - x0) + y0
                sel = np.logical_and(yline > 0, yline < contour_img.shape[0])
                if np.any(sel):
                    primary_lines.append((xline[sel], yline[sel]))
        
        if len(primary_lines) > 0:
            return contour_img, ph, primary_lines
        
        # Fallback to direct edge detection if no lines found from hull
        if backup_img is not None:
            edge = canny(backup_img, sigma=backup_sigma, low_threshold=backup_low, 
                        high_threshold=backup_high, use_quantiles=True)
            angs = np.linspace(-np.pi/2, np.pi/2, 360, endpoint=False)
            h, theta, d = hough_line(edge, angs)
            ph, pang, pdist = hough_line_peaks(h, theta, d, threshold=threshold, 
                                              min_distance=min_distance, min_angle=min_angle, 
                                              num_peaks=num_peaks)
            
            backup_lines = []
            xline = np.arange(edge.shape[1])
            height, width = edge.shape
            
            for ang, dist in zip(pang, pdist):
                x0, y0 = dist * np.cos(ang), dist * np.sin(ang)
                if x0 == 0 or y0 == 0:
                    continue
                m = y0 / x0
                m2 = -1 / m
                yline = m2 * (xline - x0) + y0
                sel = np.logical_and(yline > 0, yline < edge.shape[0])
                
                if np.any(sel):
                    x_coords = xline[sel]
                    y_coords = yline[sel]
                    
                    # Filter out border lines
                    is_border_line = (
                        np.any(x_coords <= border_buffer) or 
                        np.any(x_coords >= width - border_buffer) or
                        np.any(y_coords <= border_buffer) or 
                        np.any(y_coords >= height - border_buffer)
                    )
                    
                    if not is_border_line:
                        backup_lines.append((x_coords, y_coords))
            
            return edge, ph, backup_lines
        
        return None, None, []
    
    @staticmethod
    def fit_ellipses_to_circles(edge_img: np.ndarray, circles: Tuple,
                               config: Config) -> Tuple[List, List]:
        """Fit ellipses to detected circles"""
        _, cx, cy, radii = circles
        ellipse_params = []
        fitting_methods = []
        
        for center_x, center_y, radius in zip(cx, cy, radii):
            outer_radius = int(radius * config.ellipse_margin_factor)
            inner_radius = int(radius * config.inner_factor) if config.use_annular_mask else 0
            
            y_indices, x_indices = np.ogrid[:edge_img.shape[0], :edge_img.shape[1]]
            
            if config.use_annular_mask:
                mask = ((x_indices - center_x)**2 + (y_indices - center_y)**2 <= outer_radius**2) & \
                       ((x_indices - center_x)**2 + (y_indices - center_y)**2 >= inner_radius**2)
            else:
                mask = (x_indices - center_x)**2 + (y_indices - center_y)**2 <= outer_radius**2
            
            edge_points = np.column_stack(np.where(edge_img & mask))
            
            if len(edge_points) < config.min_edge_points + 5:
                continue
            
            edge_points_xy = edge_points[:, [1, 0]].astype(np.float64)
            ellipse = None
            method = "Fallback"
            
            try:
                model, inliers = ransac(edge_points_xy,
                                      EllipseModel,
                                      min_samples=5,
                                      residual_threshold=config.ransac_residual_threshold,
                                      max_trials=config.ransac_max_trials)
                if model and model.params:
                    xc, yc, a, b, theta = model.params
                    ellipse = ((xc, yc), (2*a, 2*b), np.degrees(theta))
                    method = "RANSAC"
            except Exception:
                pass
            
            if ellipse is None:
                try:
                    ellipse = cv2.fitEllipse(edge_points_xy.astype(np.float32))
                    method = "Fallback"
                except cv2.error:
                    continue
            
            ellipse_params.append(ellipse)
            fitting_methods.append(method)
        
        return ellipse_params, fitting_methods
    
    @staticmethod
    def evaluate_perpendicularity(ellipse_params: List, config: Config) -> Tuple[bool, Dict]:
        """Evaluate if tray is perpendicular based on ellipses"""
        if not ellipse_params:
            return False, {"error": "No Ellipses", "mean_roundness": 0.0}
        
        roundness_scores = []
        for (cx, cy), (width, height), angle in ellipse_params:
            major_axis = max(width, height)
            minor_axis = min(width, height)
            roundness = minor_axis / major_axis if major_axis > 0 else 0
            roundness_scores.append(roundness)
        
        round_count = sum(1 for score in roundness_scores
                         if score >= config.roundness_threshold)
        
        if config.use_average_mode:
            is_perpendicular = (len(ellipse_params) >= config.min_circles_required and
                              np.mean(roundness_scores) >= config.roundness_threshold)
        else:
            is_perpendicular = round_count >= config.min_circles_required
        
        return is_perpendicular, {
            "roundness_scores": roundness_scores,
            "mean_roundness": np.mean(roundness_scores),
            "round_circles_count": round_count,
            "total_circles": len(ellipse_params)
        }


class Visualizer:
    """Visualization utilities"""
    
    @staticmethod
    def add_ellipse_plots(ax, ellipse_params: List, roundness_scores: List,
                         config: Config, fitting_methods: List):
        """Add ellipse visualization to plot"""
        for ellipse, roundness, method in zip(ellipse_params, roundness_scores, fitting_methods):
            (cx, cy), (width, height), angle = ellipse
            
            color = 'lime' if roundness >= config.roundness_threshold else 'orange'
            alpha = 0.7 if roundness >= config.roundness_threshold else 0.5
            linestyle = '-' if method == 'RANSAC' else '--'
            
            ellipse_patch = Ellipse((cx, cy), width, height, angle=angle,
                                   ec=color, fc='none', linewidth=3,
                                   alpha=alpha, linestyle=linestyle)
            ax.add_patch(ellipse_patch)
            
            label = f'{roundness:.2f}\n{method[0]}'
            ax.text(cx + width/4, cy + height/4, label,
                   fontsize=10, color=color, fontweight='bold',
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.7))
    
    @staticmethod
    def add_motor_predictions(ax, motor_predicted_wells: Dict[int, Dict]):
        """Add motor-based predictions visualization"""
        if not motor_predicted_wells:
            return
        
        for well_id, pred in motor_predicted_wells.items():
            circle = plt.Circle((pred['x'], pred['y']), pred.get('radius', 100),
                              ec='magenta', fc='none', ls='--', alpha=0.6, lw=2)
            ax.add_patch(circle)
            
            ax.text(pred['x'] - 20, pred['y'] - 20, f'M#{well_id}',
                   fontsize=9, color='magenta', alpha=0.8,
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.5))
    
    @staticmethod
    def add_well_tracking_visualization(ax, tracked_circles: Optional[Tuple],
                                       well_ids: Optional[List],
                                       well_tracker: WellTracker,
                                       motor_predicted_wells: Optional[Dict] = None):
        """Add well tracking visualization with motor predictions"""
        if not well_tracker:
            return
        
        if motor_predicted_wells:
            Visualizer.add_motor_predictions(ax, motor_predicted_wells)
        
        line_endpoints = well_tracker.get_line_endpoints()
        if line_endpoints:
            (x1, y1), (x2, y2) = line_endpoints
            dx, dy = x2 - x1, y2 - y1
            length = np.sqrt(dx**2 + dy**2)
            if length > 0:
                dx_norm, dy_norm = dx / length, dy / length
                extend = 50
                ax.plot([x1 - dx_norm * extend, x2 + dx_norm * extend],
                       [y1 - dy_norm * extend, y2 + dy_norm * extend],
                       'yellow', linewidth=3, alpha=0.5, linestyle='--',
                       label='Fitted Line')
        
        detected_ids = set([id for id in well_ids if id]) if well_ids else set()
        predicted_positions = well_tracker.get_all_predicted_positions()
        
        if predicted_positions:
            for well_id, pred in predicted_positions.items():
                if well_id not in detected_ids:
                    circle = plt.Circle((pred['x'], pred['y']), pred['radius'],
                                      ec='yellow', fc='none', ls=':', alpha=0.3, lw=2)
                    ax.add_patch(circle)
                    ax.text(pred['x'], pred['y'], f'#{well_id}',
                           ha='center', va='center', fontsize=10,
                           color='yellow', alpha=0.5,
                           bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.3))
        
        if tracked_circles and well_ids:
            accum, cx, cy, radii = tracked_circles
            
            for x, y, r, well_id, conf in zip(cx, cy, radii, well_ids, accum):
                if well_id:
                    color = 'lime'
                    label = f'#{well_id}'
                    
                    circle = plt.Circle((x, y), r, ec=color, fc='none',
                                      ls='-', alpha=0.8, lw=4)
                    ax.add_patch(circle)
                    
                    ax.text(x, y, label, ha='center', va='center',
                           fontsize=11, fontweight='bold', color='white',
                           bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.7))
                else:
                    circle = plt.Circle((x, y), r, ec='orange', fc='none',
                                      ls='-', alpha=0.8, lw=4)
                    ax.add_patch(circle)
                    
                    ax.text(x, y, '?', ha='center', va='center',
                           fontsize=14, fontweight='bold', color='white',
                           bbox=dict(boxstyle='round,pad=0.3', facecolor='orange', alpha=0.7))


class WellTrackingSystem:
    """Main system for well tracking"""
    
    def __init__(self, config: Config):
        self.config = config
        self.frame_buffer = FrameBuffer(config.buffer_size)
        self.well_tracker = WellTracker(config) if config.enable_well_tracking else None
        self.well_center_tracker = (WellCenterTracker()
                                   if config.track_well_centers and config.enable_well_tracking
                                   else None)
        self.motor_calibration = (MotorCalibration(min_samples=10, max_samples=100, 
                                                  use_polynomial=False, alpha=1.0)
                                 if config.enable_motor_calibration else None)
        self.image_processor = ImageProcessor()
        self.visualizer = Visualizer()
        
        self.frames_processed = 0
        self.frames_skipped = 0
        self.frames_with_tracking = 0
        
        self.prev_motor_data = None
        self.prev_detected_wells = {}
        self.motor_predicted_wells = {}
        
        # Create output directories
        for dir_path in [config.output_dir, config.output_images_dir, config.output_json_dir]:
            Path(dir_path).mkdir(exist_ok=True)
    
    def load_frame_data(self, frame_number: int) -> Tuple[Optional[MotorPosition], Optional[np.ndarray]]:
        """Load frame data from .npz files"""
        try:
            data = np.load(f"{self.config.data_path}test{frame_number}.npz")
            motor_pos = MotorPosition(
                x=float(data['x']),
                y=float(data['y']),
                z=float(data['z']),
                phi=float(data['phi'])
            )
            return motor_pos, data['sample']
        except Exception as e:
            logger.debug(f"Error Loading Frame {frame_number}: {e}")
            return None, None
    
    def process_frame_detection(self, frame_number: int, img: np.ndarray,
                               motor_data: MotorPosition) -> Dict:
        """Process frame for detection and perpendicularity evaluation"""
        if self.well_center_tracker and self.well_center_tracker.frame_shape is None:
            self.well_center_tracker.set_frame_shape(img.shape)
        
        # Perform detection and perpendicularity evaluation
        edge, circles = self.image_processor.find_circles(img, self.config)
        
        ellipse_params, fitting_methods = self.image_processor.fit_ellipses_to_circles(
            edge, circles, self.config)
        is_perpendicular, analysis_results = self.image_processor.evaluate_perpendicularity(
            ellipse_params, self.config)
        
        analysis_results['ellipse_params'] = ellipse_params
        analysis_results['fitting_methods'] = fitting_methods
        
        # Extract hull contours and find lines for visualization
        contour_coords, segments = self.image_processor.extract_contour_coordinates(
            edge, min_area=self.config.hull_min_area, remove_border_points=True, 
            border_buffer=self.config.border_buffer)
        
        if contour_coords is not None and segments is not None:
            contour_img, ph, lines = self.image_processor.find_lines(
                contour_coords, img.shape, segments, self.config, backup_img=img, 
                border_buffer=self.config.border_buffer)
        else:
            contour_img, ph, lines = self.image_processor.find_lines(
                None, img.shape, None, self.config, backup_img=img, 
                border_buffer=self.config.border_buffer)
        
        mean_roundness = analysis_results.get('mean_roundness', 0)
        is_good_frame = (is_perpendicular and
                        analysis_results.get('total_circles', 0) >= self.config.min_circles_required)
        
        return {
            'img': img,
            'edge': edge,
            'circles': circles,
            'lines': lines,
            'contour_coords': contour_coords,
            'segments': segments,
            'ellipse_params': ellipse_params,
            'fitting_methods': fitting_methods,
            'is_perpendicular': is_perpendicular,
            'analysis_results': analysis_results,
            'is_good_frame': is_good_frame,
            'mean_roundness': mean_roundness
        }
    
    def update_tracking(self, frame_number: int, detection_results: Dict, 
                       motor_data: MotorPosition) -> Dict:
        """Update tracking based on detection results"""
        tracking_results = {
            'tracked_circles': None,
            'well_ids': None,
            'phi_suggestion': 'Not Tracking',
            'motor_predicted_wells': {},
            'motor_calibration_info': {}
        }
        
        if not self.config.enable_well_tracking or not self.well_tracker:
            return tracking_results
        
        # Update tracking
        tracked_circles, well_ids = self.well_tracker.update_tracks(detection_results['circles'])
        tracking_results['tracked_circles'] = tracked_circles
        tracking_results['well_ids'] = well_ids
        
        # Motor calibration updates
        if self.motor_calibration and self.well_tracker.detected_wells:
            self.motor_calibration.add_observation(motor_data, self.well_tracker.detected_wells)
            
            if self.prev_motor_data and self.prev_detected_wells:
                self.motor_predicted_wells = self.motor_calibration.predict_well_positions(
                    self.prev_motor_data, motor_data, self.prev_detected_wells
                )
                
                if self.motor_predicted_wells:
                    self.motor_calibration.update_prediction_error(
                        self.motor_predicted_wells, self.well_tracker.detected_wells
                    )
            
            tracking_results['motor_predicted_wells'] = self.motor_predicted_wells
            tracking_results['motor_calibration_info'] = self.motor_calibration.get_calibration_info()
        
        # Update well center tracker
        if self.well_center_tracker and self.well_tracker.detected_wells:
            line_params = self.well_tracker.get_current_line_params()
            well_spacing = self.well_tracker.get_current_spacing()
            
            self.well_center_tracker.update(frame_number,
                                           self.well_tracker.detected_wells,
                                           motor_data,
                                           line_params,
                                           well_spacing)
        
        # Update frame buffer and get suggestions
        self.frame_buffer.add_frame(frame_number, motor_data.phi, 
                                   detection_results['mean_roundness'],
                                   motor_data, detection_results['is_good_frame'])
        tracking_results['phi_suggestion'] = self.frame_buffer.get_phi_suggestion(
            detection_results['is_good_frame'])
        
        # Store for next frame
        self.prev_motor_data = motor_data
        self.prev_detected_wells = self.well_tracker.detected_wells.copy()
        
        return tracking_results
    
    def create_visualization(self, frame_number: int, results: Dict,
                           motor_data: MotorPosition) -> plt.Figure:
        """Create visualization figure"""
        fig = plt.figure(figsize=(24, 24))
        gs = GridSpec(3, 2, figure=fig, height_ratios=[0.8, 1.5, 2], hspace=0.3)
        
        axes = [fig.add_subplot(gs[0, 0]),
                fig.add_subplot(gs[0, 1]),
                fig.add_subplot(gs[1, :]),
                fig.add_subplot(gs[2, :])]
        
        # Subplot 1: Original Image
        axes[0].imshow(results['img'], cmap='gray', aspect='equal')
        axes[0].set_title('Original Image', fontsize=14)
        axes[0].axis('off')
        
        # Subplot 2: Canny Edges
        axes[1].imshow(results['edge'], cmap='gray', aspect='equal')
        axes[1].set_title('Canny Edges', fontsize=14)
        axes[1].axis('off')
        
        # Subplot 3: Ellipse Analysis
        axes[2].imshow(results['img'], cmap='gray', aspect='equal')
        
        # Draw hull/contour if available
        contour_coords = results.get('contour_coords')
        segments = results.get('segments')
        if contour_coords is not None and segments is not None:
            for i, segment in enumerate(segments):
                if len(segment) > 1:
                    axes[2].plot(segment[:, 0], segment[:, 1], 'r-', linewidth=3, 
                            alpha=0.8, label='Convex Hull Boundary' if i == 0 else "")
            axes[2].scatter(contour_coords[:, 0], contour_coords[:, 1], c='red', s=30, 
                       alpha=0.8, zorder=5, label='Hull Vertices')
        
        # Draw lines
        first_line = True
        for xline, yline in results['lines']:
            label = 'Detected Lines' if first_line else None
            axes[2].plot(xline, yline, 'cyan', linewidth=2, alpha=0.9, linestyle='--', label=label)
            first_line = False
        
        # Draw detected circles
        _, cx, cy, radii = results['circles']
        first_circle = True
        for x, y, r in zip(cx, cy, radii):
            label = 'Hough Circles' if first_circle else None
            circle = plt.Circle((x, y), r, ec='cyan', fc='none', ls='--', alpha=0.8, lw=3, label=label)
            axes[2].add_patch(circle)
            first_circle = False
        
        # Add ellipses
        if results['ellipse_params']:
            self.visualizer.add_ellipse_plots(axes[2],
                                             results['ellipse_params'],
                                             results['analysis_results']['roundness_scores'],
                                             self.config,
                                             results['fitting_methods'])
        
        legend_elements = [
            Line2D([0], [0], color='red', lw=3, label='Convex Hull'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=8, lw=0, label='Hull Vertices'),
            Line2D([0], [0], color='cyan', lw=2, ls='--', label='Detected Lines'),
            Line2D([0], [0], color='cyan', lw=3, ls='--', label='Hough Circles'),
            Line2D([0], [0], color='lime', lw=3, ls='-', label=f'Round (≥{self.config.roundness_threshold:.2f})'),
            Line2D([0], [0], color='orange', lw=3, ls='-', label=f'Not Round (<{self.config.roundness_threshold:.2f})'),
            Line2D([0], [0], color='gray', lw=3, ls='-', label='RANSAC Fit'),
            Line2D([0], [0], color='gray', lw=3, ls='--', label='Fallback Fit')
        ]
        axes[2].legend(handles=legend_elements, loc='upper right', fontsize=9, 
                      framealpha=0.9, edgecolor='white')
        
        axes[2].set_title(f"Ellipse Analysis | Roundness: {results['analysis_results']['mean_roundness']:.3f}")
        axes[2].axis('off')
        
        # Subplot 4: Well Tracking
        axes[3].imshow(results['img'], cmap='gray')
        
        # Draw hull/contour if available
        contour_coords = results.get('contour_coords')
        segments = results.get('segments')
        if contour_coords is not None and segments is not None:
            for i, segment in enumerate(segments):
                if len(segment) > 1:
                    axes[3].plot(segment[:, 0], segment[:, 1], 'r-', linewidth=3, 
                            alpha=0.8, label='Convex Hull Boundary' if i == 0 else "")
            axes[3].scatter(contour_coords[:, 0], contour_coords[:, 1], c='red', s=30, 
                       alpha=0.8, zorder=5, label='Hull Vertices')
        
        # Draw lines
        first_line = True
        for xline, yline in results['lines']:
            label = 'Detected Lines' if first_line else None
            axes[3].plot(xline, yline, 'cyan', linewidth=2, alpha=0.9, linestyle='--', label=label)
            first_line = False
        
        # Draw the original detected circles
        _, cx, cy, radii = results['circles']
        for x, y, r in zip(cx, cy, radii):
            circle = plt.Circle((x, y), r, ec='blue', fc='none', ls=':', alpha=0.4, lw=2)
            axes[3].add_patch(circle)
        
        # Add well tracking visualization
        if results.get('tracked_circles') and results.get('well_ids'):
            self.visualizer.add_well_tracking_visualization(
                axes[3],
                results['tracked_circles'],
                results['well_ids'],
                self.well_tracker,
                results.get('motor_predicted_wells')
            )
        
        title = (f"Frame {frame_number} - φ={motor_data.phi:.1f}° - "
                f"Round: {results['analysis_results']['round_circles_count']}/"
                f"{results['analysis_results']['total_circles']}")
        axes[3].set_title(title)
        
        # Motor position box
        motor_text = (f"Motor Positions\n"
                     f"X: {motor_data.x:.3f}\n"
                     f"Y: {motor_data.y:.3f}\n"
                     f"Z: {motor_data.z:.3f}\n"
                     f"φ: {motor_data.phi:.3f}°")
        axes[3].text(0.02, 0.82, motor_text,
                    transform=axes[3].transAxes, fontsize=11, fontweight='bold',
                    verticalalignment='top', horizontalalignment='left',
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.9, 
                             edgecolor='darkblue', linewidth=2))
        
        # Phi suggestion box
        phi_suggestion = results.get('phi_suggestion', 'Not Tracking')
        suggestion_color = 'lime' if phi_suggestion == "Good Frame" else 'yellow'
        axes[3].text(0.02, 0.98, f"Suggestion: {phi_suggestion}",
                    transform=axes[3].transAxes, fontsize=14, fontweight='bold',
                    verticalalignment='top', horizontalalignment='left',
                    bbox=dict(boxstyle='round', facecolor=suggestion_color, alpha=0.8))
        
        # Motor calibration info
        calibration_info = results.get('motor_calibration_info', {})
        if calibration_info:
            if calibration_info.get('is_calibrated'):
                cal_text = (f"Motor Calibration: {calibration_info['method']}\n"
                          f"Score: {calibration_info['avg_score']:.3f}\n"
                          f"Avg Error: {calibration_info['last_avg_error']:.1f}px")
                cal_color = 'green' if calibration_info['avg_score'] > 0.8 else 'orange'
            else:
                cal_text = (f"Motor Calibration: Learning...\n"
                          f"Samples: {calibration_info['samples_collected']}/{calibration_info['samples_needed']}")
                cal_color = 'gray'
            
            axes[3].text(0.98, 0.98, cal_text,
                        transform=axes[3].transAxes, fontsize=10,
                        verticalalignment='top', horizontalalignment='right',
                        bbox=dict(boxstyle='round', facecolor=cal_color, alpha=0.6))
        
        # Legend for subplot 4
        legend_elements = [
            Line2D([0], [0], color='red', lw=3, label='Convex Hull'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=8, lw=0, label='Hull Vertices'),
            Line2D([0], [0], color='blue', lw=2, ls=':', alpha=0.4, label='Original Circles'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='lime', 
                   markersize=10, lw=3, markeredgecolor='lime', label='Detected Wells'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='none', 
                   markersize=10, lw=2, markeredgecolor='yellow', linestyle=':', label='Predicted (Line)'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='none', 
                   markersize=10, lw=2, markeredgecolor='magenta', linestyle='--', label='Predicted (Motor)'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='none', 
                   markersize=10, lw=3, markeredgecolor='orange', label='Unassigned'),
            Line2D([0], [0], color='yellow', lw=3, ls='--', alpha=0.5, label='Fitted Line'),
            Line2D([0], [0], color='cyan', lw=2, ls='--', alpha=0.9, label='Detected Lines')
        ]
        axes[3].legend(handles=legend_elements, loc='lower right', fontsize=9, 
                      framealpha=0.9, edgecolor='white', ncol=2)
        
        axes[3].axis('off')
        
        return fig
    
    def generate_frame_sequence(self) -> List[int]:
        """Generate frame sequence based on loop_count parameter"""
        frames = []
        
        forward_frames = list(range(self.config.min_frame, self.config.max_frame))
        reverse_frames = list(range(self.config.max_frame - 1, self.config.min_frame, -1))
        
        for i in range(self.config.loop_count + 1):
            if i == 0:
                frames.extend(forward_frames)
            elif i % 2 == 1:
                frames.extend(reverse_frames)
            else:
                frames.extend(forward_frames[1:])  # Skip first frame to avoid duplicate
        
        return frames
    
    def run(self):
        """Main processing loop"""
        try:
            from IPython.display import clear_output, display
            use_ipython = True
        except ImportError:
            use_ipython = False
            import os
        
        writer = None
        video_path = None
        
        if self.config.save_video:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            video_filename = f"well_tracking_{timestamp}.mp4"
            video_path = Path(self.config.output_dir) / video_filename
            writer = imageio.get_writer(video_path, fps=self.config.video_fps)
            logger.info(f"Recording Video To: {video_path}")
        
        # Generate frame sequence based on loop_count
        frame_sequence = self.generate_frame_sequence()
        total_frames_to_process = len(frame_sequence)
        
        if self.config.loop_count > 0:
            logger.info(f"Loop Mode: {self.config.loop_count} {'Reversal' if self.config.loop_count == 1 else 'Reversals'}")
            logger.info(f"Total Frames To Process: {total_frames_to_process}")
        
        try:
            frame_index = 0
            current_direction = "Forward"
            loop_iteration = 0
            
            for frame_number in frame_sequence:
                # Track direction changes for display
                if frame_index > 0:
                    prev_frame = frame_sequence[frame_index - 1]
                    if frame_number < prev_frame and current_direction == "Forward":
                        current_direction = "Reverse"
                        loop_iteration += 1
                        logger.info(f"\n--- Starting Reverse Pass (Loop {loop_iteration}) ---\n")
                    elif frame_number > prev_frame and current_direction == "Reverse":
                        current_direction = "Forward"
                        loop_iteration += 1
                        logger.info(f"\n--- Starting Forward Pass (Loop {loop_iteration}) ---\n")
                
                frame_index += 1
                if self.config.display_frames:
                    if use_ipython:
                        clear_output(wait=True)
                    else:
                        os.system('cls' if os.name == 'nt' else 'clear')
                
                motor_data, img = self.load_frame_data(frame_number)
                
                if motor_data is None or img is None:
                    continue
                
                if not (self.config.phi_min <= motor_data.phi <= self.config.phi_max):
                    continue
                
                # Detect and evaluate perpendicularity
                detection_results = self.process_frame_detection(frame_number, img, motor_data)
                
                # Check if frame is perpendicular
                if self.config.use_perpendicular_filter and not detection_results['is_perpendicular']:
                    self.frames_skipped += 1
                    
                    progress_str = f"[{frame_index}/{total_frames_to_process}]" if self.config.loop_count > 0 else ""
                    direction_str = f"({current_direction})" if self.config.loop_count > 0 else ""
                    
                    print(f"Skipped Frame {frame_number} {progress_str} {direction_str} - Not Perpendicular - "
                          f"φ={motor_data.phi:.1f}° - "
                          f"Roundness: {detection_results['mean_roundness']:.3f}")
                    continue
                
                # Frame is perpendicular - update tracking
                tracking_results = self.update_tracking(frame_number, detection_results, motor_data)
                
                # Combine detection and tracking results
                results = {**detection_results, **tracking_results}
                
                self.frames_processed += 1
                if results['tracked_circles']:
                    self.frames_with_tracking += 1
                
                progress_str = f"[{frame_index}/{total_frames_to_process}]" if self.config.loop_count > 0 else ""
                direction_str = f"({current_direction})" if self.config.loop_count > 0 else ""
                
                print(f"Processing Frame {frame_number} {progress_str} {direction_str} - φ={motor_data.phi:.1f}° - "
                      f"Roundness: {results['analysis_results']['mean_roundness']:.3f} - "
                      f"Suggestion: {results['phi_suggestion']}")
                
                print(f"Motor: X={motor_data.x:.3f}, Y={motor_data.y:.3f}, Z={motor_data.z:.3f}")
                
                if self.well_tracker and self.well_tracker.detected_wells:
                    assigned_count = len([id for id in results.get('well_ids', []) if id])
                    print(f"Wells: {assigned_count}/{self.config.total_wells} Detected")
                    
                    if self.well_tracker.well_spacing:
                        print(f"Well Spacing: {self.well_tracker.well_spacing:.1f}px")
                
                cal_info = results.get('motor_calibration_info', {})
                if cal_info.get('is_calibrated'):
                    print(f"Motor Calibration: Active (Score: {cal_info['avg_score']:.3f}, "
                          f"Error: {cal_info['last_avg_error']:.1f}px)")
                elif cal_info.get('samples_collected') is not None:
                    print(f"Motor Calibration: Learning ({cal_info.get('samples_collected', 0)}/"
                          f"{cal_info.get('samples_needed', 10)} Samples)")
                
                if len(self.frame_buffer.data) > 1:
                    phi_trend = self.frame_buffer.get_phi_trend()
                    roundness_trend = self.frame_buffer.get_roundness_trend()
                    print(f"Trends: φ {phi_trend}, Roundness {roundness_trend}")
                
                fig = self.create_visualization(frame_number, results, motor_data)
                
                # Add loop information to the figure title if looping is enabled
                if self.config.loop_count > 0:
                    fig.suptitle(f"Loop Mode: Pass {loop_iteration + 1}/{self.config.loop_count + 1} "
                               f"({current_direction}) - Frame {frame_index}/{total_frames_to_process}", 
                               fontsize=16, y=0.99)
                
                if self.config.save_individual_frames:
                    if self.config.loop_count > 0:
                        frame_path = Path(self.config.output_images_dir) / f"frame_{frame_number}_loop{loop_iteration}_{current_direction.lower()}.png"
                    else:
                        frame_path = Path(self.config.output_images_dir) / f"frame_{frame_number}.png"
                    plt.savefig(frame_path, dpi=150, bbox_inches='tight')
                
                if writer:
                    fig.canvas.draw()
                    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
                    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))
                    writer.append_data(frame[:, :, :3])
                
                if self.config.display_frames:
                    if use_ipython:
                        display(fig)
                    else:
                        plt.show()
                
                plt.close(fig)
                
        except KeyboardInterrupt:
            logger.info("Processing Interrupted By User")
        finally:
            if writer:
                writer.close()
                logger.info(f"Video Saved: {video_path} ({self.frames_processed} Frames)")
            
            if self.well_center_tracker:
                saved_path = self.well_center_tracker.save_to_json(motor_calibration=self.motor_calibration)
                logger.info(f"Well Center Tracking Saved To: {saved_path}")
            
            logger.info(f"\nProcessing Summary:")
            logger.info(f"Frames Processed (Perpendicular): {self.frames_processed}")
            logger.info(f"Frames Skipped (Non-Perpendicular): {self.frames_skipped}")
            logger.info(f"Frames With Tracking: {self.frames_with_tracking}")
            if self.config.save_individual_frames:
                logger.info(f"Individual Frames Saved To: {self.config.output_images_dir}/")


def main():
    config = Config()
    system = WellTrackingSystem(config)
    system.run()


if __name__ == "__main__":
    main()