In [None]:
import os
import re
import cv2
import json
import math
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from shapely.geometry import Polygon
from shapely.ops import unary_union
from tqdm.auto import tqdm
from datetime import datetime
import seaborn as sns
import gc

# For geometry metrics
from shapely.ops import unary_union

# For image-level metrics
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio

# For KL-divergence
from scipy.stats import entropy as kl_divergence

###############################################################################
# Global config
###############################################################################

matplotlib.use('Agg')  # non-interactive backend for saving plots

SAMPLE_SIZE = 5
TEST_FOLDERS = {
    'brooklyn-boston-model': "../data/ny-brooklyn/ma-boston-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
    'brooklyn-charlotte-model': "../data/ny-brooklyn/nc-charlotte-500-150-v100/test_latest_500e-Brooklyn/images",
    'brooklyn-manhattan-model': "../data/ny-brooklyn/ny-manhattan-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
    'brooklyn-pittsburgh-model': "../data/ny-brooklyn/pa-pittsburgh-p2p-500-150-v100/test_latest_500e-Brooklyn/images",
}

def cleanup_memory():
    gc.collect()

def generate_output_dir(test_folders, sample_size):
    first_folder = list(test_folders.keys())[0]
    base_name = first_folder.split('-')[0]
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    identifier = f"{base_name}_n{sample_size}_{timestamp}"
    return os.path.join("benchmark-output", identifier)

OUTPUT_DIR = generate_output_dir(TEST_FOLDERS, SAMPLE_SIZE)
print(f"All results will be saved under: {OUTPUT_DIR}")

MIN_AREA = 50  # Minimum area for a valid parcel

###############################################################################
# 1) Parse color-coded image into dict[color] -> Polygon
###############################################################################

def parse_color_coded_image(image_rgb, min_area=MIN_AREA):
    """
    Given an RGB image where each parcel is a unique color,
    return a dict: { (R,G,B) -> shapely Polygon }

    If there's more than one connected component for a given color,
    we merge them via unary_union, picking the biggest if it's multipolygon.
    """
    h, w = image_rgb.shape[:2]
    color2mask = {}

    # Build up a mask for each color
    # shape: (H, W, 3) => each pixel is (R, G, B)
    for y in range(h):
        for x in range(w):
            color = tuple(image_rgb[y, x])  # (R, G, B)
            if color not in color2mask:
                color2mask[color] = np.zeros((h, w), dtype=np.uint8)
            color2mask[color][y, x] = 1

    color2poly = {}
    for color, mask in color2mask.items():
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        polys = []
        for cnt in contours:
            area = cv2.contourArea(cnt)
            if area < min_area:
                continue
            epsilon = 0.02 * cv2.arcLength(cnt, True)
            approx = cv2.approxPolyDP(cnt, epsilon, True)
            coords = np.squeeze(approx).reshape(-1, 2)
            poly = Polygon(coords)
            if poly.is_valid and poly.area >= min_area:
                polys.append(poly)

        if not polys:
            continue

        merged = unary_union(polys)
        # If multipolygon, pick the largest
        if merged.geom_type == 'MultiPolygon':
            bigpoly = max(merged.geoms, key=lambda g: g.area)
            color2poly[color] = bigpoly
        else:
            color2poly[color] = merged

    return color2poly

###############################################################################
# 2) Compare dict[color]->Polygon across fake vs real
###############################################################################

def compare_color_coded_parcels(fake_color2poly, real_color2poly):
    """
    For each color in fake_color2poly, see if it exists in real_color2poly.
    We compute IoU, Dice, Hausdorff, etc.

    Return:
      - a list of records, each with { color, iou, dice, hausdorff, fake_area, real_area }
      - plus overall means
    """
    # store one row per color
    records = []
    # shared colors
    shared_colors = set(fake_color2poly.keys()).intersection(set(real_color2poly.keys()))

    for color in shared_colors:
        fpoly = fake_color2poly[color]
        rpoly = real_color2poly[color]
        if not fpoly.is_valid or not rpoly.is_valid:
            continue
        inter_area = fpoly.intersection(rpoly).area
        union_area = fpoly.union(rpoly).area
        iou = inter_area / union_area if union_area>0 else 0

        area_sum = fpoly.area + rpoly.area
        dice = 2.0*inter_area / area_sum if area_sum>0 else 0

        # Hausdorff
        try:
            haus = fpoly.hausdorff_distance(rpoly)
        except AttributeError:
            # older shapely
            from shapely.ops import hausdorff_distance
            haus = hausdorff_distance(fpoly, rpoly)

        rec = {
            'color': color,
            'iou': iou,
            'dice': dice,
            'hausdorff': haus,
            'fake_area': fpoly.area,
            'real_area': rpoly.area
        }
        records.append(rec)

    # Missing in real
    missing_in_real = set(fake_color2poly.keys()) - set(real_color2poly.keys())
    for c in missing_in_real:
        poly = fake_color2poly[c]
        rec = {
            'color': c,
            'iou': 0,
            'dice': 0,
            'hausdorff': None,
            'fake_area': poly.area if poly.is_valid else 0,
            'real_area': 0
        }
        records.append(rec)

    # Missing in fake
    missing_in_fake = set(real_color2poly.keys()) - set(fake_color2poly.keys())
    for c in missing_in_fake:
        poly = real_color2poly[c]
        rec = {
            'color': c,
            'iou': 0,
            'dice': 0,
            'hausdorff': None,
            'fake_area': 0,
            'real_area': poly.area if poly.is_valid else 0
        }
        records.append(rec)

    # Summaries
    if len(records)==0:
        return {
            'per_color': [],
            'mean_iou': 0,
            'mean_dice': 0,
            'mean_hausdorff': 0
        }

    valid_haus = [r for r in records if r['hausdorff'] is not None]
    mean_iou = np.mean([r['iou'] for r in records]) if records else 0
    mean_dice = np.mean([r['dice'] for r in records]) if records else 0
    mean_haus = np.mean([r['hausdorff'] for r in valid_haus]) if valid_haus else 0

    return {
        'per_color': records,
        'mean_iou': mean_iou,
        'mean_dice': mean_dice,
        'mean_hausdorff': mean_haus
    }

###############################################################################
# 3) Optional image-based metrics (MSE, PSNR, SSIM, etc.)
###############################################################################

def compute_image_metrics(fake_rgb, real_rgb):
    """
    Compare raw images for MSE, PSNR, SSIM, etc.
    """
    if fake_rgb.shape != real_rgb.shape:
        return {'mse': 0, 'psnr': 0, 'ssim': 0}

    fake_f = fake_rgb.astype(np.float32)/255.0
    real_f = real_rgb.astype(np.float32)/255.0

    mse_val = mean_squared_error(real_f, fake_f)
    psnr_val = peak_signal_noise_ratio(real_f, fake_f, data_range=1.0)
    try:
        ssim_val = ssim(real_f, fake_f, data_range=1.0, multichannel=True)
    except ValueError as e:
        print(f"SSIM error: {e}")
        ssim_val = 0

    return {
        'mse': mse_val,
        'psnr': psnr_val,
        'ssim': ssim_val
    }

###############################################################################
# 4) Visualization: side-by-side
###############################################################################

def create_side_by_side_visual(fake_rgb, real_rgb,
                               fake_color2poly, real_color2poly,
                               max_count=50):
    """
    Renders a side-by-side figure: left=Fake, right=Real.
    We'll draw polygons in red for Fake, green for Real.
    """
    fig, axes = plt.subplots(1,2, figsize=(10,6))

    # Left: fake
    axes[0].imshow(fake_rgb)
    # Only draw up to max_count polygons
    fc = list(fake_color2poly.keys())
    random.shuffle(fc)
    for c in fc[:max_count]:
        poly = fake_color2poly[c]
        if not poly.is_valid:
            continue
        x,y = poly.exterior.xy
        axes[0].plot(x,y,'r',linewidth=1)
    axes[0].set_title(f"FAKE: {len(fc)} color-coded parcels")

    # Right: real
    axes[1].imshow(real_rgb)
    rc = list(real_color2poly.keys())
    random.shuffle(rc)
    for c in rc[:max_count]:
        poly = real_color2poly[c]
        if not poly.is_valid:
            continue
        x,y = poly.exterior.xy
        axes[1].plot(x,y,'g',linewidth=1)
    axes[1].set_title(f"REAL: {len(rc)} color-coded parcels")

    plt.tight_layout()
    return fig

###############################################################################
# 5) Process a Single Pair
###############################################################################

def process_single_pair(fake_path, real_path):
    """
    - Load color-coded images
    - Parse each to dict[color]->Polygon
    - Compare per color
    - Compute image-based metrics
    - Return a big list of records, one row per color
    """
    fake_bgr = cv2.imread(fake_path)
    real_bgr = cv2.imread(real_path)

    if fake_bgr is None or real_bgr is None:
        print(f"Error loading images: {fake_path}, {real_path}")
        return []

    # Convert to RGB
    fake_rgb = cv2.cvtColor(fake_bgr, cv2.COLOR_BGR2RGB)
    real_rgb = cv2.cvtColor(real_bgr, cv2.COLOR_BGR2RGB)

    # Parse
    fake_color2poly = parse_color_coded_image(fake_rgb, MIN_AREA)
    real_color2poly = parse_color_coded_image(real_rgb, MIN_AREA)

    # Compare geometry
    color_metrics = compare_color_coded_parcels(fake_color2poly, real_color2poly)
    overall_iou = color_metrics['mean_iou']
    overall_dice = color_metrics['mean_dice']
    overall_haus = color_metrics['mean_hausdorff']

    # Image-based metrics
    img_metrics = compute_image_metrics(fake_rgb, real_rgb)
    # MSE, PSNR, SSIM

    # Create side-by-side figure
    fig = create_side_by_side_visual(fake_rgb, real_rgb,
                                     fake_color2poly, real_color2poly)
    # return fig so the caller can save it.

    # produce a row for each color in color_metrics['per_color']
    # plus add overall_iou/dice/haus.
    results = []
    base_name = os.path.basename(fake_path)
    file_id = base_name.split('_')[1] if '_' in base_name else base_name

    for rec in color_metrics['per_color']:
        row = dict(rec)  # iou, dice, color, etc.
        row['file_id'] = file_id
        row['fake_path'] = fake_path
        row['real_path'] = real_path
        # Overall geometry
        row['mean_iou_overall'] = overall_iou
        row['mean_dice_overall'] = overall_dice
        row['mean_hausdorff_overall'] = overall_haus
        # Image metrics
        row['mse'] = img_metrics['mse']
        row['psnr'] = img_metrics['psnr']
        row['ssim'] = img_metrics['ssim']
        results.append(row)

    return results, fig, file_id

###############################################################################
# 6) Process a Folder
###############################################################################

def process_test_folder(folder_path, output_dir, sample_size=SAMPLE_SIZE):
    print(f"Scanning folder: {folder_path}")
    pairs_dict = {}
    for fname in os.listdir(folder_path):
        if not fname.endswith('.png'):
            continue
        m = re.search(r'combined_(\d+)_(fake|real)_B\.png', fname)
        if not m:
            continue
        base_num, img_type = m.group(1), m.group(2)
        full_path = os.path.join(folder_path, fname)
        if base_num not in pairs_dict:
            pairs_dict[base_num] = {'fake': None, 'real': None}
        pairs_dict[base_num][img_type] = full_path

    complete_pairs = [(d['fake'], d['real'])
                      for d in pairs_dict.values()
                      if d['fake'] and d['real']]
    if sample_size and sample_size < len(complete_pairs):
        print(f"Sampling {sample_size} from total {len(complete_pairs)} pairs")
        complete_pairs = random.sample(complete_pairs, sample_size)
    else:
        print(f"Processing all {len(complete_pairs)} pairs")

    os.makedirs(output_dir, exist_ok=True)

    results = []
    with tqdm(total=len(complete_pairs), desc="Processing pairs") as pbar:
        for (fake_path, real_path) in complete_pairs:
            try:
                recs, fig, file_id = process_single_pair(fake_path, real_path)
                # Save figure
                if fig:
                    out_fig = os.path.join(output_dir, f"side_by_side_{file_id}.png")
                    fig.savefig(out_fig, dpi=150, bbox_inches='tight')
                    plt.close(fig)

                results.extend(recs)
            except Exception as e:
                print(f"Error on pair: {fake_path}, {real_path} => {e}")
            finally:
                cleanup_memory()
            pbar.update(1)

    return results

###############################################################################
# 7) Generate Reports
###############################################################################

def create_visualizations(df, output_dir):
    try:
        # pick some columns to visualise
        metrics = ['iou','dice','hausdorff','fake_area','real_area',
                   'mean_iou_overall','mean_dice_overall','mean_hausdorff_overall',
                   'mse','psnr','ssim']
        metrics = [m for m in metrics if m in df.columns]

        for metric in metrics:
            plt.figure(figsize=(8,6))
            sns.histplot(df[metric], kde=True)
            plt.title(metric)
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, f"{metric}_dist.png"))
            plt.close()

        if len(metrics)>1:
            corr = df[metrics].corr()
            plt.figure(figsize=(12,10))
            sns.heatmap(corr, annot=True, cmap='coolwarm', center=0)
            plt.title("Metric Correlations")
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, "metric_correlations.png"))
            plt.close()

    except Exception as e:
        print(f"Error in create_visualizations: {e}")

def generate_reports(all_results, output_dir):
    df = pd.DataFrame(all_results)
    out_csv = os.path.join(output_dir, "detailed_results.csv")
    df.to_csv(out_csv, index=False)
    print(f"Saved {len(df)} rows to {out_csv}")

    # Summaries
    agg_cols = ['iou','dice','hausdorff','fake_area','real_area',
                'mean_iou_overall','mean_dice_overall','mean_hausdorff_overall',
                'mse','psnr','ssim']
    summary = {}
    for col in agg_cols:
        if col in df.columns and len(df[col])>0:
            valid_vals = df[col].dropna()
            if len(valid_vals)>0:
                summary[col] = {
                    'mean': valid_vals.mean(),
                    'std':  valid_vals.std(),
                    'min':  valid_vals.min(),
                    'max':  valid_vals.max(),
                    'median': valid_vals.median()
                }

    with open(os.path.join(output_dir, "aggregate_metrics.json"), 'w') as f:
        json.dump(summary, f, indent=4)

    # Folder-level grouping if 'folder' column is present
    if 'folder' in df.columns:
        folderwise = df.groupby('folder')[agg_cols].mean().reset_index()
        folderwise.to_csv(os.path.join(output_dir, "folderwise_metrics.csv"), index=False)

    create_visualizations(df, output_dir)

###############################################################################
# 8) Main
###############################################################################

def main():
    print("Color-Coded Parcel Benchmarking")
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    valid_folders = {k:v for k,v in TEST_FOLDERS.items() if os.path.exists(v)}
    if not valid_folders:
        print("No valid folders found!")
        return

    all_results = []
    for folder_name, folder_path in valid_folders.items():
        print(f"\nProcessing {folder_name} => {folder_path}")
        out_subdir = os.path.join(OUTPUT_DIR, folder_name)
        os.makedirs(out_subdir, exist_ok=True)

        folder_res = process_test_folder(folder_path, out_subdir, SAMPLE_SIZE)
        # Add 'folder' column
        for r in folder_res:
            r['folder'] = folder_name
        all_results.extend(folder_res)

    # Once we have everything, generate global summary
    if all_results:
        print("\nGenerating global summary...")
        generate_reports(all_results, OUTPUT_DIR)
    else:
        print("No results found at all.")

    print("Done!")

if __name__ == "__main__":
    main()