In [1]:
# Cell 1: All imports

import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
import json
from tqdm.notebook import tqdm
from shapely.geometry import Polygon
from shapely.ops import unary_union
from skimage.metrics import (
    mean_squared_error, 
    structural_similarity as ssim, 
    peak_signal_noise_ratio
)
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Cell 2: Configuration Parameters

# Core Parcel Evaluation Configuration
PARCEL_CONFIG = {
    # Core Parameters
    'MIN_AREA': 50,                    # Minimum area for a parcel to be considered valid
    'COLOR_DIST_THRESHOLD': 30,        # Maximum color distance for matching
    'WIN_SIZE_FOR_SSIM': 3,           # Window size for SSIM calculation
    
    # Processing Parameters
    'CHUNK_SIZE': 100,                # Size of chunks for processing large images
    'CONTOUR_EPSILON_FACTOR': 0.02,   # Factor for polygon approximation
    
    # Visualization Parameters
    'SAVE_VISUALIZATIONS': True,      # Whether to save comparison visualizations
    'DPI': 150,                       # DPI for saved visualizations
    'FIGURE_SIZE': (12, 6),           # Size of comparison figures
    'POLYGON_LINE_WIDTH': 1,          # Width of polygon outlines in visualizations
    'REAL_POLYGON_COLOR': 'g',        # Color for real polygons
    'FAKE_POLYGON_COLOR': 'r'         # Color for fake polygons
}

# Directory Evaluation Configuration
DIR_EVAL_CONFIG = {
    # File Pattern Settings
    'REAL_SUFFIX': 'real_B.png',      # Suffix for real images
    'FAKE_SUFFIX': 'fake_B.png',      # Suffix for fake images
    'DEFAULT_PATTERN': '*.png',        # Default file pattern for searching
    
    # Output Settings
    'DEFAULT_BENCHMARK_DIR': 'benchmark-outputs',  # Default output directory
    'SAVE_DETAILED_METRICS': True,     # Whether to save per-image metrics
    'SAVE_SUMMARY_STATS': True,        # Whether to save summary statistics
    
    # Processing Settings
    'PARALLEL_PROCESSING': False,      # Whether to use parallel processing
    'NUM_WORKERS': 4,                  # Number of workers for parallel processing
    
    # Reporting Settings
    'REPORT_DECIMAL_PLACES': 4,        # Number of decimal places in reports
    'INCLUDE_TIMESTAMPS': True         # Whether to include timestamps in reports
}

# Multi-model Comparison Configuration
COMPARISON_CONFIG = {
    # Output Settings
    'DEFAULT_OUTPUT_DIR': 'benchmark-outputs/multi_model_comparison',
    'PLOTS_SUBDIRECTORY': 'comparative_plots',
    'REPORTS_SUBDIRECTORY': 'reports',
    
    # Visualization Settings
    'PLOT_DPI': 300,
    'BAR_PLOT_SIZE': (12, 6),
    'HEATMAP_SIZE': (15, 10),
    'VIOLIN_PLOT_SIZE': (12, 6),
    'COLOR_PALETTE': 'husl',          # seaborn color palette
    'PLOT_GRID_STYLE': '--',
    'PLOT_GRID_ALPHA': 0.7,
    
    # Metrics Configuration
    'KEY_METRICS': {
        'mean_iou': 'Mean IoU',
        'ssim': 'SSIM',
        'psnr': 'PSNR',
        'polygon_count_ratio': 'Polygon Count Ratio',
        'mean_area_ratio': 'Mean Area Ratio'
    },
    
    # Metrics where higher values are better
    'HIGHER_BETTER_METRICS': [
        'ssim', 
        'psnr', 
        'mean_iou'
    ],
    
    # Report Settings
    'DECIMAL_PLACES': 4,
    'INCLUDE_TIMESTAMPS': True,
    
    # Model Paths
    'MODEL_PATHS': {
        'brooklyn-boston-model': "../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
        'brooklyn-charlotte-model': "../data/ny-brooklyn/nc-charlotte-500-150-v100/test_latest_500e-Brooklyn/images",
        'brooklyn-manhattan-model': "../data/ny-brooklyn/ny-manhattan-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
        'brooklyn-pittsburgh-model': "../data/ny-brooklyn/pa-pittsburgh-p2p-500-150-v100/test_latest_500e-Brooklyn/images"
    }
}

In [3]:
# Cell 3: ParcelEvaluator Class

class ParcelEvaluator:
    def __init__(
        self,
        min_area=PARCEL_CONFIG['MIN_AREA'],
        color_dist_threshold=PARCEL_CONFIG['COLOR_DIST_THRESHOLD'],
        win_size_for_ssim=PARCEL_CONFIG['WIN_SIZE_FOR_SSIM'],
        save_visualizations=PARCEL_CONFIG['SAVE_VISUALIZATIONS']
    ):
        """Initialize the ParcelEvaluator with configuration parameters."""
        self.min_area = min_area
        self.color_dist_threshold = color_dist_threshold
        self.win_size_for_ssim = win_size_for_ssim
        self.save_visualizations = save_visualizations
        self.chunk_size = PARCEL_CONFIG['CHUNK_SIZE']
        self.contour_epsilon_factor = PARCEL_CONFIG['CONTOUR_EPSILON_FACTOR']

    def create_benchmark_id(self, model_name=None):
        """Create a unique identifier for this benchmark run."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if model_name:
            return f"{model_name}_{timestamp}"
        return f"benchmark_{timestamp}"

    def setup_output_directory(self, base_dir, benchmark_id):
        """Create and setup output directory structure."""
        output_dir = Path(base_dir) / benchmark_id
        
        # Create subdirectories
        (output_dir / "visualizations").mkdir(parents=True, exist_ok=True)
        (output_dir / "metrics").mkdir(parents=True, exist_ok=True)
        
        return output_dir

    def save_config(self, output_dir):
        """Save evaluator configuration."""
        config = {
            'min_area': self.min_area,
            'color_dist_threshold': self.color_dist_threshold,
            'win_size_for_ssim': self.win_size_for_ssim,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(output_dir / "metrics" / "config.json", 'w') as f:
            json.dump(config, f, indent=4)

    def load_and_preprocess(self, image_path):
        """Load and preprocess image to RGB."""
        img_bgr = cv2.imread(str(image_path))
        if img_bgr is None:
            raise ValueError(f"Could not load image: {image_path}")
        return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    def parse_color_coded_image(self, image_rgb):
        """Extract exact color polygons from real image."""
        h, w = image_rgb.shape[:2]
        color2mask = {}
        image_rgb = image_rgb.astype(np.float32)

        for y_start in range(0, h, self.chunk_size):
            y_end = min(y_start + self.chunk_size, h)
            for x_start in range(0, w, self.chunk_size):
                x_end = min(x_start + self.chunk_size, w)
                chunk = image_rgb[y_start:y_end, x_start:x_end]
                
                for y in range(chunk.shape[0]):
                    for x in range(chunk.shape[1]):
                        c = tuple(chunk[y, x])
                        if c not in color2mask:
                            color2mask[c] = np.zeros((h, w), dtype=np.uint8)
                        color2mask[c][y_start + y, x_start + x] = 1

        return self._masks_to_polygons(color2mask)

    def parse_fake_image(self, fake_rgb, target_colors):
        """Parse fake image using color threshold approach."""
        h, w = fake_rgb.shape[:2]
        fake_rgb = fake_rgb.astype(np.float32)
        color2mask = {tuple(float(x) for x in c): np.zeros((h, w), dtype=np.uint8) 
                     for c in target_colors}

        for y_start in tqdm(range(0, h, self.chunk_size), desc="Processing fake image", leave=False):
            y_end = min(y_start + self.chunk_size, h)
            for x_start in range(0, w, self.chunk_size):
                x_end = min(x_start + self.chunk_size, w)
                chunk = fake_rgb[y_start:y_end, x_start:x_end]
                
                for y in range(chunk.shape[0]):
                    for x in range(chunk.shape[1]):
                        pixel = chunk[y, x]
                        best_color = min(
                            color2mask.keys(),
                            key=lambda c: np.sqrt(np.sum((pixel - c) ** 2))
                        )
                        if np.sqrt(np.sum((pixel - best_color) ** 2)) < self.color_dist_threshold:
                            color2mask[best_color][y_start + y, x_start + x] = 1

        return self._masks_to_polygons(color2mask)

    def _masks_to_polygons(self, color2mask):
        """Convert masks to merged polygons."""
        color2poly = {}
        for color, mask in color2mask.items():
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            polys = []
            for cnt in contours:
                area = cv2.contourArea(cnt)
                if area < self.min_area:
                    continue
                ep = self.contour_epsilon_factor * cv2.arcLength(cnt, True)
                approx = cv2.approxPolyDP(cnt, ep, True)
                coords = np.squeeze(approx).reshape(-1, 2) if approx.size > 0 else np.array([])
                if coords.size > 0:
                    try:
                        p = Polygon(coords)
                        if p.is_valid and p.area >= self.min_area:
                            polys.append(p)
                    except Exception:
                        continue
                    
            if not polys:
                continue
                
            try:
                merged = unary_union(polys)
                if merged.geom_type == 'MultiPolygon':
                    big = max(merged.geoms, key=lambda g: g.area)
                    color2poly[color] = big
                else:
                    color2poly[color] = merged
            except Exception as e:
                print(f"Error merging polygons for color {color}: {str(e)}")
                continue
                
        return color2poly

    def compute_geometric_metrics(self, real_polys, fake_polys):
        """Compute geometric comparison metrics."""
        metrics = {
            'polygon_count_ratio': len(fake_polys) / len(real_polys) if real_polys else 0,
            'area_ratios': [],
            'iou_scores': []
        }
        
        for color, real_poly in real_polys.items():
            if color in fake_polys:
                fake_poly = fake_polys[color]
                # Area ratio
                area_ratio = fake_poly.area / real_poly.area if real_poly.area > 0 else 0
                metrics['area_ratios'].append(area_ratio)
                
                # IoU
                try:
                    intersection = real_poly.intersection(fake_poly).area
                    union = real_poly.union(fake_poly).area
                    iou = intersection / union if union > 0 else 0
                    metrics['iou_scores'].append(iou)
                except Exception:
                    continue

        metrics['mean_area_ratio'] = np.mean(metrics['area_ratios']) if metrics['area_ratios'] else 0
        metrics['mean_iou'] = np.mean(metrics['iou_scores']) if metrics['iou_scores'] else 0
        return metrics

    def compute_image_metrics(self, real_rgb, fake_rgb):
        """Compute traditional image comparison metrics."""
        if real_rgb.shape != fake_rgb.shape:
            return {'mse': float('inf'), 'psnr': 0, 'ssim': 0}
            
        real_f = real_rgb.astype(np.float32) / 255
        fake_f = fake_rgb.astype(np.float32) / 255
        
        metrics = {
            'mse': mean_squared_error(real_f, fake_f),
            'psnr': peak_signal_noise_ratio(real_f, fake_f, data_range=1.0),
            'ssim': ssim(real_f, fake_f, data_range=1.0, 
                        channel_axis=2, win_size=self.win_size_for_ssim)
        }
        return metrics

    def visualize_comparison(self, real_rgb, fake_rgb, real_polys, fake_polys, output_path):
        """Create visualization of the comparison."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=PARCEL_CONFIG['FIGURE_SIZE'])
        
        # Plot fake image and its polygons
        ax1.imshow(fake_rgb)
        for poly in fake_polys.values():
            if poly.is_valid:
                x, y = poly.exterior.xy
                ax1.plot(x, y, f"{PARCEL_CONFIG['FAKE_POLYGON_COLOR']}-", 
                        linewidth=PARCEL_CONFIG['POLYGON_LINE_WIDTH'])
        ax1.set_title('Generated Image with Detected Parcels')
        ax1.axis('off')
        
        # Plot real image and its polygons
        ax2.imshow(real_rgb)
        for poly in real_polys.values():
            if poly.is_valid:
                x, y = poly.exterior.xy
                ax2.plot(x, y, f"{PARCEL_CONFIG['REAL_POLYGON_COLOR']}-", 
                        linewidth=PARCEL_CONFIG['POLYGON_LINE_WIDTH'])
        ax2.set_title('Ground Truth with Parcels')
        ax2.axis('off')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=PARCEL_CONFIG['DPI'], bbox_inches='tight')
        plt.close()

    def evaluate(self, real_path, fake_path, benchmark_dir="benchmark-outputs", model_name=None):
        """Main evaluation function."""
        benchmark_id = self.create_benchmark_id(model_name)
        output_dir = self.setup_output_directory(benchmark_dir, benchmark_id)
        
        self.save_config(output_dir)
        
        real_rgb = self.load_and_preprocess(real_path)
        fake_rgb = self.load_and_preprocess(fake_path)
        
        real_polys = self.parse_color_coded_image(real_rgb)
        fake_polys = self.parse_fake_image(fake_rgb, list(real_polys.keys()))
        
        geometric_metrics = self.compute_geometric_metrics(real_polys, fake_polys)
        image_metrics = self.compute_image_metrics(real_rgb, fake_rgb)
        
        all_metrics = {**geometric_metrics, **image_metrics}
        
        metrics_df = pd.DataFrame([all_metrics])
        metrics_df.to_csv(output_dir / "metrics" / "single_image_metrics.csv", index=False)
        
        if self.save_visualizations:
            viz_path = output_dir / "visualizations" / f"{Path(fake_path).stem}_comparison.png"
            self.visualize_comparison(real_rgb, fake_rgb, real_polys, fake_polys, viz_path)
        
        return all_metrics, benchmark_id

In [4]:
# Cell 4: Directory Evaluation Functions

def evaluate_directory(
    real_dir, 
    fake_dir, 
    benchmark_dir=DIR_EVAL_CONFIG['DEFAULT_BENCHMARK_DIR'], 
    model_name=None, 
    pattern=DIR_EVAL_CONFIG['DEFAULT_PATTERN'], 
    **kwargs
):
    """Evaluate all matching images in directories with organized outputs."""
    evaluator = ParcelEvaluator(**kwargs)
    
    # Create benchmark ID and setup directories
    benchmark_id = evaluator.create_benchmark_id(model_name)
    output_dir = evaluator.setup_output_directory(benchmark_dir, benchmark_id)
    
    # Save configuration
    evaluator.save_config(output_dir)
    
    # Get all matching files
    real_dir = Path(real_dir)
    fake_dir = Path(fake_dir)
    
    # Find pairs of real and fake images
    real_files = []
    fake_files = []
    
    # Look for real_B images and their corresponding fake_B images
    for real_file in sorted(real_dir.glob(f"*{DIR_EVAL_CONFIG['REAL_SUFFIX']}")):
        fake_file = fake_dir / real_file.name.replace(
            DIR_EVAL_CONFIG['REAL_SUFFIX'], 
            DIR_EVAL_CONFIG['FAKE_SUFFIX']
        )
        if fake_file.exists():
            real_files.append(real_file)
            fake_files.append(fake_file)
    
    if not real_files:
        raise ValueError(f"No matching image pairs found in {real_dir} and {fake_dir}")
    
    # Store results
    all_results = []
    
    for real_file, fake_file in tqdm(zip(real_files, fake_files), 
                                    total=len(real_files),
                                    desc=f"Processing {model_name if model_name else 'images'}"):
        try:
            metrics, _ = evaluator.evaluate(
                real_file, 
                fake_file, 
                output_dir / "visualizations"
            )
            metrics['file_name'] = fake_file.name
            all_results.append(metrics)
        except Exception as e:
            print(f"Error processing {fake_file.name}: {str(e)}")
            continue
    
    # Create summary DataFrame
    df = pd.DataFrame(all_results)
    
    if DIR_EVAL_CONFIG['SAVE_DETAILED_METRICS']:
        df.to_csv(output_dir / "metrics" / "detailed_metrics.csv", index=False)
    
    if DIR_EVAL_CONFIG['SAVE_SUMMARY_STATS']:
        # Compute and save summary statistics
        summary_stats = df.describe()
        summary_stats.to_csv(output_dir / "metrics" / "summary_statistics.csv")
        
        # Save averages in a more readable format
        avg_metrics = df.mean(numeric_only=True).round(DIR_EVAL_CONFIG['REPORT_DECIMAL_PLACES'])
        summary_dict = {
            'model_name': model_name,
            'average_metrics': avg_metrics.to_dict(),
            'total_images_processed': len(all_results),
            'successful_evaluations': len(all_results),
            'failed_evaluations': len(real_files) - len(all_results)
        }
        
        if DIR_EVAL_CONFIG['INCLUDE_TIMESTAMPS']:
            summary_dict['timestamp'] = datetime.now().isoformat()
    
        with open(output_dir / "metrics" / "summary.json", 'w') as f:
            json.dump(summary_dict, f, indent=4)
    
    print(f"\nResults saved to: {output_dir}")
    print("\nAverage Metrics:")
    print("-" * 50)
    for metric, value in avg_metrics.items():
        if isinstance(value, (int, float)):
            print(f"{metric}: {value:.{DIR_EVAL_CONFIG['REPORT_DECIMAL_PLACES']}f}")
    
    return df, benchmark_id

def create_metric_summary(metrics_df, metric_name):
    """Create a summary of a specific metric from the results DataFrame."""
    summary = {
        'mean': metrics_df[metric_name].mean(),
        'std': metrics_df[metric_name].std(),
        'min': metrics_df[metric_name].min(),
        'max': metrics_df[metric_name].max(),
        'median': metrics_df[metric_name].median()
    }
    return {k: round(v, DIR_EVAL_CONFIG['REPORT_DECIMAL_PLACES']) for k, v in summary.items()}

def generate_model_report(df, model_name, output_dir):
    """Generate a detailed report for a single model's performance."""
    report = {
        'model_name': model_name,
        'number_of_images': len(df),
        'metrics': {}
    }
    
    # Get metrics (excluding non-numeric columns)
    metric_columns = df.select_dtypes(include=['float64', 'int64']).columns
    
    for metric in metric_columns:
        report['metrics'][metric] = create_metric_summary(df, metric)
    
    if DIR_EVAL_CONFIG['INCLUDE_TIMESTAMPS']:
        report['timestamp'] = datetime.now().isoformat()
    
    # Save report
    output_path = Path(output_dir) / f"{model_name}_detailed_report.json"
    with open(output_path, 'w') as f:
        json.dump(report, f, indent=4)
    
    return report

In [5]:
# Cell 5: Multi-model Comparison Functions

def evaluate_multiple_models(
    model_paths=COMPARISON_CONFIG['MODEL_PATHS'], 
    benchmark_dir=COMPARISON_CONFIG['DEFAULT_OUTPUT_DIR']
):
    """Evaluate multiple models and compare their performance."""
    
    # Create parent directory for this comparison
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    parent_dir = Path(benchmark_dir) / f"comparison_{timestamp}"
    parent_dir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectories
    plots_dir = parent_dir / COMPARISON_CONFIG['PLOTS_SUBDIRECTORY']
    reports_dir = parent_dir / COMPARISON_CONFIG['REPORTS_SUBDIRECTORY']
    plots_dir.mkdir(exist_ok=True)
    reports_dir.mkdir(exist_ok=True)
    
    # Store results for each model
    model_results = {}
    
    # Process each model
    for model_name, base_path in model_paths.items():
        print(f"\nProcessing {model_name}...")
        base_path = Path(base_path)
        
        # Evaluate this model
        df, benchmark_id = evaluate_directory(
            real_dir=base_path,
            fake_dir=base_path,
            benchmark_dir=parent_dir / model_name,
            model_name=model_name
        )
        
        # Generate detailed report
        report = generate_model_report(df, model_name, reports_dir)
        
        # Store results
        model_results[model_name] = {
            'metrics_df': df,
            'benchmark_id': benchmark_id,
            'avg_metrics': df.mean(numeric_only=True),
            'report': report
        }
    
    # Create comparison visualizations
    create_comparison_plots(model_results, plots_dir)
    
    # Save comparative summary
    save_comparative_summary(model_results, parent_dir)
    
    return model_results, parent_dir

def create_comparison_plots(model_results, output_dir):
    """Create comparative visualizations of model performances."""
    # Get consistent colors for models
    colors = sns.color_palette(COMPARISON_CONFIG['COLOR_PALETTE'], 
                             n_colors=len(model_results))
    model_colors = dict(zip(model_results.keys(), colors))
    
    # Create individual metric comparisons
    for metric, metric_label in COMPARISON_CONFIG['KEY_METRICS'].items():
        # Bar plot
        plt.figure(figsize=COMPARISON_CONFIG['BAR_PLOT_SIZE'])
        values = [results['avg_metrics'][metric] for results in model_results.values()]
        models = list(model_results.keys())
        
        bars = plt.bar(models, values, color=[model_colors[model] for model in models])
        
        plt.title(f'Comparison of {metric_label}', pad=20)
        plt.xticks(rotation=45, ha='right')
        plt.ylabel(metric_label)
        
        # Add value labels
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.{COMPARISON_CONFIG["DECIMAL_PLACES"]}f}',
                    ha='center', va='bottom')
        
        plt.grid(True, axis='y', linestyle=COMPARISON_CONFIG['PLOT_GRID_STYLE'], 
                alpha=COMPARISON_CONFIG['PLOT_GRID_ALPHA'])
        plt.tight_layout()
        plt.savefig(output_dir / f'{metric}_comparison.png', 
                   dpi=COMPARISON_CONFIG['PLOT_DPI'], 
                   bbox_inches='tight')
        plt.close()
        
        # Violin plot
        plt.figure(figsize=COMPARISON_CONFIG['VIOLIN_PLOT_SIZE'])
        data_dict = {model_name: results['metrics_df'][metric].values 
                    for model_name, results in model_results.items()}
        
        violin_parts = plt.violinplot([data_dict[model] for model in models],
                                    showmeans=True, showmedians=True)
        
        plt.title(f'Distribution of {metric_label} Across Images', pad=20)
        plt.xticks(range(1, len(models) + 1), models, rotation=45, ha='right')
        plt.ylabel(metric_label)
        plt.grid(True, axis='y', linestyle=COMPARISON_CONFIG['PLOT_GRID_STYLE'], 
                alpha=COMPARISON_CONFIG['PLOT_GRID_ALPHA'])
        
        plt.tight_layout()
        plt.savefig(output_dir / f'{metric}_distribution.png', 
                   dpi=COMPARISON_CONFIG['PLOT_DPI'], 
                   bbox_inches='tight')
        plt.close()
    
    # Create heatmap
    plt.figure(figsize=COMPARISON_CONFIG['HEATMAP_SIZE'])
    metrics_data = pd.DataFrame({
        model_name: results['avg_metrics'] 
        for model_name, results in model_results.items()
    })
    
    sns.heatmap(metrics_data, annot=True, 
                fmt=f'.{COMPARISON_CONFIG["DECIMAL_PLACES"]}f', 
                cmap='YlOrRd',
                cbar_kws={'label': 'Metric Value'})
    plt.title('Model Comparison Heatmap')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(output_dir / 'metrics_heatmap.png', 
                dpi=COMPARISON_CONFIG['PLOT_DPI'], 
                bbox_inches='tight')
    plt.close()

def save_comparative_summary(model_results, output_dir):
    """Save comparative summary of all models."""
    summary = {
        'models_compared': list(model_results.keys()),
        'model_metrics': {
            model_name: results['avg_metrics'].to_dict()
            for model_name, results in model_results.items()
        }
    }
    
    if COMPARISON_CONFIG['INCLUDE_TIMESTAMPS']:
        summary['timestamp'] = datetime.now().isoformat()
    
    # Determine best model for each metric
    best_models = {}
    for metric in COMPARISON_CONFIG['KEY_METRICS'].keys():
        metric_values = {model: metrics[metric] 
                        for model, metrics in summary['model_metrics'].items()}
        
        if metric in COMPARISON_CONFIG['HIGHER_BETTER_METRICS']:
            best_model = max(metric_values.items(), key=lambda x: x[1])
        else:  # For metrics where closer to 1.0 is better
            best_model = min(metric_values.items(), 
                           key=lambda x: abs(x[1] - 1.0))
        
        best_models[metric] = {
            'best_model': best_model[0],
            'value': round(best_model[1], COMPARISON_CONFIG['DECIMAL_PLACES']),
            'all_values': {k: round(v, COMPARISON_CONFIG['DECIMAL_PLACES']) 
                          for k, v in metric_values.items()}
        }
    
    summary['best_models'] = best_models
    
    # Calculate overall ranking
    model_scores = {model: 0 for model in model_results.keys()}
    for metric_result in best_models.values():
        model_scores[metric_result['best_model']] += 1
    
    summary['overall_ranking'] = dict(sorted(model_scores.items(), 
                                           key=lambda x: x[1], 
                                           reverse=True))
    
    # Save summary
    with open(output_dir / 'comparative_summary.json', 'w') as f:
        json.dump(summary, f, indent=4)
    
    # Print summary
    print("\nComparative Summary:")
    print("-" * 50)
    print("\nBest performing models by metric:")
    for metric, result in best_models.items():
        print(f"{metric}: {result['best_model']} "
              f"({result['value']:.{COMPARISON_CONFIG['DECIMAL_PLACES']}f})")
    
    print("\nOverall Ranking (number of metrics won):")
    for model, score in summary['overall_ranking'].items():
        print(f"{model}: {score}")

In [None]:
# Cell 6: Run the Multi-model Evaluation

# Run evaluation using the configured model paths
results, output_dir = evaluate_multiple_models()
print(f"\nAll results saved to: {output_dir}")

# Access results for specific models if needed
for model_name, model_data in results.items():
    print(f"\nModel: {model_name}")
    print("Average Metrics:")
    for metric, value in model_data['avg_metrics'].items():
        if isinstance(value, (int, float)):
            print(f"{metric}: {value:.4f}")


Processing brooklyn-boston-model...


Processing brooklyn-boston-model:   0%|          | 0/1000 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]

Processing fake image:   0%|          | 0/3 [00:00<?, ?it/s]