In [None]:
# Part 1: Imports and Core Evaluator Class
import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
import json
from tqdm.notebook import tqdm
from shapely.geometry import Polygon
from shapely.ops import unary_union
from skimage.metrics import (
    mean_squared_error, 
    structural_similarity as ssim, 
    peak_signal_noise_ratio
)
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

class ParcelEvaluator:
    def __init__(
        self,
        min_area=50,
        color_dist_threshold=30,
        win_size_for_ssim=3,
        save_visualizations=True
    ):
        """Initialize the ParcelEvaluator with configuration parameters."""
        self.min_area = min_area
        self.color_dist_threshold = color_dist_threshold
        self.win_size_for_ssim = win_size_for_ssim
        self.save_visualizations = save_visualizations

    def create_benchmark_id(self, model_name=None):
        """Create a unique identifier for this benchmark run."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if model_name:
            return f"{model_name}_{timestamp}"
        return f"benchmark_{timestamp}"

    def setup_output_directory(self, base_dir, benchmark_id):
        """Create and setup output directory structure."""
        output_dir = Path(base_dir) / benchmark_id
        
        # Create subdirectories
        (output_dir / "visualizations").mkdir(parents=True, exist_ok=True)
        (output_dir / "metrics").mkdir(parents=True, exist_ok=True)
        
        return output_dir

    def save_config(self, output_dir):
        """Save evaluator configuration."""
        config = {
            'min_area': self.min_area,
            'color_dist_threshold': self.color_dist_threshold,
            'win_size_for_ssim': self.win_size_for_ssim,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(output_dir / "metrics" / "config.json", 'w') as f:
            json.dump(config, f, indent=4)

    def load_and_preprocess(self, image_path):
        """Load and preprocess image to RGB."""
        img_bgr = cv2.imread(str(image_path))
        if img_bgr is None:
            raise ValueError(f"Could not load image: {image_path}")
        return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    def parse_color_coded_image(self, image_rgb):
        """Extract exact color polygons from real image."""
        h, w = image_rgb.shape[:2]
        color2mask = {}
        image_rgb = image_rgb.astype(np.float32)

        # Process image in chunks for memory efficiency
        chunk_size = 100
        for y_start in range(0, h, chunk_size):
            y_end = min(y_start + chunk_size, h)
            for x_start in range(0, w, chunk_size):
                x_end = min(x_start + chunk_size, w)
                chunk = image_rgb[y_start:y_end, x_start:x_end]
                
                for y in range(chunk.shape[0]):
                    for x in range(chunk.shape[1]):
                        c = tuple(chunk[y, x])
                        if c not in color2mask:
                            color2mask[c] = np.zeros((h, w), dtype=np.uint8)
                        color2mask[c][y_start + y, x_start + x] = 1

        return self._masks_to_polygons(color2mask)

    def parse_fake_image(self, fake_rgb, target_colors):
        """Parse fake image using color threshold approach."""
        h, w = fake_rgb.shape[:2]
        fake_rgb = fake_rgb.astype(np.float32)
        color2mask = {tuple(float(x) for x in c): np.zeros((h, w), dtype=np.uint8) 
                     for c in target_colors}

        # Process in chunks
        chunk_size = 100
        for y_start in tqdm(range(0, h, chunk_size), desc="Processing fake image", leave=False):
            y_end = min(y_start + chunk_size, h)
            for x_start in range(0, w, chunk_size):
                x_end = min(x_start + chunk_size, w)
                chunk = fake_rgb[y_start:y_end, x_start:x_end]
                
                for y in range(chunk.shape[0]):
                    for x in range(chunk.shape[1]):
                        pixel = chunk[y, x]
                        best_color = min(
                            color2mask.keys(),
                            key=lambda c: np.sqrt(np.sum((pixel - c) ** 2))
                        )
                        if np.sqrt(np.sum((pixel - best_color) ** 2)) < self.color_dist_threshold:
                            color2mask[best_color][y_start + y, x_start + x] = 1

        return self._masks_to_polygons(color2mask)

    def _masks_to_polygons(self, color2mask):
        """Convert masks to merged polygons."""
        color2poly = {}
        for color, mask in color2mask.items():
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            polys = []
            for cnt in contours:
                area = cv2.contourArea(cnt)
                if area < self.min_area:
                    continue
                ep = 0.02 * cv2.arcLength(cnt, True)
                approx = cv2.approxPolyDP(cnt, ep, True)
                coords = np.squeeze(approx).reshape(-1, 2) if approx.size > 0 else np.array([])
                if coords.size > 0:
                    try:
                        p = Polygon(coords)
                        if p.is_valid and p.area >= self.min_area:
                            polys.append(p)
                    except Exception:
                        continue
                    
            if not polys:
                continue
                
            try:
                merged = unary_union(polys)
                if merged.geom_type == 'MultiPolygon':
                    big = max(merged.geoms, key=lambda g: g.area)
                    color2poly[color] = big
                else:
                    color2poly[color] = merged
            except Exception as e:
                print(f"Error merging polygons for color {color}: {str(e)}")
                continue
                
        return color2poly

    def compute_geometric_metrics(self, real_polys, fake_polys):
        """Compute geometric comparison metrics."""
        metrics = {
            'polygon_count_ratio': len(fake_polys) / len(real_polys) if real_polys else 0,
            'area_ratios': [],
            'iou_scores': []
        }
        
        for color, real_poly in real_polys.items():
            if color in fake_polys:
                fake_poly = fake_polys[color]
                # Area ratio
                area_ratio = fake_poly.area / real_poly.area if real_poly.area > 0 else 0
                metrics['area_ratios'].append(area_ratio)
                
                # IoU
                try:
                    intersection = real_poly.intersection(fake_poly).area
                    union = real_poly.union(fake_poly).area
                    iou = intersection / union if union > 0 else 0
                    metrics['iou_scores'].append(iou)
                except Exception:
                    continue

        metrics['mean_area_ratio'] = np.mean(metrics['area_ratios']) if metrics['area_ratios'] else 0
        metrics['mean_iou'] = np.mean(metrics['iou_scores']) if metrics['iou_scores'] else 0
        return metrics

    def compute_image_metrics(self, real_rgb, fake_rgb):
        """Compute traditional image comparison metrics."""
        if real_rgb.shape != fake_rgb.shape:
            return {'mse': float('inf'), 'psnr': 0, 'ssim': 0}
            
        real_f = real_rgb.astype(np.float32) / 255
        fake_f = fake_rgb.astype(np.float32) / 255
        
        metrics = {
            'mse': mean_squared_error(real_f, fake_f),
            'psnr': peak_signal_noise_ratio(real_f, fake_f, data_range=1.0),
            'ssim': ssim(real_f, fake_f, data_range=1.0, 
                        channel_axis=2, win_size=self.win_size_for_ssim)
        }
        return metrics

    def visualize_comparison(self, real_rgb, fake_rgb, real_polys, fake_polys, output_path):
        """Create visualization of the comparison."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        
        # Plot fake image and its polygons
        ax1.imshow(fake_rgb)
        for poly in fake_polys.values():
            if poly.is_valid:
                x, y = poly.exterior.xy
                ax1.plot(x, y, 'r-', linewidth=1)
        ax1.set_title('Generated Image with Detected Parcels')
        ax1.axis('off')
        
        # Plot real image and its polygons
        ax2.imshow(real_rgb)
        for poly in real_polys.values():
            if poly.is_valid:
                x, y = poly.exterior.xy
                ax2.plot(x, y, 'g-', linewidth=1)
        ax2.set_title('Ground Truth with Parcels')
        ax2.axis('off')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()

    def evaluate(self, real_path, fake_path, benchmark_dir="benchmark-outputs", model_name=None):
        """Main evaluation function."""
        # Create benchmark ID and setup directories
        benchmark_id = self.create_benchmark_id(model_name)
        output_dir = self.setup_output_directory(benchmark_dir, benchmark_id)
        
        # Save configuration
        self.save_config(output_dir)
        
        # Load images
        real_rgb = self.load_and_preprocess(real_path)
        fake_rgb = self.load_and_preprocess(fake_path)
        
        # Parse polygons
        real_polys = self.parse_color_coded_image(real_rgb)
        fake_polys = self.parse_fake_image(fake_rgb, list(real_polys.keys()))
        
        # Compute metrics
        geometric_metrics = self.compute_geometric_metrics(real_polys, fake_polys)
        image_metrics = self.compute_image_metrics(real_rgb, fake_rgb)
        
        # Combine metrics
        all_metrics = {**geometric_metrics, **image_metrics}
        
        # Save metrics
        metrics_df = pd.DataFrame([all_metrics])
        metrics_df.to_csv(output_dir / "metrics" / "single_image_metrics.csv", index=False)
        
        # Save visualization
        if self.save_visualizations:
            viz_path = output_dir / "visualizations" / f"{Path(fake_path).stem}_comparison.png"
            self.visualize_comparison(real_rgb, fake_rgb, real_polys, fake_polys, viz_path)
        
        return all_metrics, benchmark_id

# Part 2: Directory Evaluation and Usage

def evaluate_directory(
    real_dir, 
    fake_dir, 
    benchmark_dir="benchmark-outputs", 
    model_name=None, 
    pattern="*.png", 
    **kwargs
):
    """Evaluate all matching images in directories with organized outputs."""
    evaluator = ParcelEvaluator(**kwargs)
    
    # Create benchmark ID and setup directories
    benchmark_id = evaluator.create_benchmark_id(model_name)
    output_dir = evaluator.setup_output_directory(benchmark_dir, benchmark_id)
    
    # Save configuration
    evaluator.save_config(output_dir)
    
    # Get all matching files
    real_dir = Path(real_dir)
    fake_dir = Path(fake_dir)
    real_files = sorted(real_dir.glob(pattern))
    fake_files = sorted(fake_dir.glob(pattern))
    
    if len(real_files) != len(fake_files):
        raise ValueError(f"Number of files doesn't match: {len(real_files)} vs {len(fake_files)}")
    
    # Store results
    all_results = []
    
    for real_file, fake_file in tqdm(zip(real_files, fake_files), 
                                    total=len(real_files),
                                    desc="Processing images"):
        try:
            metrics, _ = evaluator.evaluate(
                real_file, 
                fake_file, 
                output_dir / "visualizations"
            )
            metrics['file_name'] = fake_file.name
            all_results.append(metrics)
        except Exception as e:
            print(f"Error processing {fake_file.name}: {str(e)}")
            continue
    
    # Create summary DataFrame
    df = pd.DataFrame(all_results)
    
    # Save detailed results
    df.to_csv(output_dir / "metrics" / "detailed_metrics.csv", index=False)
    
    # Compute and save summary statistics
    summary_stats = df.describe()
    summary_stats.to_csv(output_dir / "metrics" / "summary_statistics.csv")
    
    # Save averages in a more readable format
    avg_metrics = df.mean(numeric_only=True)
    summary_dict = {
        'average_metrics': avg_metrics.to_dict(),
        'total_images_processed': len(all_results),
        'timestamp': datetime.now().isoformat()
    }
    
    with open(output_dir / "metrics" / "summary.json", 'w') as f:
        json.dump(summary_dict, f, indent=4)
    
    print(f"\nResults saved to: {output_dir}")
    print("\nAverage Metrics:")
    print("-" * 50)
    for metric, value in avg_metrics.items():
        print(f"{metric}: {value:.4f}")
    
    return df, benchmark_id

# Example usage
if __name__ == "__main__":
    # Example paths
    fake_path = "../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images/combined_200035_fake_B.png"
    real_path = "../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images/combined_200035_real_B.png"
    
    # Single image evaluation
    evaluator = ParcelEvaluator(
        min_area=50,
        color_dist_threshold=30,
        win_size_for_ssim=3,
        save_visualizations=True
    )
    
    metrics, benchmark_id = evaluator.evaluate(
        real_path, 
        fake_path, 
        model_name="pix2pix_boston"
    )
    
    print(f"\nBenchmark ID: {benchmark_id}")
    print("\nMetrics for single image:")
    print("-" * 50)
    for metric, value in metrics.items():
        if isinstance(value, (int, float)):
            print(f"{metric}: {value:.4f}")
    
    # For directory evaluation, uncomment and modify paths:
    """
    df, benchmark_id = evaluate_directory(
        real_dir="path/to/real/images",
        fake_dir="path/to/fake/images",
        model_name="pix2pix_boston",
        pattern="*.png",
        min_area=50,
        color_dist_threshold=30
    )
    """

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]


Benchmark ID: pix2pix_boston_20241222_203730

Metrics for single image:
--------------------------------------------------
polygon_count_ratio: 1.0000
mean_area_ratio: 1.4528
mean_iou: 0.3143
mse: 0.0525
psnr: 12.7975
