Visualise the Results for Python
Native and ONNX logits comparision

In [None]:
import os
import time
import torch
import pandas as pd
import sklearn.metrics
import breizhcrops
from torch.utils.data import DataLoader, ConcatDataset
from torch.optim import Adam
from tqdm import tqdm 
from breizhcrops import BreizhCrops
from breizhcrops.models import TempCNN
import numpy as np
import onnxruntime as ort

DATA_PATH = "/breizh_data" 
NUM_WORKERS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEVEL = "L1C"
PRELOAD_RAM = False
BATCH_SIZE = 1024

tempcnn = breizhcrops.models.TempCNN()
model = tempcnn
device = "cuda" if torch.cuda.is_available() else "cpu"
frh04 = BreizhCrops(region="frh04", root=DATA_PATH, level=LEVEL, preload_ram=PRELOAD_RAM)
model.load_state_dict(torch.load("log_tempcnn/tempcnn_py_model.pt",map_location=device))
model.to(device)
model.eval()

test_loader = DataLoader(frh04, batch_size=BATCH_SIZE, shuffle=False)

sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 1
sess_options.inter_op_num_threads = 1
ort.set_default_logger_severity(2)  
providers = (
    ["CUDAExecutionProvider", "CPUExecutionProvider"]
    if "CUDAExecutionProvider" in ort.get_available_providers()
    else ["CPUExecutionProvider"]
)
onnx_model = "log_tempcnn/tempcnn_py_model.onnx"
session = ort.InferenceSession(onnx_model, providers=providers, sess_options=sess_options)

input_name = session.get_inputs()[0].name

all_native_logits = []
all_onnx_logits = []

for batch in test_loader:
    x = batch[0].to(device)

    # Native
    with torch.no_grad():
        y_native = model(x).cpu().numpy()

    # ONNX
    x_np = x.cpu().numpy().astype(np.float32)
    y_onnx = session.run(None, {input_name: x_np})[0]

    all_native_logits.append(y_native)
    all_onnx_logits.append(y_onnx)

native_logits = np.concatenate(all_native_logits, axis=0)
onnx_logits   = np.concatenate(all_onnx_logits, axis=0)

np.save("native_logits_python.npy", native_logits)
np.save("onnx_logits_python.npy", onnx_logits)

In [None]:
native_logits

In [None]:
onnx_logits

Computed MAE and RMSE

In [None]:
mae = np.mean(np.abs(native_logits - onnx_logits))
max_err = np.max(np.abs(native_logits - onnx_logits))
top1_agree = np.mean(np.argmax(native_logits,1) == np.argmax(onnx_logits,1))
print(f"MAE={mae:.2e}, Max|err|={max_err:.2e}, Top1 agreement={top1_agree*100:.4f}%")

rmse = np.sqrt(np.mean((native_logits - onnx_logits)**2))
from numpy.linalg import norm
cos = np.median([np.dot(a,b)/(norm(a)*norm(b) + 1e-12) for a,b in zip(native_logits, onnx_logits)])
print(f"RMSE={rmse:.2e}, median cosine={cos:.6f}")

Prediction Maps for FRH04 and Belle-Ile Region subset of FRH04

In [None]:
import torch
import numpy as np
import onnxruntime as ort
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.metrics import classification_report, accuracy_score
from collections import defaultdict
import os
DATA_PATH = "/breizh_data"
LEVEL = "L1C"
PRELOAD_RAM = False
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
C = 9  # number of classes

In [None]:
def setup_pytorch_model():
    """Load and setup PyTorch model"""
    frh04 = BreizhCrops(region="frh04", root=DATA_PATH, level=LEVEL, preload_ram=PRELOAD_RAM)
    # Load PyTorch model
    tempcnn = breizhcrops.models.TempCNN()
    model = tempcnn
    state = torch.load("/log_tempcnn/tempcnn_py_model.pt",map_location=DEVICE)
    model.load_state_dict(state)
    model.to(DEVICE).eval()
    return model, frh04

def setup_onnx_model():
    """Load and setup ONNX model"""
    sess_options = ort.SessionOptions()
    sess_options.intra_op_num_threads = 1
    sess_options.inter_op_num_threads = 1
    ort.set_default_logger_severity(2)
    
    providers = (
        ["CUDAExecutionProvider", "CPUExecutionProvider"]
        if "CUDAExecutionProvider" in ort.get_available_providers()
        else ["CPUExecutionProvider"]
    )
    
    session = ort.InferenceSession(
        "/log_tempcnn/tempcnn_py_model.onnx",
        sess_options=sess_options,
        providers=providers
    )
    
    return session

def compare_single_sample(pytorch_model, onnx_session, x, y_true, field_id, sample_idx, verbose=False):
    """Compare PyTorch and ONNX predictions on a single sample"""
    
    input_name = onnx_session.get_inputs()[0].name
    output_name = onnx_session.get_outputs()[0].name
    softmax = torch.nn.Softmax(dim=1)
    x_tensor = x.unsqueeze(0).to(DEVICE)  
    x_numpy = x.unsqueeze(0).numpy().astype(np.float32) 
    
    # PyTorch inference
    with torch.no_grad():
        logits_pt = pytorch_model(x_tensor)  
        probs_pt = softmax(logits_pt)  
        pred_pt = probs_pt.argmax(dim=1).item()
        conf_pt = probs_pt.max(dim=1).values.item()
        probs_pt_np = probs_pt.cpu().numpy()[0]  
    
    # ONNX inference
    log_probs_onnx = onnx_session.run([output_name], {input_name: x_numpy})[0][0]  
    probs_onnx = np.exp(log_probs_onnx)  
    pred_onnx = int(np.argmax(probs_onnx))
    conf_onnx = float(np.max(probs_onnx))
    
    # Calculate differences
    prob_diff = np.abs(probs_pt_np - probs_onnx)
    max_prob_diff = np.max(prob_diff)
    
    results = {
        'sample_idx': sample_idx,
        'field_id': field_id,
        'ground_truth': y_true,
        'pytorch_pred': pred_pt,
        'onnx_pred': pred_onnx,
        'pytorch_conf': conf_pt,
        'onnx_conf': conf_onnx,
        'predictions_match': pred_pt == pred_onnx,
        'pytorch_correct': pred_pt == y_true,
        'onnx_correct': pred_onnx == y_true,
        'max_prob_diff': max_prob_diff,
        'probs_pytorch': probs_pt_np.copy(),
        'probs_onnx': probs_onnx.copy(),
        'prob_differences': prob_diff
    }
    
    if verbose:
        print(f"\nSample {sample_idx} (Field {field_id}):")
        print(f"  Ground Truth: {y_true}")
        print(f"  PyTorch:  Pred={pred_pt}, Conf={conf_pt:.4f}, Correct={pred_pt==y_true}")
        print(f"  ONNX:     Pred={pred_onnx}, Conf={conf_onnx:.4f}, Correct={pred_onnx==y_true}")
        print(f"  Match: {pred_pt == pred_onnx}, Max Prob Diff: {max_prob_diff:.6f}")
    
    return results

def analyze_disagreement_patterns(disagreements, frh04):
    """Analyze patterns in disagreements between models"""
    
    print("\n" + "="*80)
    print("DISAGREEMENT PATTERN ANALYSIS")
    print("="*80)
    
    total_disagreements = len(disagreements)
    print(f"Total disagreements: {total_disagreements}")
    
    # Ground truth class distribution of disagreements
    gt_classes = [d['ground_truth'] for d in disagreements]
    gt_counts = np.bincount(gt_classes, minlength=C)
    
    print(f"\nDisagreements by Ground Truth Class:")
    for class_idx in range(C):
        class_name = frh04.classname[class_idx]
        count = gt_counts[class_idx]
        percentage = count / total_disagreements * 100 if total_disagreements > 0 else 0
        print(f"  {class_idx:>2} ({class_name:<20}): {count:>4} ({percentage:5.1f}%)")
    
    # Most common disagreement patterns
    disagreement_patterns = defaultdict(int)
    for d in disagreements:
        pattern = f"{d['pytorch_pred']} → {d['onnx_pred']}"
        disagreement_patterns[pattern] += 1
    
    print(f"\nMost Common Disagreement Patterns:")
    sorted_patterns = sorted(disagreement_patterns.items(), key=lambda x: x[1], reverse=True)
    for pattern, count in sorted_patterns[:10]:
        percentage = count / total_disagreements * 100
        pt_class, onnx_class = pattern.split(' → ')
        pt_name = frh04.classname[int(pt_class)][:15]
        onnx_name = frh04.classname[int(onnx_class)][:15]
        print(f"  {pt_name:<15} → {onnx_name:<15}: {count:>4} ({percentage:5.1f}%)")
    
    # Probability difference analysis
    prob_diffs = [d['max_prob_diff'] for d in disagreements]
    print(f"\nProbability Difference Statistics:")
    print(f"  Mean: {np.mean(prob_diffs):.6f}")
    print(f"  Std:  {np.std(prob_diffs):.6f}")
    print(f"  Min:  {np.min(prob_diffs):.6f}")
    print(f"  Max:  {np.max(prob_diffs):.6f}")
    
    # Cases where both are wrong vs one is right
    both_wrong = sum(1 for d in disagreements if not d['pytorch_correct'] and not d['onnx_correct'])
    pt_right = sum(1 for d in disagreements if d['pytorch_correct'] and not d['onnx_correct'])
    onnx_right = sum(1 for d in disagreements if not d['pytorch_correct'] and d['onnx_correct'])
    
    print(f"\nCorrectness in Disagreements:")
    print(f"  Both wrong:     {both_wrong:>4} ({both_wrong/total_disagreements*100:5.1f}%)")
    print(f"  PyTorch right:  {pt_right:>4} ({pt_right/total_disagreements*100:5.1f}%)")
    print(f"  ONNX right:     {onnx_right:>4} ({onnx_right/total_disagreements*100:5.1f}%)")

def save_detailed_comparison(all_results, save_path="detailed_comparison.csv"):
    """Save detailed comparison results to CSV"""
    
    df_data = []
    for result in all_results:
        row = {
            'sample_idx': result['sample_idx'],
            'field_id': result['field_id'],
            'ground_truth': result['ground_truth'],
            'pytorch_pred': result['pytorch_pred'],
            'onnx_pred': result['onnx_pred'],
            'pytorch_conf': result['pytorch_conf'],
            'onnx_conf': result['onnx_conf'],
            'predictions_match': result['predictions_match'],
            'pytorch_correct': result['pytorch_correct'],
            'onnx_correct': result['onnx_correct'],
            'max_prob_diff': result['max_prob_diff']
        }
        
        # Add individual class probabilities
        for class_idx in range(C):
            row[f'prob_pytorch_class_{class_idx}'] = result['probs_pytorch'][class_idx]
            row[f'prob_onnx_class_{class_idx}'] = result['probs_onnx'][class_idx]
            row[f'prob_diff_class_{class_idx}'] = result['prob_differences'][class_idx]
        
        df_data.append(row)
    
    df = pd.DataFrame(df_data)
    df.to_csv(save_path, index=False)
    print(f"\nDetailed results saved to: {save_path}")
    return df
    
def run_comprehensive_comparison(num_samples=None, save_results=True, verbose_samples=0):
    """Run comprehensive comparison between PyTorch and ONNX models"""
    
    print("Setting up models...")
    pytorch_model, frh04 = setup_pytorch_model()
    onnx_session = setup_onnx_model()
    print(f"Dataset size: {len(frh04)} samples")
    total_samples = len(frh04) if num_samples is None else min(num_samples, len(frh04))
    print(f"Processing {total_samples} samples...")
    all_results = []
    disagreements = []
    for i in range(total_samples):
        x, y_true, field_id = frh04[i]
        
        # Compare models on this sample
        result = compare_single_sample(
            pytorch_model, onnx_session, x, y_true, field_id, i, 
            verbose=(i < verbose_samples) )
        
        all_results.append(result)
        
        # Track disagreements
        if not result['predictions_match']:
            disagreements.append(result)
        # Progress update
        if (i + 1) % 10000 == 0:
            current_agreement = (total_samples - len(disagreements)) / (i + 1)
            print(f"  Processed {i+1}/{total_samples} samples. Current agreement: {current_agreement:.3f}")
    
    # Calculate overall metrics
    pytorch_preds = [r['pytorch_pred'] for r in all_results]
    onnx_preds = [r['onnx_pred'] for r in all_results]
    ground_truth = [r['ground_truth'] for r in all_results]
    field_ids = [r['field_id'] for r in all_results]
    
    pytorch_accuracy = accuracy_score(ground_truth, pytorch_preds)
    onnx_accuracy = accuracy_score(ground_truth, onnx_preds)
    agreement = np.mean([r['predictions_match'] for r in all_results])
    
    print("\n" + "="*80)
    print("OVERALL COMPARISON RESULTS")
    print("="*80)
    print(f"Samples processed: {total_samples:,}")
    print(f"PyTorch accuracy: {pytorch_accuracy:.6f}")
    print(f"ONNX accuracy:    {onnx_accuracy:.6f}")
    print(f"Model agreement:  {agreement:.6f} ({agreement*100:.2f}%)")
    print(f"Disagreements:    {len(disagreements):,}")
    
    # Analyze disagreement patterns
    if disagreements:
        analyze_disagreement_patterns(disagreements, frh04)
    
    # Save detailed results
    if save_results:
        os.makedirs("maps", exist_ok=True)
        df = save_detailed_comparison(all_results, "maps/detailed_pytorch_onnx_comparison.csv")
        # Save disagreements 
        if disagreements:
            disagreement_df = df[~df['predictions_match']]
            disagreement_df.to_csv("maps/disagreements_only.csv", index=False)
            print(f"Disagreements saved to: maps/disagreements_only.csv")
        
        # Save summary 
        save_summary_stats(all_results, frh04, field_ids, "maps/summary_stats.txt")
        
        # Create field mapping
        create_field_mapping(field_ids, all_results, frh04, "maps/field_id_mapping.csv")
    
    return all_results, disagreements, field_ids

def save_summary_stats(all_results, frh04, field_ids, save_path="summary_stats.txt"):
    """Save summary statistics to a text file"""
    
    pytorch_preds = [r['pytorch_pred'] for r in all_results]
    onnx_preds = [r['onnx_pred'] for r in all_results]
    ground_truth = [r['ground_truth'] for r in all_results]
    
    with open(save_path, 'w') as f:
        f.write("="*80 + "\n")
        f.write("PYTORCH vs ONNX MODEL COMPARISON SUMMARY\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Total samples processed: {len(all_results):,}\n")
        f.write(f"PyTorch accuracy: {accuracy_score(ground_truth, pytorch_preds):.6f}\n")
        f.write(f"ONNX accuracy: {accuracy_score(ground_truth, onnx_preds):.6f}\n")
        f.write(f"Model agreement: {np.mean([r['predictions_match'] for r in all_results]):.6f}\n\n")
        
        f.write("CLASS DISTRIBUTION:\n")
        f.write("-" * 50 + "\n")
        gt_counts = np.bincount(ground_truth, minlength=C)
        for class_idx in range(C):
            class_name = frh04.classname[class_idx]
            count = gt_counts[class_idx]
            percentage = count / len(all_results) * 100
            f.write(f"{class_idx:>2} ({class_name:<25}): {count:>6,} ({percentage:5.1f}%)\n")
        
        f.write("\nCONFIDENCE STATISTICS:\n")
        f.write("-" * 50 + "\n")
        pytorch_confs = [r['pytorch_conf'] for r in all_results]
        onnx_confs = [r['onnx_conf'] for r in all_results]
        f.write(f"PyTorch confidence - Mean: {np.mean(pytorch_confs):.4f}, Std: {np.std(pytorch_confs):.4f}\n")
        f.write(f"ONNX confidence - Mean: {np.mean(onnx_confs):.4f}, Std: {np.std(onnx_confs):.4f}\n")
        
        f.write(f"\nFile generated on: {pd.Timestamp.now()}\n")
    
    print(f"Summary statistics saved to: {save_path}")

def create_field_mapping(field_ids, all_results, frh04, save_path="maps/field_id_mapping.csv"):
    """Create a simple field ID to prediction mapping"""
    
    mapping_data = []
    for i, field_id in enumerate(field_ids):
        result = all_results[i]
        mapping_data.append({
            'field_id': field_id,
            'ground_truth_class_id': result['ground_truth'],
            'ground_truth_class_name': frh04.classname[result['ground_truth']],
            'predicted_class_id': result['pytorch_pred'],  
            'onnx_class_id': result['onnx_pred'],
            'predicted_class_name': frh04.classname[result['pytorch_pred']],
            'onnx_class_name': frh04.classname[result['onnx_pred']],
            'confidence': result['pytorch_conf'],
            'onnx_confidence': result['onnx_conf']
        })
    
    df = pd.DataFrame(mapping_data)
    df.to_csv(save_path, index=False)
    print(f"Field mapping saved to: {save_path}")
    return df


if __name__ == "__main__":
    results, disagreements, field_ids = run_comprehensive_comparison(
        num_samples=None,  
        save_results=True,
        #verbose_samples=5
    )
    
    print(f"\nAnalysis complete!")
    print(f"Files saved:")
    print(f"   - maps/detailed_pytorch_onnx_comparison.csv")
    print(f"   - maps/summary_stats.txt")
    print(f"   - maps/field_id_mapping.csv")
    
    if disagreements:
        print(f" - maps/disagreements_only.csv (disagreement cases only)")
    else:
        print(f"Perfect agreement - no disagreements to analyze!")

In [None]:
belle_ile = BreizhCrops(region="belle-ile", root=DATA_PATH, level=LEVEL, preload_ram=PRELOAD_RAM)
frh04 = BreizhCrops(region="frh04", root=DATA_PATH, level=LEVEL, preload_ram=PRELOAD_RAM)
field_parcels_geodataframe = belle_ile.geodataframe() #frh04 or belle_ile

In [None]:
import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
import numpy as np
from pyproj import CRS
import math

In [None]:
SAVE_DIR     = "maps"   
DETAIL_CSV   = "detailed_pytorch_onnx_comparison.csv"  # prediction file 
PRED_SOURCE  = "pytorch"   # choose: "onnx" or "pytorch"
FIXED_SCALE_KM = None  
TITLE_REGION = "FRH04"  # title text


os.makedirs(SAVE_DIR, exist_ok=True)

def clean_tensor_int(x):
    """Convert 'tensor(8)' -> 8, '8'->8, 8->8."""
    if isinstance(x, str) and x.startswith("tensor(") and x.endswith(")"):
        x = x[7:-1]
    return int(float(x))

def _safe_series(x):
    """Flatten to 1-D Series to avoid MultiIndex issues with isna()/value_counts()."""
    return pd.Series(np.asarray(x, dtype=object))

def add_north_arrow(ax, xy=(0.95, 0.92), length=0.10):
    """
    Draw a north arrow.
    """
    ax.annotate(
        '', xy=xy, xytext=(xy[0], xy[1]-length),
        xycoords='axes fraction', textcoords='axes fraction',
        arrowprops=dict(arrowstyle='-|>', lw=1.8, color='k')
    )
    ax.text(xy[0], xy[1]+0.02, 'N', transform=ax.transAxes,
            ha='center', va='bottom', fontsize=12, weight='bold')


def _nice_number_m(x_m):
    """Round meters to a nice 1-2-5*10^n value."""
    if x_m <= 0:
        return 1.0
    exp = int(np.floor(np.log10(x_m)))
    frac = x_m / (10 ** exp)
    if frac < 1.5:
        nice = 1
    elif frac < 3:
        nice = 2
    elif frac < 7:
        nice = 5
    else:
        nice = 10
    return nice * (10 ** exp)

def _meters_per_deg_lon_at_lat(lat_deg):
    return 111320.0 * math.cos(math.radians(lat_deg))


def fmt_pct(x):
    return f"{x:.1f}%"


def add_qgis_scalebar_auto(ax, gdf_plotted, total_km=None, segments=4,
                           loc='lower right', pad_frac=0.05, edge_lw=0.8, textsize=10):
    from pyproj import CRS
    import math
    import numpy as np
    from matplotlib.patches import Rectangle

    def _nice_number_m(x_m):
        if x_m <= 0: return 1.0
        exp = int(np.floor(np.log10(x_m)))
        frac = x_m / (10 ** exp)
        nice = 1 if frac < 1.5 else (2 if frac < 3 else (5 if frac < 7 else 10))
        return nice * (10 ** exp)

    def _meters_per_deg_lon_at_lat(lat_deg):
        return 111320.0 * math.cos(math.radians(lat_deg))

    if gdf_plotted.crs is None:
        raise ValueError("GeoDataFrame has no CRS. Set a CRS or reproject before drawing scalebar.")

 
    minx, maxx = ax.get_xlim()
    miny, maxy = ax.get_ylim()
    width  = maxx - minx
    height = maxy - miny

    crs = CRS.from_user_input(gdf_plotted.crs)

    if crs.is_projected:
        width_m_est = width
        meters_total = (total_km * 1000.0) if total_km is not None else _nice_number_m(width_m_est / 5.0)
        bar_len_native = meters_total
        seg_native     = bar_len_native / float(segments)
        bar_thick      = 0.014 * height
        label_mid_km   = meters_total / 2000.0
    else:
        mid_lat = 0.5 * (miny + maxy)
        m_per_deg_lon = max(_meters_per_deg_lon_at_lat(mid_lat), 1e-6)
        width_m_est   = width * m_per_deg_lon
        meters_total  = (total_km * 1000.0) if total_km is not None else _nice_number_m(width_m_est / 5.0)
        bar_len_deg   = meters_total / m_per_deg_lon
        seg_native    = bar_len_deg / float(segments)
        bar_len_native = bar_len_deg
        bar_thick     = 0.014 * height
        label_mid_km  = meters_total / 2000.0

    # Anchor position
    if loc == 'lower right':
        x0 = maxx - pad_frac*width - bar_len_native
        y0 = miny + pad_frac*height
        label_above = True
    elif loc == 'lower left':
        x0 = minx + pad_frac*width
        y0 = miny + pad_frac*height
        label_above = True
    elif loc == 'upper right':
        x0 = maxx - pad_frac*width - bar_len_native
        y0 = maxy - pad_frac*height - 2.5*bar_thick
        label_above = False
    else:  # 'upper left'
        x0 = minx + pad_frac*width
        y0 = maxy - pad_frac*height - 2.5*bar_thick
        label_above = False
    for i in range(segments):
        xi = x0 + i*seg_native
        ax.add_patch(Rectangle((xi, y0), seg_native, bar_thick,
                               facecolor=('k' if i % 2 == 0 else 'white'),
                               edgecolor='k', lw=edge_lw))

    # Ticks
    tick_y0 = y0
    tick_y1 = y0 + bar_thick
    ax.plot([x0, x0], [tick_y0, tick_y1], color='k', lw=edge_lw)
    ax.plot([x0 + bar_len_native/2, x0 + bar_len_native/2], [tick_y0, tick_y1], color='k', lw=edge_lw)
    ax.plot([x0 + bar_len_native, x0 + bar_len_native], [tick_y0, tick_y1], color='k', lw=edge_lw)

    # Labels
    ty = (y0 + 1.9*bar_thick) if label_above else (y0 - 0.9*bar_thick)
    va = 'bottom' if label_above else 'top'

    ax.text(x0, ty, "0", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
    ax.text(x0 + bar_len_native/2, ty, f"{label_mid_km:g}", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))

    end_label = f"{(meters_total/1000.0):g} km" if meters_total >= 1000 else f"{int(round(meters_total))} m"
    ax.text(x0 + bar_len_native, ty, end_label, ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))

def make_maps(field_parcels_geodataframe,
              detail_csv=DETAIL_CSV,
              pred_source=PRED_SOURCE,
              save_dir=SAVE_DIR,
              fixed_scale_km=FIXED_SCALE_KM,
              title_region=TITLE_REGION):
    # ---------- prediction CSV -------------
    pred_df = pd.read_csv(detail_csv)
    # clean essential columns
    pred_df["field_id"]     = pred_df["field_id"].astype(int)
    pred_df["ground_truth"] = pred_df["ground_truth"].apply(clean_tensor_int)
    pred_df["pytorch_pred"] = pred_df["pytorch_pred"].astype(int)
    pred_df["onnx_pred"]    = pred_df["onnx_pred"].astype(int)

    # choose prediction source
    if pred_source.lower() == "onnx":
        pred_idx_col = "onnx_pred"
        title_suffix = "ONNX"
    elif pred_source.lower() == "pytorch":
        pred_idx_col = "pytorch_pred"
        title_suffix = "PyTorch"
    else:
        raise ValueError("PRED_SOURCE must be 'onnx' or 'pytorch'.")

    # slim frame for merge
    pred_merge = pred_df.loc[:, ["field_id", "ground_truth", pred_idx_col]].rename(
        columns={"field_id": "id", pred_idx_col: "pred_idx"}
    )

    # ---------- merge with GeoDataFrame ----------
    gdf = field_parcels_geodataframe.copy()
    gdf = gdf.merge(pred_merge, on="id", how="left")
    gdf_ll = gdf.to_crs(epsg=4326)

    idx_to_name = {}
    for idx, sub in gdf_ll.dropna(subset=["ground_truth"]).groupby("ground_truth"):
        if "classname" in sub and not sub["classname"].isna().all():
            name = sub["classname"].mode(dropna=True)
            idx_to_name[int(idx)] = str(name.iloc[0]) if len(name) else f"class_{int(idx)}"


    all_idxs = set(gdf_ll["ground_truth"].dropna().astype(int)) \
               .union(set(gdf_ll["pred_idx"].dropna().astype(int)))
    for k in all_idxs:
        idx_to_name.setdefault(int(k), f"class_{int(k)}")

    # Map names
    gdf_ll["gt_class_name"]   = gdf_ll["ground_truth"].map(lambda v: idx_to_name.get(int(v), f"class_{int(v)}") if pd.notna(v) else None)
    gdf_ll["pred_class_name"] = gdf_ll["pred_idx"].map(lambda v: idx_to_name.get(int(v), f"class_{int(v)}") if pd.notna(v) else None)

    if "classname" in gdf_ll.columns and not gdf_ll["classname"].isna().all():
        gdf_ll["gt_class_name"] = gdf_ll["classname"].astype(str)
    classes = sorted(
        pd.unique(
            pd.concat([
                _safe_series(gdf_ll["gt_class_name"]),
                _safe_series(gdf_ll["pred_class_name"])
            ], ignore_index=True).dropna().astype(str)
        )
    )
    cmap = plt.cm.get_cmap("tab20", len(classes))
    class2color = {cls: cmap(i) for i, cls in enumerate(classes)}

    # Keep categoricals aligned with classes
    gdf_ll["gt_class_name"]   = pd.Categorical(_safe_series(gdf_ll["gt_class_name"]).astype(str),   categories=classes, ordered=True)
    gdf_ll["pred_class_name"] = pd.Categorical(_safe_series(gdf_ll["pred_class_name"]).astype(str), categories=classes, ordered=True)

    # Colors
    gdf_ll["gt_color"]   = _safe_series(gdf_ll["gt_class_name"]).map(class2color)
    gdf_ll["pred_color"] = _safe_series(gdf_ll["pred_class_name"]).map(class2color)

    gdf_m = gdf_ll.to_crs(epsg=3857)
    for col in ["gt_color", "pred_color"]:
        gdf_m[col] = gdf_ll[col]
    _scale_km = fixed_scale_km 

    # =========================
    # 1) Ground-truth map
    # =========================
    gt_counts = (_safe_series(gdf_ll["gt_class_name"]).value_counts(dropna=True)
                 .reindex(classes, fill_value=0))
    gt_total  = int(gt_counts.sum())
    gt_pct    = gt_counts / max(gt_total, 1) * 100.0

    fig, ax = plt.subplots(1, 1, figsize=(12, 12))
    gdf_m.plot(color=gdf_m["gt_color"], linewidth=0.05, edgecolor="none", ax=ax)
    print(gdf_m.total_bounds) 
    print("Width in km:", (gdf_m.total_bounds[2] - gdf_m.total_bounds[0]) / 1000)
    ax.set_title(f"{title_region} — Ground Truth (BreizhCrops)", pad=12)
    ax.set_xlabel(""); ax.set_ylabel("")
    add_north_arrow(ax, xy=(0.95, 0.92), length=0.10)
    ax.set_aspect('equal')
    add_qgis_scalebar_auto(ax, gdf_m, total_km=_scale_km, segments=4, loc='lower right')

    handles_gt = []
    for cls in classes:
        cnt = int(gt_counts.loc[cls])
        pct = float(gt_pct.loc[cls])
        handles_gt.append(
            mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                           label=f"{cls} — {fmt_pct(pct)} ({cnt:,})")
        )
    ax.legend(handles=handles_gt, loc="center left", bbox_to_anchor=(1.02, 0.5),
              frameon=True, title="Class")
    fig.tight_layout(); fig.subplots_adjust(right=0.82)
    fig.savefig(os.path.join(save_dir, "gr.png"), dpi=300, bbox_inches="tight", facecolor="white")
    plt.show()
    plt.close(fig)
    


    # ==================================
    # 2) Predicted map (PyTorch or ONNX)
    # ==================================
    pred_counts = (_safe_series(gdf_ll["pred_class_name"]).value_counts(dropna=True)
                   .reindex(classes, fill_value=0))
    pred_total  = int(pred_counts.sum())
    pred_pct    = pred_counts / max(pred_total, 1) * 100.0

    fig, ax = plt.subplots(1, 1, figsize=(12, 12))
    gdf_m.plot(color=gdf_m["pred_color"], linewidth=0.05, edgecolor="none", ax=ax)
    print(gdf_m.total_bounds)  
    print("Width in km:", (gdf_m.total_bounds[2] - gdf_m.total_bounds[0]) / 1000)
    ax.set_title(f"{title_region} — Predicted — ({title_suffix})", pad=12)
    ax.set_xlabel(""); ax.set_ylabel("")
    add_north_arrow(ax, xy=(0.95, 0.92), length=0.10)
    ax.set_aspect('equal')
    add_qgis_scalebar_auto(ax, gdf_m, total_km=_scale_km, segments=4, loc='lower right')

    handles_pr = []
    for cls in classes:
        cnt = int(pred_counts.loc[cls])
        pct = float(pred_pct.loc[cls])
        handles_pr.append(
            mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                           label=f"{cls} — {fmt_pct(pct)} ({cnt:,})")
        )
    ax.legend(handles=handles_pr, loc="center left", bbox_to_anchor=(1.02, 0.5),
              frameon=True, title="Predicted Class")
    fig.tight_layout(); fig.subplots_adjust(right=0.82)
    
    fig.savefig(os.path.join(save_dir, "pr.png"), dpi=300, bbox_inches="tight", facecolor="white")
    plt.show()
    plt.close(fig)

    # ==========================================
    # 3) Correct vs Incorrect
    # ==========================================
    gdf_ll["correct"] = (_safe_series(gdf_ll["gt_class_name"]) == _safe_series(gdf_ll["pred_class_name"]))
    palette_ci = {True: "#1a9850", False: "#d73027"}
    gdf_ll["ci_color"] = gdf_ll["correct"].map(palette_ci)

    gdf_m = gdf_ll.to_crs(epsg=3857)
    gdf_m["ci_color"] = gdf_ll["ci_color"]
    print(gdf_m.total_bounds) 
    print("Width in km:", (gdf_m.total_bounds[2] - gdf_m.total_bounds[0]) / 1000)

    s_pred     = _safe_series(gdf_ll["pred_class_name"])
    n_total    = int(s_pred.notna().sum())
    n_correct  = int(_safe_series(gdf_ll["correct"]).sum())
    n_incorrect = n_total - n_correct
    p_correct   = 100.0 * n_correct / max(n_total, 1)
    p_incorrect = 100.0 * n_incorrect / max(n_total, 1)

    fig, ax = plt.subplots(1, 1, figsize=(12, 12))
    gdf_m.plot(color=gdf_m["ci_color"], linewidth=0.05, edgecolor="none", ax=ax)
    ax.set_title(f"{title_region} — Correct vs Incorrect — ({title_suffix})", pad=12)
    ax.set_xlabel(""); ax.set_ylabel("")
    add_north_arrow(ax, xy=(0.95, 0.92), length=0.10)
    ax.set_aspect('equal')
    add_qgis_scalebar_auto(ax, gdf_m, total_km=_scale_km, segments=4, loc='lower right')

    handles_ci = [
        mpatches.Patch(color=palette_ci[True],  label=f"Correct — {p_correct:.1f}% ({n_correct:,})"),
        mpatches.Patch(color=palette_ci[False], label=f"Incorrect — {p_incorrect:.1f}% ({n_incorrect:,})"),
    ]
    ax.legend(handles=handles_ci, loc="center left", bbox_to_anchor=(1.02, 0.5),
              frameon=True, title="Prediction")
    fig.tight_layout(); fig.subplots_adjust(right=0.72)
    fig.savefig(os.path.join(save_dir, "co.png"), dpi=300, bbox_inches="tight", facecolor="white")
    plt.show()
    plt.close(fig)

    print("Saved:",
          os.path.join(save_dir, "gr.png"),
          os.path.join(save_dir, "pr_py.png"),
          os.path.join(save_dir, "co.png"))


make_maps(
    field_parcels_geodataframe,
    detail_csv=DETAIL_CSV,
    pred_source="pytorch",     # pytorch or onnx 
    save_dir="maps",
    fixed_scale_km=None,       
    title_region="FRH04"    # appears in the figure title
)


#Belle-Ile

In [None]:
import os, math
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
from pyproj import CRS
import contextily as ctx
DETAIL_CSV   = "detailed_pytorch_onnx_comparison.csv"  
PRED_SOURCE  = "onnx"   # "pytorch" or "onnx"
SAVE_DIR     = "maps"
TITLE_REGION = "Belle-ile"
FIXED_SCALE_KM = None 
os.makedirs(SAVE_DIR, exist_ok=True)

def clean_tensor_int(x):
    if isinstance(x, str) and x.startswith("tensor(") and x.endswith(")"):
        x = x[7:-1]
    return int(float(x))

def fmt_pct(p):
    if p <= 0:  return "0%"
    if p < 0.1: return "<0.1%"
    return f"{p:.1f}%"

def add_north_arrow(ax, xy=(0.95, 0.92), length=0.10):
    ax.annotate('', xy=xy, xytext=(xy[0], xy[1]-length),
                xycoords='axes fraction', textcoords='axes fraction',
                arrowprops=dict(arrowstyle='-|>', lw=1.8, color='w'))
    ax.text(xy[0], xy[1]+0.02, 'N', transform=ax.transAxes,
            ha='center', va='bottom', fontsize=12, weight='bold',color='w')

def _nice_number_m(x_m):
    if x_m <= 0: return 1.0
    exp  = int(np.floor(np.log10(x_m)))
    frac = x_m / (10**exp)
    nice = 1 if frac < 1.5 else (2 if frac < 3 else (5 if frac < 7 else 10))
    return nice * (10**exp)

def _meters_per_deg_lon_at_lat(lat_deg):
    return 111320.0 * math.cos(math.radians(lat_deg))

def add_qgis_scalebar_auto(ax, gdf_plotted, total_km=None, segments=4,
                           loc='lower left', pad_frac=0.05, edge_lw=0.8, textsize=10,
                           bgcolor='lightgray', bg_alpha=0.6, pad_px=5):
    """
    Scale bar with optional gray background box.
    Uses CURRENT axes extent (what you see), honoring the plotted CRS.
    """
    if gdf_plotted.crs is None:
        raise ValueError("GeoDataFrame has no CRS. Set a CRS or reproject before drawing scalebar.")
    crs = CRS.from_user_input(gdf_plotted.crs)

    minx, maxx = ax.get_xlim()
    miny, maxy = ax.get_ylim()
    width  = maxx - minx
    height = maxy - miny

    if crs.is_projected:
        width_m_est = width
        meters_total = (total_km*1000.0) if total_km is not None else _nice_number_m(width_m_est/5.0)
        bar_len_native = meters_total
        seg_native     = bar_len_native / float(segments)
        bar_thick      = 0.014 * height
        label_mid_km   = meters_total / 2000.0
    else:
        mid_lat = 0.5*(miny+maxy)
        m_per_deg_lon = max(_meters_per_deg_lon_at_lat(mid_lat), 1e-6)
        width_m_est   = width * m_per_deg_lon
        meters_total  = (total_km*1000.0) if total_km is not None else _nice_number_m(width_m_est/5.0)
        bar_len_deg   = meters_total / m_per_deg_lon
        seg_native    = bar_len_deg / float(segments)
        bar_len_native = bar_len_deg
        bar_thick     = 0.014 * height
        label_mid_km  = meters_total / 2000.0

    # anchor
    if loc == 'lower left':
        x0 = minx + pad_frac*width
        y0 = miny + pad_frac*height
        label_above = True
    elif loc == 'lower right':
        x0 = maxx - pad_frac*width - bar_len_native
        y0 = miny + pad_frac*height
        label_above = True
    elif loc == 'upper right':
        x0 = maxx - pad_frac*width - bar_len_native
        y0 = maxy - pad_frac*height - 2.5*bar_thick
        label_above = False
    else:  # upper left
        x0 = minx + pad_frac*width
        y0 = maxy - pad_frac*height - 2.5*bar_thick
        label_above = False

    # alternating blocks
    for i in range(segments):
        xi = x0 + i*seg_native
        ax.add_patch(Rectangle((xi, y0), seg_native, bar_thick,
                               facecolor=('grey' if i%2==0 else 'white'),
                               edgecolor='grey', lw=edge_lw, zorder=2))

    # ticks
    ax.plot([x0, x0], [y0, y0+bar_thick], color='k', lw=edge_lw, zorder=3)
    ax.plot([x0+bar_len_native/2, x0+bar_len_native/2], [y0, y0+bar_thick], color='k', lw=edge_lw, zorder=3)
    ax.plot([x0+bar_len_native,   x0+bar_len_native],   [y0, y0+bar_thick], color='k', lw=edge_lw, zorder=3)

    # labels
    ty = (y0 + 1.9*bar_thick) if label_above else (y0 - 0.9*bar_thick)
    va = 'bottom' if label_above else 'top'
    ax.text(x0, ty, "0", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7), zorder=4)
    ax.text(x0+bar_len_native/2, ty, f"{label_mid_km:g}", ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7), zorder=4)
    end_label = f"{(meters_total/1000.0):g} km" if meters_total>=1000 else f"{int(round(meters_total))} m"
    ax.text(x0+bar_len_native, ty, end_label, ha='center', va=va, fontsize=textsize,
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.7), zorder=4)


def add_bg(ax, crs, zoom="auto", attribution=True, alpha=1.0):
    """Add a Contextily basemap behind data in EPSG:3857."""
    providers_try = [
        "Esri.WorldImagery","CartoDB.Positron", "CartoDB.Voyager", "OpenStreetMap.Mapnik",
         "Esri.WorldTopoMap", "OpenTopoMap",
    ]
    def _resolve(path):
        prov = ctx.providers
        for part in path.split("."):
            prov = getattr(prov, part, None)
            if prov is None:
                return None
        return prov
    for name in providers_try:
        prov = _resolve(name)
        if prov is not None:
            ctx.add_basemap(ax, crs=crs, source=prov, zoom=zoom, attribution=attribution, alpha=alpha)
            return name
    ctx.add_basemap(ax, crs=crs, zoom=zoom, attribution=attribution, alpha=alpha)
    return "default(OSM)"


gdf_source = field_parcels_geodataframe  
gdf_3857 = gdf_source.to_crs(epsg=3857)

# Build class palette + counts/% for legend
classes = sorted(gdf_3857["classname"].dropna().astype(str).unique())
cmap = plt.cm.get_cmap("tab20", len(classes))
class2color = {cls: cmap(i) for i, cls in enumerate(classes)}
gt_counts = gdf_3857["classname"].astype(str).value_counts().reindex(classes, fill_value=0)
gt_total  = int(gt_counts.sum())
gt_pct    = (gt_counts / max(gt_total, 1)) * 100.0
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
gdf_3857.plot(
    color=gdf_3857["classname"].astype(str).map(class2color),
    linewidth=0.05, edgecolor="none", ax=ax, zorder=2
)
minx, miny, maxx, maxy = gdf_3857.total_bounds
pad_x = 0.02 * (maxx - minx)
pad_y = 0.02 * (maxy - miny)
ax.set_xlim(minx - pad_x, maxx + pad_x)
ax.set_ylim(miny - pad_y, maxy + pad_y)

# add basemap 
_ = add_bg(ax, gdf_3857.crs, zoom="auto", attribution=True, alpha=1.0)
gdf_3857.plot(
    color=gdf_3857["classname"].astype(str).map(class2color),
    linewidth=0.05, edgecolor="none", ax=ax, zorder=3
)

# =========================
# 1) Ground-truth map
# =========================

ax.set_title("Belle-ile — Ground Truth (BreizhCrops)", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax, xy=(0.95, 0.92), length=0.10)
add_qgis_scalebar_auto(ax, gdf_3857, total_km=None, segments=4, loc='lower left')  
handles_gt = []
for cls in classes:
    cnt = int(gt_counts.loc[cls])
    pct = float(gt_pct.loc[cls])
    handles_gt.append(mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                                     label=f"{cls} — {fmt_pct(pct)} ({cnt:,})"))
ax.legend(handles=handles_gt, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Class")

fig.tight_layout(); fig.subplots_adjust(right=0.82)
fig.savefig(os.path.join(SAVE_DIR, "gr_bille.png"), dpi=300, bbox_inches="tight", facecolor="white")
plt.show()
plt.close(fig)

# ==================================
# 2) Predicted map (PyTorch or ONNX)
# ==================================

pred = pd.read_csv(DETAIL_CSV)
pred["field_id"]     = pred["field_id"].astype(int)
pred["ground_truth"] = pred["ground_truth"].apply(clean_tensor_int)
pred["pytorch_pred"] = pred["pytorch_pred"].astype(int)
pred["onnx_pred"]    = pred["onnx_pred"].astype(int)

pred_idx_col = "onnx_pred" if PRED_SOURCE.lower() == "onnx" else "pytorch_pred"


merge_df = pred.loc[:, ["field_id", "ground_truth", pred_idx_col]].rename(
    columns={"field_id": "id", pred_idx_col: "pred_idx"}
)
gdf = field_parcels_geodataframe.merge(merge_df, on="id", how="left")
idx_to_name = {}
for idx, sub in gdf.dropna(subset=["ground_truth", "classname"]).groupby("ground_truth"):
    mode = sub["classname"].mode(dropna=True)
    idx_to_name[int(idx)] = str(mode.iloc[0]) if len(mode) else f"class_{int(idx)}"

all_idxs = set(gdf["ground_truth"].dropna().astype(int)) | set(gdf["pred_idx"].dropna().astype(int))
for k in all_idxs:
    idx_to_name.setdefault(int(k), f"class_{int(k)}")

# Predicted class names
gdf["pred_class_name"] = gdf["pred_idx"].map(lambda v: idx_to_name.get(int(v), f"class_{int(v)}") if pd.notna(v) else None)
gdf_3857 = gdf.to_crs(epsg=3857)
classes = sorted(pd.unique(pd.concat([
    gdf_3857["classname"].dropna().astype(str),
    gdf_3857["pred_class_name"].dropna().astype(str)
])))
gdf_3857["pred_class_name"] = pd.Categorical(gdf_3857["pred_class_name"], categories=classes, ordered=True)


cmap = plt.cm.get_cmap("tab20", len(classes))
class2color = {cls: cmap(i) for i, cls in enumerate(classes)}
pred_counts = gdf_3857["pred_class_name"].astype(str).value_counts().reindex(classes, fill_value=0)
pred_total  = int(pred_counts.sum())
pred_pct    = (pred_counts / max(pred_total, 1)) * 100.0
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
gdf_3857.plot(
    color=gdf_3857["pred_class_name"].astype(str).map(class2color),
    linewidth=0.05, edgecolor="none", ax=ax, zorder=2
)
minx, miny, maxx, maxy = gdf_3857.total_bounds
pad_x = 0.02 * (maxx - minx)
pad_y = 0.02 * (maxy - miny)
ax.set_xlim(minx - pad_x, maxx + pad_x)
ax.set_ylim(miny - pad_y, maxy + pad_y)
_ = add_bg(ax, gdf_3857.crs, zoom="auto", attribution=True, alpha=1.0)
gdf_3857.plot(
    color=gdf_3857["pred_class_name"].astype(str).map(class2color),
    linewidth=0.05, edgecolor="none", ax=ax, zorder=3
)
ax.set_title(f"{TITLE_REGION} — Predicted (TempCNN, {PRED_SOURCE.upper()})", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax, xy=(0.95, 0.92), length=0.10)
add_qgis_scalebar_auto(ax, gdf_3857, total_km=FIXED_SCALE_KM, segments=4, loc='lower left')  
handles_pr = []
for cls in classes:
    cnt = int(pred_counts.loc[cls])
    pct = float(pred_pct.loc[cls])
    handles_pr.append(
        mpatches.Patch(facecolor=class2color[cls], edgecolor="none",
                       label=f"{cls} — {fmt_pct(pct)} ({cnt:,})")
    )
ax.legend(handles=handles_pr, loc="center left", bbox_to_anchor=(1.02, 0.5),
          frameon=True, title="Predicted Class")

plt.tight_layout(); plt.subplots_adjust(right=0.82)
out_path = os.path.join(SAVE_DIR, f"pred_{PRED_SOURCE.lower()}.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.show()
print("Saved:", out_path)

# ==========================================
# 3) Correct vs Incorrect
# ==========================================
gdf["correct"] = (gdf["classname"].astype(str) == gdf["pred_class_name"].astype(str)) & gdf["pred_class_name"].notna()
palette = {True: "#1a9850", False: "#d73027"}
gdf["color_ci"] = gdf["correct"].map(palette)
gdf_3857 = gdf.to_crs(epsg=3857)
n_total    = int(gdf_3857["pred_class_name"].notna().sum())
n_correct  = int(gdf_3857["correct"].sum())
n_incorrect = n_total - n_correct
p_correct   = (100.0 * n_correct / max(n_total, 1))
p_incorrect = (100.0 * n_incorrect / max(n_total, 1))

fig, ax = plt.subplots(1, 1, figsize=(12, 12))
gdf_3857.plot(color=gdf_3857["color_ci"], linewidth=0.05, edgecolor="none", ax=ax, zorder=2)
minx, miny, maxx, maxy = gdf_3857.total_bounds
pad_x = 0.02 * (maxx - minx)
pad_y = 0.02 * (maxy - miny)
ax.set_xlim(minx - pad_x, maxx + pad_x)
ax.set_ylim(miny - pad_y, maxy + pad_y)
_ = add_bg(ax, gdf_3857.crs, zoom="auto", attribution=True, alpha=1.0)
gdf_3857.plot(color=gdf_3857["color_ci"], linewidth=0.05, edgecolor="none", ax=ax, zorder=3)

ax.set_title("Belle-ile — Correct vs Incorrect — Python", pad=12)
ax.set_xlabel(""); ax.set_ylabel(""); ax.set_aspect("equal")
add_north_arrow(ax, xy=(0.95, 0.92), length=0.10)
add_qgis_scalebar_auto(ax, gdf_3857, total_km=None, segments=4, loc='lower left')  

handles = [
    mpatches.Patch(color=palette[True],  label=f"Correct — {fmt_pct(p_correct)} ({n_correct:,})"),
    mpatches.Patch(color=palette[False], label=f"Incorrect — {fmt_pct(p_incorrect)} ({n_incorrect:,})"),
]
ax.legend(handles=handles, loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=True, title="Prediction")

plt.tight_layout(); plt.subplots_adjust(right=0.82)
out_path = os.path.join(SAVE_DIR, "correct_incorrect_python.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.show()
print("Saved:", out_path)
