In [None]:
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
import cv2
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import re
import json
import torch
from tqdm.auto import tqdm
import warnings
import urllib.request
import gc
from datetime import datetime
import seaborn as sns

# Shapely 2.0: no longer supports 'from shapely import hausdorff'
# only import the geometry classes and ops we need directly.
from shapely.geometry import Polygon, LineString
from shapely.ops import unary_union  # Possibly used in your code
# For older Shapely fallback, do a local import in the try/except below.

# skimage metrics for SSIM, MSE, PSNR
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio

# KL-divergence from scipy
from scipy.stats import entropy as kl_divergence

warnings.filterwarnings('ignore')

################################################################################
# Memory Management / Global Config
################################################################################

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.set_per_process_memory_fraction(0.7)  # Use 70% of available memory
    torch.backends.cuda.max_split_size_mb = 128
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True'

def cleanup_memory():
    """Utility function for thorough memory cleanup."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

TEST_FOLDERS = {
    'brooklyn-boston-model': "../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
    'brooklyn-charlotte-model': "../data/ny-brooklyn/nc-charlotte-500-150-v100/test_latest_500e-Brooklyn/images",
    'brooklyn-manhattan-model': "../data/ny-brooklyn/ny-manhattan-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
    'brooklyn-pittsburgh-model': "../data/ny-brooklyn/pa-pittsburgh-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
}

SAMPLE_SIZE = 5  # down from 1000 to 5 for testing

def generate_output_dir(test_folders, sample_size):
    first_folder = list(test_folders.keys())[0]
    base_name = first_folder.split('-')[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    identifier = f"{base_name}_n{sample_size}_{timestamp}"
    return os.path.join("benchmark-output", identifier)

OUTPUT_DIR = generate_output_dir(TEST_FOLDERS, SAMPLE_SIZE)

def download_sam_checkpoint():
    """Download SAM checkpoint if needed."""
    checkpoint_path = "sam_vit_h_4b8939.pth"
    if not os.path.exists(checkpoint_path):
        print("Downloading SAM checkpoint...")
        url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
        urllib.request.urlretrieve(url, checkpoint_path)
        print("Download complete!")
    return checkpoint_path

SAM_CHECKPOINT = download_sam_checkpoint()

def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

DEVICE = get_device()
print(f"Using device: {DEVICE}")

################################################################################
# Configuration Parameters
################################################################################

BUILDING_DETECTION = {
    'confidence_threshold': 0.5,
    'min_area': 50,
    'grid_points': 4,
    'batch_size': 1,
    'max_image_size': 512
}

GRID_DETECTION = {
    'min_line_length': 50,
    'angle_tolerance': 5,
    'grid_spacing': 85,
}

BLOCK_DETECTION = {
    'clustering_distance': 50,
    'min_block_size': 3,
}

GEOMETRY_PARAMS = {
    'smoothness_window': 3,
    'corner_angle_tolerance': 5,
    'parallel_edge_tolerance': 5,
}

VIZ_PARAMS = {
    'figure_size': (20, 15),
    'dpi': 150,
    'parcel_color': 'red',
    'grid_line_color': 'yellow',
    'line_width': 1,
    'heatmap_cmap': 'hot',
}

ENTROPY_PARAMS = {
    'cell_size': 32,
    'angle_bins': 36,
    'neighbor_threshold': 10
}

################################################################################
# Detector / Entropy
################################################################################

class EntropyCalculator:
    """Memory-efficient entropy calculations."""
    
    def __init__(self, params=ENTROPY_PARAMS):
        self.params = params
    
    def compute_shannon_entropy(self, probabilities):
        probabilities = np.array(probabilities)
        probabilities = probabilities[probabilities > 0]
        return -np.sum(probabilities * np.log2(probabilities))
    
    def create_spatial_grid(self, parcels, shape):
        """Create spatial occupation grid (for spatial entropy)."""
        cell_size = self.params['cell_size']
        grid_h, grid_w = shape[0] // cell_size, shape[1] // cell_size
        grid = np.zeros((grid_h, grid_w))
        
        batch_size = 50
        for i in range(0, len(parcels), batch_size):
            batch_parcels = parcels[i:i + batch_size]
            for parcel in batch_parcels:
                try:
                    coords = np.array(parcel.exterior.coords).astype(np.int32)
                    mask = np.zeros(shape, dtype=np.uint8)
                    cv2.fillPoly(mask, [coords], 1)
                    
                    for row_i in range(grid_h):
                        for col_j in range(grid_w):
                            cell = mask[
                                row_i*cell_size : (row_i+1)*cell_size,
                                col_j*cell_size : (col_j+1)*cell_size
                            ]
                            grid[row_i, col_j] += np.sum(cell) > 0
                except Exception:
                    continue
            cleanup_memory()
        
        return grid
    
    def compute_all_entropy_metrics(self, parcels, shape):
        """Compute spatial, size, complexity entropies, plus total."""
        metrics = {}
        
        try:
            # Spatial Entropy
            grid = self.create_spatial_grid(parcels, shape)
            total_occupied = np.sum(grid > 0)
            if total_occupied > 0:
                spatial_probs = grid[grid > 0] / total_occupied
                metrics['spatial_entropy'] = self.compute_shannon_entropy(spatial_probs)
            else:
                metrics['spatial_entropy'] = 0
            
            del grid
            cleanup_memory()
            
            # Size Entropy
            areas = []
            batch_size = 50
            for i in range(0, len(parcels), batch_size):
                batch = parcels[i:i + batch_size]
                areas.extend([p.area for p in batch])
                cleanup_memory()
            
            if areas:
                hist, _ = np.histogram(areas, bins='auto', density=True)
                size_probs = hist[hist > 0] / np.sum(hist[hist > 0])
                metrics['size_entropy'] = self.compute_shannon_entropy(size_probs)
            else:
                metrics['size_entropy'] = 0
            del areas
            cleanup_memory()
            
            # Complexity Entropy (exterior coords length)
            complexities = []
            for i in range(0, len(parcels), batch_size):
                batch = parcels[i:i + batch_size]
                complexities.extend([len(p.exterior.coords) for p in batch])
                cleanup_memory()
            
            if complexities:
                hist, _ = np.histogram(complexities, bins='auto', density=True)
                complexity_probs = hist[hist > 0] / np.sum(hist[hist > 0])
                metrics['complexity_entropy'] = self.compute_shannon_entropy(complexity_probs)
            else:
                metrics['complexity_entropy'] = 0
            del complexities
            cleanup_memory()
            
            # Total Entropy
            metrics['total_entropy'] = np.mean([
                metrics['spatial_entropy'],
                metrics['size_entropy'],
                metrics['complexity_entropy']
            ])
            
        except Exception as e:
            print(f"Error computing entropy metrics: {str(e)}")
            metrics = {
                'spatial_entropy': 0,
                'size_entropy': 0,
                'complexity_entropy': 0,
                'total_entropy': 0
            }
        
        return metrics

class MemoryEfficientDetector:
    """Memory-efficient building-parcel detector using SAM."""
    
    def __init__(self, sam_checkpoint, device='cuda'):
        self.checkpoint_path = sam_checkpoint
        self.device = device
        self.sam = None
        self.predictor = None
        self.entropy_calc = EntropyCalculator()
    
    def initialize_model(self):
        """Initialize model with careful memory management."""
        try:
            self.cleanup()
            
            print("Loading SAM checkpoint...")
            with open(self.checkpoint_path, "rb") as f:
                state_dict = torch.load(f, map_location='cpu')
            
            from segment_anything import SamPredictor, sam_model_registry
            self.sam = sam_model_registry["vit_h"](checkpoint=None)
            self.sam.load_state_dict(state_dict)
            
            self.sam = self.sam.to(device=self.device, dtype=torch.float16)
            self.predictor = SamPredictor(self.sam)
            
            print("Model initialized successfully")
            return True
        except Exception as e:
            print(f"Error initializing model: {e}")
            self.cleanup()
            return False
    
    def cleanup(self):
        """Cleanup resources."""
        if getattr(self, 'predictor', None) is not None:
            self.predictor.reset_image()
            del self.predictor
            self.predictor = None
        if getattr(self, 'sam', None) is not None:
            del self.sam
            self.sam = None
        cleanup_memory()
    
    def preprocess_image(self, image):
        """Memory-efficient image preprocessing."""
        try:
            h, w = image.shape[:2]
            max_size = BUILDING_DETECTION['max_image_size']
            scale = min(max_size/h, max_size/w)
            
            if scale < 1:
                new_h, new_w = int(h*scale), int(w*scale)
                image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
            
            gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            enhanced = clahe.apply(gray)
            
            return enhanced, image, scale
        except Exception as e:
            print(f"Error in preprocessing: {e}")
            return None, None, None
    
    def detect_grid_pattern(self, image):
        """Detect grid lines in the image."""
        try:
            lsd = cv2.createLineSegmentDetector()
            edges = cv2.Canny(image, 50, 150)
            lines = lsd.detect(edges)[0]
            
            if lines is None:
                return [], []
            
            filtered_lines = []
            angles = []
            for line in lines:
                x1, y1, x2, y2 = map(int, line[0])
                length = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
                angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) % 180
                if length > GRID_DETECTION['min_line_length']:
                    filtered_lines.append(line[0])
                    angles.append(angle)
            
            if angles:
                hist, _ = np.histogram(angles, bins=180, range=(0,180))
                peak_indices = hist > (np.mean(hist) + np.std(hist))
                peak_angles = np.where(peak_indices)[0]
                return filtered_lines, peak_angles
            return [], []
        except Exception as e:
            print(f"Error in grid detection: {e}")
            return [], []
        finally:
            cleanup_memory()
    
    def process_single_image(self, image):
        """Process a single image to find parcels and compute entropies."""
        try:
            enhanced, orig_image, scale = self.preprocess_image(image)
            if enhanced is None:
                return None
            
            # Grid detection
            grid_lines, dom_angles = self.detect_grid_pattern(enhanced)
            
            # Contours for parcels
            edges = cv2.Canny(enhanced, 50, 150)
            contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            parcels = []
            batch_size = 50
            for i in range(0, len(contours), batch_size):
                for cnt in contours[i:i+batch_size]:
                    try:
                        area = cv2.contourArea(cnt)
                        if area < BUILDING_DETECTION['min_area']:
                            continue
                        epsilon = 0.02 * cv2.arcLength(cnt, True)
                        approx = cv2.approxPolyDP(cnt, epsilon, True)
                        if len(approx) < 4 or len(approx) > 8:
                            continue
                        coords = np.squeeze(approx).reshape(-1, 2)
                        poly = Polygon(coords)
                        if poly.is_valid and poly.area >= BUILDING_DETECTION['min_area']:
                            parcels.append(poly)
                    except:
                        continue
                cleanup_memory()
            
            # Entropy
            entropy_metrics = self.entropy_calc.compute_all_entropy_metrics(parcels, image.shape[:2])
            
            return {
                'parcels': parcels,
                'grid_lines': grid_lines,
                'dominant_angles': dom_angles,
                'entropy_metrics': entropy_metrics,
                'shape': image.shape[:2],
                'scale': scale
            }
        except Exception as e:
            print(f"Error processing image: {e}")
            return None
        finally:
            cleanup_memory()
    
    def process_image_pair(self, fake_img, real_img):
        """Process an already-initialized model with one pair (fake, real)."""
        try:
            # Fake
            fake_results = self.process_single_image(fake_img)
            # Real
            real_results = self.process_single_image(real_img)
            return fake_results, real_results
        except Exception as e:
            print(f"Error in process_image_pair: {e}")
            return None, None

################################################################################
# Visualizer
################################################################################

class MemoryEfficientVisualizer:
    """Memory-efficient visualization class."""
    
    def __init__(self, viz_params=VIZ_PARAMS):
        self.viz_params = viz_params
        matplotlib.use('Agg')  # Non-interactive backend
    
    def create_comparison_visualization(self, image, parcels, intermediates, metrics,
                                        max_parcels=50):
        """Create a basic side-by-side or single figure comparison."""
        try:
            fig = plt.figure(figsize=self.viz_params['figure_size'])
            gs = GridSpec(2, 2, figure=fig)
            
            # 1. Original image with detected parcels
            ax1 = fig.add_subplot(gs[0,0])
            ax1.imshow(image)
            sample_parcels = random.sample(parcels, min(len(parcels), max_parcels))
            for parcel in sample_parcels:
                x, y = parcel.exterior.xy
                ax1.plot(x, y, self.viz_params['parcel_color'], linewidth=self.viz_params['line_width'])
            ax1.set_title(f"Parcels (showing {len(sample_parcels)} / {len(parcels)})")
            
            # 2. Grid lines
            ax2 = fig.add_subplot(gs[0,1])
            ax2.imshow(image)
            grid_lines = intermediates.get('grid_lines', [])
            if len(grid_lines) > 100:
                grid_lines = random.sample(grid_lines, 100)
            for line in grid_lines:
                x1, y1, x2, y2 = map(int, line)
                ax2.plot([x1, x2], [y1, y2], self.viz_params['grid_line_color'])
            ax2.set_title("Grid Pattern")
            
            # 3. Metrics summary
            ax3 = fig.add_subplot(gs[1,0])
            metrics_text = (
                f"Grid Alignment: {metrics.get('grid_alignment',0):.3f}\n"
                f"Shape Regularity: {metrics.get('regularity_score',0):.3f}\n"
                f"Spatial Entropy: {metrics.get('spatial_entropy',0):.3f}\n"
                f"Size Entropy: {metrics.get('size_entropy',0):.3f}\n"
                f"Complexity Entropy: {metrics.get('complexity_entropy',0):.3f}\n"
                f"Total Entropy: {metrics.get('total_entropy',0):.3f}\n"
            )
            ax3.text(0.5, 0.5, metrics_text, ha='center', va='center', transform=ax3.transAxes)
            ax3.axis('off')
            
            # 4. Parcel Size Distribution
            ax4 = fig.add_subplot(gs[1,1])
            areas = [p.area for p in parcels]
            sns.histplot(areas, ax=ax4, bins=30, kde=True)
            ax4.set_title("Parcel Size Distribution")
            
            plt.tight_layout()
            return fig
        except Exception as e:
            print(f"Error creating visualization: {e}")
            plt.close('all')
            return None
        finally:
            cleanup_memory()

################################################################################
# Benchmarker - Additional Metrics
################################################################################

class MemoryEfficientBenchmarker:
    """Benchmark class that uses the MemoryEfficientDetector and visualizer."""
    
    def __init__(self):
        self.detector = None
        self.visualizer = MemoryEfficientVisualizer()
    
    def initialize_detector(self):
        """Initialize the SAM-based detector once."""
        try:
            if self.detector is not None:
                self.detector.cleanup()
                del self.detector
                cleanup_memory()
            
            self.detector = MemoryEfficientDetector(SAM_CHECKPOINT, device=DEVICE)
            return self.detector.initialize_model()
        except Exception as e:
            print(f"Error initializing detector: {e}")
            return False
    
    def cleanup(self):
        """Cleanup resources."""
        if self.detector is not None:
            self.detector.cleanup()
            del self.detector
            self.detector = None
        cleanup_memory()
    
    def process_single_pair_with_existing_detector(self, fake_path, real_path, output_dir=None):
        """Process a single pair with the already-initialized model."""
        try:
            fake_img = cv2.imread(fake_path)
            real_img = cv2.imread(real_path)
            if fake_img is None or real_img is None:
                print(f"Error loading images: {fake_path}, {real_path}")
                return None
            
            # Convert to RGB
            fake_img = cv2.cvtColor(fake_img, cv2.COLOR_BGR2RGB)
            real_img = cv2.cvtColor(real_img, cv2.COLOR_BGR2RGB)
            
            fake_results, real_results = self.detector.process_image_pair(fake_img, real_img)
            if fake_results and real_results:
                metrics = self.calculate_metrics(fake_results, real_results, fake_img, real_img)
                
                # Identify the file/pair
                base_name = os.path.basename(fake_path)
                # e.g., combined_123_fake_B.png => file_id = "123"
                metrics['file_id'] = base_name.split('_')[1]
                metrics['fake_path'] = fake_path
                metrics['real_path'] = real_path
                
                # Visualization
                if output_dir:
                    fig = self.visualizer.create_comparison_visualization(
                        fake_img, fake_results['parcels'], fake_results, metrics
                    )
                    if fig is not None:
                        fig.savefig(
                            os.path.join(output_dir, f"comparison_{metrics['file_id']}.png"),
                            dpi=self.visualizer.viz_params['dpi'],
                            bbox_inches='tight'
                        )
                        plt.close(fig)
                
                return metrics
            return None
        except Exception as e:
            print(f"Error processing pair: {e}")
            return None
    
    def calculate_metrics(self, fake_results, real_results, fake_img, real_img):
        """
        Calculate comparison metrics, including:
          - existing grid/entropy/regularity
          - IoU / Dice / Hausdorff
          - shape descriptors
          - MSE / PSNR / SSIM (raster-based)
          - distribution overlap (KL)
        """
        metrics = {}
        try:
            # ------------------------------------------------------------------
            # Existing metrics from the previous code
            # ------------------------------------------------------------------
            fake_parcels = fake_results['parcels']
            real_parcels = real_results['parcels']
            
            # Basic ratio: #fake parcels / #real parcels
            if fake_parcels and real_parcels:
                metrics['parcel_count_ratio'] = (
                    len(fake_parcels) / max(len(real_parcels), 1)
                )
            else:
                metrics['parcel_count_ratio'] = 0
            
            # Some measure of "grid alignment" (already in code)
            if fake_parcels and fake_results.get('grid_lines'):
                metrics['grid_alignment'] = self.compute_grid_alignment(
                    fake_parcels, fake_results['grid_lines']
                )
            else:
                metrics['grid_alignment'] = 0
            
            # "Shape regularity" measure
            if fake_parcels:
                metrics['regularity_score'] = self.compute_shape_regularity(fake_parcels)
            else:
                metrics['regularity_score'] = 0
            
            # Copy over fake's entropy metrics
            if fake_results.get('entropy_metrics'):
                metrics.update(fake_results['entropy_metrics'])
            
            # Compare to real entropy
            if real_results.get('entropy_metrics'):
                for key, val in fake_results['entropy_metrics'].items():
                    if key in real_results['entropy_metrics']:
                        real_val = real_results['entropy_metrics'][key]
                        metrics[f"relative_{key}"] = abs(val - real_val) / max(real_val, 1e-6)
            
            # ------------------------------------------------------------------
            # 1. IoU & Dice & Hausdorff Distances
            # ------------------------------------------------------------------
            iou_scores = []
            dice_scores = []
            hausdorff_dists = []
            
            used_real_indices = set()
            for fpar in fake_parcels:
                best_iou = 0
                best_idx = -1
                for idx, rpar in enumerate(real_parcels):
                    if idx in used_real_indices:
                        continue
                    if not fpar.is_valid or not rpar.is_valid:
                        continue
                    inter_area = fpar.intersection(rpar).area
                    union_area = fpar.union(rpar).area
                    iou = inter_area / union_area if union_area > 0 else 0
                    if iou > best_iou:
                        best_iou = iou
                        best_idx = idx
                if best_idx >= 0:
                    used_real_indices.add(best_idx)
                    rpar = real_parcels[best_idx]
                    
                    # Dice
                    inter_area = fpar.intersection(rpar).area
                    area_sum = fpar.area + rpar.area
                    dice_val = (2.0 * inter_area / area_sum) if area_sum > 0 else 0
                    
                    # Hausdorff
                    try:
                        # Shapely 2.0+ method
                        hdist = fpar.hausdorff_distance(rpar)
                    except AttributeError:
                        # Shapely < 2.0 fallback
                        from shapely.ops import hausdorff_distance
                        hdist = hausdorff_distance(fpar, rpar)
                    
                    iou_scores.append(best_iou)
                    dice_scores.append(dice_val)
                    hausdorff_dists.append(hdist)
            
            metrics['mean_iou'] = np.mean(iou_scores) if iou_scores else 0
            metrics['mean_dice'] = np.mean(dice_scores) if dice_scores else 0
            metrics['mean_hausdorff'] = np.mean(hausdorff_dists) if hausdorff_dists else 0
            
            # ------------------------------------------------------------------
            # 2. Shape Descriptors (circularity)
            # ------------------------------------------------------------------
            def compute_circularity_stats(parcels):
                circs = []
                for p in parcels:
                    if not p.is_valid:
                        continue
                    area = p.area
                    perimeter = p.length
                    if perimeter > 1e-9:
                        circ = 4.0 * np.pi * (area / (perimeter * perimeter))
                        circs.append(circ)
                return np.mean(circs) if circs else 0
            
            fake_circ = compute_circularity_stats(fake_parcels)
            real_circ = compute_circularity_stats(real_parcels)
            metrics['circularity_fake'] = fake_circ
            metrics['circularity_real'] = real_circ
            metrics['circularity_diff'] = abs(fake_circ - real_circ)
            
            # ------------------------------------------------------------------
            # 3. MSE / PSNR / SSIM (raster-based)
            # ------------------------------------------------------------------
            if fake_img.shape == real_img.shape:
                fake_float = fake_img.astype(np.float32) / 255.0
                real_float = real_img.astype(np.float32) / 255.0
                
                # MSE
                mse_val = mean_squared_error(real_float, fake_float)
                # PSNR
                psnr_val = peak_signal_noise_ratio(real_float, fake_float, data_range=1.0)
                # SSIM
                ssim_val = ssim(real_float, fake_float, data_range=1.0, multichannel=True)
                
                metrics['mse'] = mse_val
                metrics['psnr'] = psnr_val
                metrics['ssim'] = ssim_val
            else:
                metrics['mse'] = 0
                metrics['psnr'] = 0
                metrics['ssim'] = 0
            
            # ------------------------------------------------------------------
            # 4. Distribution Overlap (KL) on area
            # ------------------------------------------------------------------
            fake_areas = [p.area for p in fake_parcels if p.is_valid]
            real_areas = [p.area for p in real_parcels if p.is_valid]
            
            if len(fake_areas) > 1 and len(real_areas) > 1:
                all_areas = np.concatenate([fake_areas, real_areas])
                bins = np.histogram_bin_edges(all_areas, bins='auto')
                fake_hist, _ = np.histogram(fake_areas, bins=bins, density=True)
                real_hist, _ = np.histogram(real_areas, bins=bins, density=True)
                kl_val = kl_divergence(fake_hist + 1e-9, real_hist + 1e-9)
                metrics['kl_div_area'] = kl_val
            else:
                metrics['kl_div_area'] = 0
            
        except Exception as e:
            print(f"Error in calculate_metrics: {e}")
        return metrics
    
    def compute_grid_alignment(self, parcels, grid_lines, batch_size=50):
        """Measure alignment of each parcel edge to the dominant grid lines."""
        if not parcels or not grid_lines:
            return 0
        alignment_scores = []
        try:
            for i in range(0, len(parcels), batch_size):
                batch = parcels[i:i+batch_size]
                for parcel in batch:
                    coords = list(parcel.exterior.coords)
                    parcel_lines = [LineString([coords[j], coords[j+1]]) for j in range(len(coords)-1)]
                    
                    angles = []
                    for p_line in parcel_lines:
                        for g_line in grid_lines:
                            g_linestring = LineString([(g_line[0], g_line[1]), (g_line[2], g_line[3])])
                            angle = abs(
                                np.degrees(np.arctan2(g_line[3]-g_line[1], g_line[2]-g_line[0])
                                - np.arctan2(
                                    p_line.coords[1][1] - p_line.coords[0][1],
                                    p_line.coords[1][0] - p_line.coords[0][0]
                                )) % 90
                            )
                            angles.append(min(angle, 90-angle))
                    if angles:
                        alignment_scores.append(np.mean(angles))
            return 1 - (np.mean(alignment_scores)/45) if alignment_scores else 0
        except Exception as e:
            print(f"Error computing grid alignment: {e}")
            return 0
    
    def compute_shape_regularity(self, parcels, batch_size=50):
        """Compute shape regularity = area(parcel)/area(minimum_rotated_rectangle), plus angle check."""
        if not parcels:
            return 0
        scores = []
        try:
            for i in range(0, len(parcels), batch_size):
                for p in parcels[i:i+batch_size]:
                    if not p.is_valid:
                        continue
                    min_rect = p.minimum_rotated_rectangle
                    reg = 0
                    if min_rect.area > 1e-9:
                        reg = p.area / min_rect.area
                    # If quadrilateral, check angles near 90
                    coords = np.array(p.exterior.coords[:-1])
                    if len(coords) == 4:
                        angles = []
                        for idx in range(4):
                            v1 = coords[(idx+1)%4] - coords[idx]
                            v2 = coords[(idx-1)%4] - coords[idx]
                            angle = abs(90 - abs(np.degrees(np.arctan2(
                                np.cross(v1, v2), np.dot(v1, v2)
                            ))))
                            angles.append(angle)
                        angle_score = 1 - (np.mean(angles)/90)
                        reg = (reg + angle_score)/2
                    scores.append(reg)
                cleanup_memory()
            return np.mean(scores) if scores else 0
        except Exception as e:
            print(f"Error computing shape regularity: {e}")
            return 0

################################################################################
# Report Generation
################################################################################

def create_memory_efficient_visualizations(df, output_dir):
    """Create some basic distribution/correlation plots for your metrics."""
    try:
        # Filter columns for advanced metrics
        metrics = [
            col for col in df.columns
            if col.startswith(('relative_', 'grid_', 'regularity_', 'mean_', 'mse', 'psnr', 'ssim', 'kl_div_area'))
            or col in ('spatial_entropy','size_entropy','complexity_entropy','total_entropy')
        ]
        
        # 1. Distributions
        for metric in metrics:
            plt.figure(figsize=(8,6))
            sns.histplot(df[metric], kde=True)
            plt.title(metric)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{metric}_dist.png"))
            plt.close()
            cleanup_memory()
        
        # 2. Correlation matrix
        if len(metrics) > 1:
            corr = df[metrics].corr()
            plt.figure(figsize=(12,10))
            sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
            plt.title("Metric Correlations")
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "metric_correlations.png"))
            plt.close()
            cleanup_memory()
        
        # 3. Box plots in small chunks
        chunk_size = 5
        for i in range(0, len(metrics), chunk_size):
            chunk = metrics[i:i+chunk_size]
            df_melted = df[chunk].melt()
            plt.figure(figsize=(10,6))
            sns.boxplot(data=df_melted, x='variable', y='value')
            plt.xticks(rotation=45)
            plt.title(f"Boxplots (Group {i//chunk_size + 1})")
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"boxplots_group_{i//chunk_size + 1}.png"))
            plt.close()
            cleanup_memory()
    except Exception as e:
        print(f"Error creating visualizations: {e}")
    finally:
        cleanup_memory()

def generate_memory_efficient_reports(results, output_dir):
    """Generate folderwise + overall reports, including new advanced metrics."""
    try:
        chunk_size = 100
        all_dfs = []
        for i in range(0, len(results), chunk_size):
            df_chunk = pd.DataFrame(results[i:i+chunk_size])
            all_dfs.append(df_chunk)
            cleanup_memory()
        
        df = pd.concat(all_dfs, ignore_index=True)
        
        # 1. Save detailed results
        df.to_csv(os.path.join(output_dir, 'detailed_results.csv'), index=False)
        cleanup_memory()
        
        # 2. Compute overall aggregates
        metrics_to_analyze = [
            'parcel_count_ratio','grid_alignment','regularity_score',
            'mean_iou','mean_dice','mean_hausdorff',
            'mse','psnr','ssim','kl_div_area',
            'spatial_entropy','size_entropy','complexity_entropy','total_entropy'
        ]
        
        aggregate_metrics = {}
        for metric in metrics_to_analyze:
            if metric in df.columns:
                aggregate_metrics[metric] = {
                    'mean': df[metric].mean(),
                    'std': df[metric].std(),
                    'min': df[metric].min(),
                    'max': df[metric].max(),
                    'median': df[metric].median()
                }
        
        with open(os.path.join(output_dir, 'aggregate_metrics.json'), 'w') as f:
            json.dump(aggregate_metrics, f, indent=4)
        
        # 3. Folderwise summary (average per folder)
        if 'folder' in df.columns:
            folderwise_cols = [
                'spatial_entropy','size_entropy','complexity_entropy','total_entropy',
                'mean_iou','mean_dice','mean_hausdorff','mse','psnr','ssim',
                'kl_div_area','parcel_count_ratio','grid_alignment','regularity_score'
            ]
            folderwise_cols = [c for c in folderwise_cols if c in df.columns]
            if folderwise_cols:
                folderwise = (df.groupby('folder')[folderwise_cols].mean().reset_index())
                folderwise.to_csv(os.path.join(output_dir, 'folderwise_metrics.csv'), index=False)
        
        # 4. Visualizations
        create_memory_efficient_visualizations(df, output_dir)
        
    except Exception as e:
        print(f"Error generating reports: {e}")
    finally:
        cleanup_memory()

################################################################################
# Main Processing Logic
################################################################################

def process_test_folder(test_folder, output_dir=None, sample_size=SAMPLE_SIZE):
    """Process test folder with controlled sample size."""
    if output_dir is None:
        output_dir = os.path.join(test_folder, 'benchmark_results')
    os.makedirs(output_dir, exist_ok=True)
    
    benchmarker = MemoryEfficientBenchmarker()
    
    print(f"Scanning directory: {test_folder}")
    image_pairs = {}
    
    # Regex: e.g. combined_123_fake_B.png or combined_123_real_B.png
    for filename in os.listdir(test_folder):
        if not filename.endswith('.png'):
            continue
        
        match = re.search(r'combined_(\d+)_(fake|real)_B\.png', filename)
        if not match:
            continue
        
        base_num, img_type = match.group(1), match.group(2)
        full_path = os.path.join(test_folder, filename)
        
        if base_num not in image_pairs:
            image_pairs[base_num] = {'fake': None, 'real': None}
        image_pairs[base_num][img_type] = full_path
    
    complete_pairs = [
        (pair['fake'], pair['real']) for pair in image_pairs.values()
        if pair['fake'] and pair['real']
    ]
    
    if sample_size and sample_size < len(complete_pairs):
        print(f"Sampling {sample_size} pairs from {len(complete_pairs)} total")
        complete_pairs = random.sample(complete_pairs, sample_size)
    else:
        print(f"Processing all {len(complete_pairs)} pairs")
    
    results = []
    with tqdm(total=len(complete_pairs), desc="Processing image pairs") as pbar:
        if benchmarker.initialize_detector():
            for fake_path, real_path in complete_pairs:
                try:
                    metrics = benchmarker.process_single_pair_with_existing_detector(
                        fake_path, real_path, output_dir
                    )
                    if metrics:
                        results.append(metrics)
                except Exception as e:
                    print(f"\nError processing {os.path.basename(fake_path)}: {e}")
                finally:
                    cleanup_memory()
                pbar.update(1)
            
            benchmarker.cleanup()
    
    if results:
        generate_memory_efficient_reports(results, output_dir)
    
    return results

def main():
    print("\nParcel Detection and Benchmarking System - Extended Metrics")
    print("===========================================================")
    
    try:
        # Quick device check
        if torch.cuda.is_available():
            print(f"CUDA available: {torch.cuda.get_device_name(0)}")
            print(f"Memory allocated: {torch.cuda.memory_allocated(0)/(1024**2):.1f} MB")
            print(f"Memory reserved: {torch.cuda.memory_reserved(0)/(1024**2):.1f} MB")
        else:
            print("CUDA not available, using CPU")
        
        # Verify input folders
        valid_folders = {
            name: path for name, path in TEST_FOLDERS.items()
            if os.path.exists(path)
        }
        if not valid_folders:
            raise ValueError("No valid folders found!")
        
        print("Valid folders found:")
        for name, path in valid_folders.items():
            print(f"  - {name}: {path}")
        
        # Create main output
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        print(f"\nOutput directory: {OUTPUT_DIR}")
        
        all_results = []
        for folder_name, folder_path in valid_folders.items():
            print(f"\nProcessing {folder_name}...")
            output_subdir = os.path.join(OUTPUT_DIR, folder_name)
            try:
                folder_results = process_test_folder(
                    folder_path, output_subdir, sample_size=SAMPLE_SIZE
                )
                if folder_results:
                    for r in folder_results:
                        r['folder'] = folder_name
                    all_results.extend(folder_results)
                cleanup_memory()
            except Exception as e:
                print(f"Error processing folder {folder_name}: {e}")
        
        # Overall summary
        if all_results:
            print("\nGenerating overall summary across all folders...")
            generate_memory_efficient_reports(all_results, OUTPUT_DIR)
        
        print("\nProcessing complete!")
        print(f"Results saved to: {OUTPUT_DIR}")
        
    except Exception as e:
        print(f"Error during execution: {e}")
    finally:
        cleanup_memory()

if __name__ == "__main__":
    if torch.cuda.is_available():
        torch.cuda.set_per_process_memory_fraction(0.7)
        torch.backends.cuda.max_split_size_mb = 128
    main()


Using device: cuda

Parcel Detection and Benchmarking System - Extended Metrics
CUDA available: NVIDIA GeForce RTX 3070 Ti
Memory allocated: 0.0 MB
Memory reserved: 0.0 MB
Valid folders found:
  - brooklyn-boston-model: ../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images
  - brooklyn-charlotte-model: ../data/ny-brooklyn/nc-charlotte-500-150-v100/test_latest_500e-Brooklyn/images
  - brooklyn-manhattan-model: ../data/ny-brooklyn/ny-manhattan-p2p-500-150-v100/test_latest_500e-Brooklyn/images
  - brooklyn-pittsburgh-model: ../data/ny-brooklyn/pa-pittsburgh-p2p-500-150-v100/test_latest_500e-Brooklyn/images

Output directory: benchmark-output/brooklyn_n5_20241222_1422

Processing brooklyn-boston-model...
Scanning directory: ../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images
Sampling 5 pairs from 1000 total


Processing image pairs:   0%|          | 0/5 [00:00<?, ?it/s]

Loading SAM checkpoint...
Model initialized successfully
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If you

Processing image pairs:   0%|          | 0/5 [00:00<?, ?it/s]

Loading SAM checkpoint...
Model initialized successfully
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If you

Processing image pairs:   0%|          | 0/5 [00:00<?, ?it/s]

Loading SAM checkpoint...
Model initialized successfully
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If you

Processing image pairs:   0%|          | 0/5 [00:00<?, ?it/s]

Loading SAM checkpoint...
Model initialized successfully
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If your images are multichannel (with color channels), set channel_axis to the axis number corresponding to the channels.
Error in calculate_metrics: win_size exceeds image extent. Either ensure that your images are at least 7x7; or pass win_size explicitly in the function call, with an odd value less than or equal to the smaller side of your images. If you