In [230]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import re
import logging


def compute_average_deviation(deviation_frequency, deviation_sum, velocity_threshold=0.0):
    """
    Compute the average magnitude of deviations per velocity bin for each dataset,
    considering only velocities above a specified threshold.

    Args:
        deviation_frequency (dict): {dataset_name: {bin_label: count}}
        deviation_sum (dict): {dataset_name: {bin_label: sum_of_deviations}}
        velocity_threshold (float): The minimum ground truth velocity to consider (in m/s).

    Returns:
        pd.DataFrame: DataFrame with 'Velocity Bin' as index and datasets as columns containing average deviations.
    """
    # Extract all unique bin labels across all datasets
    bin_labels = set()
    for dataset in deviation_frequency:
        bin_labels.update(deviation_frequency[dataset].keys())
    # Convert all bin_labels to strings if they aren't already
    bin_labels = sorted(bin_labels, key=lambda x: float(x.split('-')[0]) if isinstance(x, str) else float(x))

    # Initialize DataFrame
    avg_dev_df = pd.DataFrame(index=bin_labels)

    for dataset in deviation_frequency:
        avg_devs = []
        for bin_label in bin_labels:
            # Extract the lower bound of the bin
            if isinstance(bin_label, str):
                try:
                    lower_str = bin_label.split('-')[0]
                    lower = float(lower_str)
                except (IndexError, ValueError) as e:
                    print(f"Invalid bin label format: '{bin_label}'. Error: {e}. Assigning average deviation as 0.0.")
                    avg_devs.append(0.0)
                    continue
            elif isinstance(bin_label, (float, int)):
                lower = float(bin_label)
            else:
                print(f"Unsupported bin label type: {type(bin_label)}. Assigning average deviation as 0.0.")
                avg_devs.append(0.0)
                continue

            # Apply velocity threshold
            if lower < velocity_threshold:
                avg_devs.append(None)  # Use None to indicate exclusion
                continue

            freq = deviation_frequency[dataset].get(bin_label, 0)
            sum_dev = deviation_sum[dataset].get(bin_label, 0.0)
            avg = sum_dev / freq if freq > 0 else 0.0
            avg_devs.append(avg)
        avg_dev_df[dataset] = avg_devs

    # Reset index to have 'Velocity Bin' as a column
    avg_dev_df = avg_dev_df.reset_index().rename(columns={'index': 'Velocity Bin'})

    # Drop bins below the threshold (i.e., where any dataset has None)
    if velocity_threshold > 0.0:
        # Only keep rows where all datasets have non-None values
        avg_dev_df = avg_dev_df.dropna(subset=avg_dev_df.columns[1:])
        # Alternatively, if you want to keep bins where at least one dataset meets the threshold:
        # avg_dev_df = avg_dev_df.dropna(subset=['Velocity Bin'])
        # However, since we assign None to bins below the threshold for all datasets,
        # dropping rows with any NaN should suffice.
    
    return avg_dev_df





def compute_error_metrics(predicted, ground_truth):
    """
    Compute common error metrics between predicted and ground truth velocity fields.

    Args:
        predicted (numpy array): Predicted velocity values.
        ground_truth (numpy array): Ground truth velocity values.

    Returns:
        dict: Dictionary containing MAE, MSE, RMSE, and R-squared.
    """
    predicted_flat = predicted.flatten()
    ground_truth_flat = ground_truth.flatten()

    mae = mean_absolute_error(ground_truth_flat, predicted_flat)
    mse = mean_squared_error(ground_truth_flat, predicted_flat)
    rmse = np.sqrt(mse)
    r2 = r2_score(ground_truth_flat, predicted_flat)

    return {"MAE": mae, "MSE": mse, "RMSE": rmse, "R2": r2}


def comprehensive_analysis(predicted_image_path, ground_truth_image_path, velocity_threshold=0.5, deviation_threshold=0.01):
    """
    Perform comprehensive analysis between predicted and ground truth velocity fields.

    Args:
        predicted_image_path (str): Path to the predicted velocity field image.
        ground_truth_image_path (str): Path to the ground truth velocity field image.
        velocity_threshold (float): Threshold for high-velocity regions (m/s).
        deviation_threshold (float): Threshold to define a significant deviation.

    Returns:
        tuple: (error_metrics, aggregate_data)
            - error_metrics (dict): Computed error metrics.
            - aggregate_data (dict): Data accumulated for aggregate plots.
    """
    # Extract velocity fields
    try:
        predicted_field = np.array(Image.open(predicted_image_path).convert('L')) / 255.0
        ground_truth_field = np.array(Image.open(ground_truth_image_path).convert('L')) / 255.0
    except Exception as e:
        print(f"Error loading images:\nPredicted: {predicted_image_path}\nGround Truth: {ground_truth_image_path}\nException: {e}")
        return None, None

    # Compute centerline profiles
    def extract_centerline(velocity_field):
        center_column = velocity_field.shape[1] // 2
        centerline = velocity_field[:, center_column]
        y = np.arange(velocity_field.shape[0])
        return y, centerline

    y_pred, pred_centerline = extract_centerline(predicted_field)
    y_gt, gt_centerline = extract_centerline(ground_truth_field)

    # Compute deviations
    deviations = pred_centerline - gt_centerline

    # Compute error metrics
    error_metrics = compute_error_metrics(pred_centerline, gt_centerline)

    # Define where deviations occur based on the threshold
    deviation_mask = np.abs(deviations) > deviation_threshold

    # Accumulate data for aggregate plots
    aggregate_data = {
        "overall_deviations": deviations,
        "high_velocity_deviations": deviations[gt_centerline > velocity_threshold],
        "deviation_histogram": deviations,
        "deviation_mask": deviation_mask,
        "ground_truth_velocities": gt_centerline
    }

    return error_metrics, aggregate_data

def process_image(dataset_name, image_name, model_folder, ground_truth_folder, velocity_threshold, deviation_threshold=0.01):
    """
    Process a single image: perform analysis and accumulate data.

    Args:
        dataset_name (str): Name of the dataset.
        image_name (str): Filename of the image.
        model_folder (str): Path to the model's prediction folder.
        ground_truth_folder (str): Path to the ground truth images folder.
        velocity_threshold (float): Threshold for high-velocity regions (m/s).
        deviation_threshold (float): Threshold to define a significant deviation.

    Returns:
        tuple: (metrics_record, aggregate_data) or None if skipped.
            - metrics_record (dict): Error metrics for the image.
            - aggregate_data (dict): Data for aggregate plots.
    """
    predicted_image_path = os.path.join(model_folder, image_name)
    ground_truth_image_path = os.path.join(ground_truth_folder, image_name)

    if not os.path.exists(ground_truth_image_path):
        print(f"Ground truth for '{image_name}' not found in dataset '{dataset_name}'. Skipping.")
        return None

    # Perform comprehensive analysis
    error_metrics, aggregate_data = comprehensive_analysis(
        predicted_image_path,
        ground_truth_image_path,
        velocity_threshold=velocity_threshold,
        deviation_threshold=deviation_threshold
    )

    if error_metrics is None or aggregate_data is None:
        print(f"Analysis failed for '{image_name}' in dataset '{dataset_name}'. Skipping.")
        return None

    # Prepare metrics record
    metrics_record = {
        "Dataset": dataset_name,
        "Image": image_name,
        "MAE": error_metrics["MAE"],
        "MSE": error_metrics["MSE"],
        "RMSE": error_metrics["RMSE"],
        "R2": error_metrics["R2"]
    }

    return (metrics_record, aggregate_data)

def batch_comprehensive_analysis_parallel(dataset_paths, ground_truth_folder, velocity_threshold=0.5, deviation_threshold=0.01, n_jobs=-1):
    """
    Perform comprehensive analysis for multiple datasets against the same ground truth using parallel processing.

    Args:
        dataset_paths (dict): Dictionary with dataset names as keys and model prediction folder paths as values.
        ground_truth_folder (str): Path to the folder containing ground truth images.
        velocity_threshold (float): Threshold for high-velocity regions (m/s).
        deviation_threshold (float): Threshold to define a significant deviation.
        n_jobs (int): Number of jobs for parallel processing. -1 means using all processors.

    Returns:
        dict: Contains average_metrics_df and accumulated data for aggregate plots.
    """
    # Initialize a list to store metrics
    metrics_list = []

    # Initialize dictionaries to accumulate data for aggregate plots
    aggregate_deviations = {}
    aggregate_high_velocity_deviations = {}
    aggregate_deviation_histograms = {}
    
    # Initialize dictionaries for deviation frequency and sum per velocity bin
    deviation_frequency = {}
    deviation_sum = {}
    
    # Initialize a dictionary for ground truth velocity density
    ground_truth_density = {}
    for dataset in dataset_paths:
        ground_truth_density[dataset] = []

    # Prepare tasks
    tasks = []
    for dataset_name, model_folder in dataset_paths.items():
        print(f"\nPreparing to analyze Dataset: '{dataset_name}'")
        if not os.path.isdir(model_folder):
            print(f"Model prediction folder '{model_folder}' for dataset '{dataset_name}' does not exist. Skipping.")
            continue
        for image_name in os.listdir(model_folder):
            if image_name.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
                tasks.append((dataset_name, image_name, model_folder, ground_truth_folder, velocity_threshold, deviation_threshold))

    # Execute tasks in parallel
    print(f"\nStarting parallel processing with {n_jobs if n_jobs != -1 else 'all available'} jobs...")
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_image)(*task) for task in tasks
    )

    # Process results
    for result in results:
        if result is None:
            continue
        metrics_record, aggregate_data = result
        metrics_list.append(metrics_record)

        dataset_name = metrics_record["Dataset"]

        # Accumulate deviations
        if dataset_name not in aggregate_deviations:
            aggregate_deviations[dataset_name] = aggregate_data["overall_deviations"]
        else:
            aggregate_deviations[dataset_name] = np.concatenate((aggregate_deviations[dataset_name],
                                                                 aggregate_data["overall_deviations"]))

        # Accumulate high-velocity deviations
        if dataset_name not in aggregate_high_velocity_deviations:
            aggregate_high_velocity_deviations[dataset_name] = aggregate_data["high_velocity_deviations"]
        else:
            aggregate_high_velocity_deviations[dataset_name] = np.concatenate((aggregate_high_velocity_deviations[dataset_name],
                                                                                 aggregate_data["high_velocity_deviations"]))

        # Accumulate deviation histograms
        if dataset_name not in aggregate_deviation_histograms:
            aggregate_deviation_histograms[dataset_name] = aggregate_data["deviation_histogram"]
        else:
            aggregate_deviation_histograms[dataset_name] = np.concatenate((aggregate_deviation_histograms[dataset_name],
                                                                           aggregate_data["deviation_histogram"]))

        # Accumulate deviation frequencies and sums
        if dataset_name not in deviation_frequency:
            deviation_frequency[dataset_name] = {}
        if dataset_name not in deviation_sum:
            deviation_sum[dataset_name] = {}
        
        gt_velocities = aggregate_data["ground_truth_velocities"]
        deviation_mask = aggregate_data["deviation_mask"]
        deviations = aggregate_data["overall_deviations"]

        for gt_v, dev, dev_abs in zip(gt_velocities, deviation_mask, np.abs(deviations)):
            # Collect ground truth velocities for density plot
            ground_truth_density[dataset_name].append(gt_v)

            if dev:
                # Round ground truth velocity to 2 decimal places to match binning
                gt_v_rounded = round(gt_v, 2)
                # Update frequency
                if gt_v_rounded in deviation_frequency[dataset_name]:
                    deviation_frequency[dataset_name][gt_v_rounded] += 1
                else:
                    deviation_frequency[dataset_name][gt_v_rounded] = 1
                # Update sum of deviations
                if gt_v_rounded in deviation_sum[dataset_name]:
                    deviation_sum[dataset_name][gt_v_rounded] += dev_abs
                else:
                    deviation_sum[dataset_name][gt_v_rounded] = dev_abs

    # Create DataFrame from the list
    metrics_df = pd.DataFrame(metrics_list)

    # Debugging: Inspect the first few entries of metrics_df
    print("\n--- Metrics DataFrame Head ---")
    print(metrics_df.head())

    print("\n--- Metrics DataFrame Info ---")
    print(metrics_df.info())

    # Data Validation: Ensure all required columns are present and have correct data types
    required_columns = ["Dataset", "Image", "MAE", "MSE", "RMSE", "R2"]
    missing_columns = set(required_columns) - set(metrics_df.columns)
    if missing_columns:
        raise ValueError(f"The following required columns are missing from the metrics DataFrame: {missing_columns}")

    # Check for non-numeric data in error metrics columns
    for col in ["MAE", "MSE", "RMSE", "R2"]:
        if not pd.api.types.is_numeric_dtype(metrics_df[col]):
            raise TypeError(f"Column '{col}' must contain numeric data. Found data type: {metrics_df[col].dtype}")

    # Inspect the 'Dataset' column for any anomalies
    print("\n--- Unique Dataset Names ---")
    print(metrics_df['Dataset'].unique())

    print("\n--- Entries in 'Dataset' Containing Image Filenames ---")
    image_extensions = ['.png', '.jpg', '.jpeg', '.tif', '.tiff']
    pattern = '|'.join([ext.replace('.', r'\.') for ext in image_extensions])
    problematic_entries = metrics_df[metrics_df['Dataset'].str.contains(pattern, regex=True)]
    print(problematic_entries)

    print("\n--- Length of 'Dataset' Entries ---")
    print(metrics_df['Dataset'].str.len().describe())

    # Check for any unusually long 'Dataset' entries
    unusually_long = metrics_df[metrics_df['Dataset'].str.len() > 50]
    if not unusually_long.empty:
        print("\n--- Unusually Long 'Dataset' Entries ---")
        print(unusually_long)
    else:
        print("\nNo unusually long 'Dataset' entries found.")

    # Remove entries where 'Dataset' contains image filenames
    if not problematic_entries.empty:
        print("\nRemoving problematic entries from 'metrics_df'...")
        metrics_df = metrics_df[~metrics_df['Dataset'].str.contains(pattern, regex=True)]
        print(f"Remaining entries after removal: {metrics_df.shape[0]}")
    else:
        print("\nNo problematic 'Dataset' entries found. Proceeding...")

    # Compute average error metrics per dataset by selecting only numeric columns
    try:
        average_metrics = metrics_df.groupby("Dataset")[["MAE", "MSE", "RMSE", "R2"]].mean().reset_index()
    except Exception as e:
        print(f"\nError during groupby operation: {e}")
        raise

    return {
        "average_metrics_df": average_metrics,
        "aggregate_deviations": aggregate_deviations,
        "aggregate_high_velocity_deviations": aggregate_high_velocity_deviations,
        "aggregate_deviation_histograms": aggregate_deviation_histograms,
        "deviation_frequency": deviation_frequency,
        "deviation_sum": deviation_sum,
        "ground_truth_density": ground_truth_density
    }

def plot_fixed_velocity_bin_metrics(deviation_frequency, deviation_sum, output_path, bin_label="0.00-0.10 m/s"):
    """
    Plot frequency, sum, and density of deviations within a specific velocity bin for each dataset.

    Args:
        deviation_frequency (dict): Dictionary with dataset names as keys and another dict as values
                                    mapping ground truth velocity to deviation count.
        deviation_sum (dict): Dictionary with dataset names as keys and another dict as values
                              mapping ground truth velocity to sum of deviation magnitudes.
        output_path (str): Path to save the fixed velocity bin metrics plot.
        bin_label (str): The specific velocity bin to focus on (e.g., "0.00-0.10 m/s").

    Returns:
        None
    """
    # Define the specific velocity bin
    target_bin = bin_label  # e.g., "0.00-0.10 m/s"

    # Prepare data for plotting
    plot_data = []
    for dataset in deviation_frequency:
        # Extract the lower and upper bounds from bin_label
        lower, upper = map(float, target_bin.split('-'))
        # Since bins are defined with two decimal places, ensure consistency
        lower = round(lower, 2)
        upper = round(upper, 2)

        # Count frequency
        freq = 0
        for v, count in deviation_frequency[dataset].items():
            if lower <= v < upper:
                freq += count

        # Sum of deviations
        sum_dev = 0.0
        for v, s_dev in deviation_sum[dataset].items():
            if lower <= v < upper:
                sum_dev += s_dev

        # Total deviations for density
        total_devs = sum(deviation_frequency[dataset].values())
        density = (freq / total_devs) * 100 if total_devs > 0 else 0

        plot_data.append({
            "Dataset": dataset,
            "Frequency": freq,
            "Sum of Deviations": sum_dev,
            "Density (%)": density
        })

    plot_df = pd.DataFrame(plot_data)

    # Set the order for metrics
    metrics_order = ["Frequency", "Sum of Deviations", "Density (%)"]

    # Melt the DataFrame for seaborn plotting
    plot_df_melted = plot_df.melt(id_vars="Dataset", value_vars=metrics_order, var_name="Metric", value_name="Value")

    # Plotting
    plt.figure(figsize=(12, 8))
    sns.barplot(data=plot_df_melted, x="Dataset", y="Value", hue="Metric")
    plt.xlabel("Dataset", fontsize=14)
    plt.ylabel("Value", fontsize=14)
    plt.title(f"Deviation Metrics within {target_bin} Velocity Bin", fontsize=16)
    plt.legend(title="Metric", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def plot_average_deviation_per_bin(avg_dev_df, output_path, velocity_threshold=0.0):
    """
    Plot the average magnitude of deviations per velocity bin for each dataset,
    considering only velocities above a specified threshold.

    Args:
        avg_dev_df (pd.DataFrame): DataFrame with 'Velocity Bin' and datasets as columns containing average deviations.
        output_path (str): Path to save the average deviation plot.
        velocity_threshold (float): The minimum ground truth velocity to consider (in m/s).

    Returns:
        None
    """
    plt.figure(figsize=(15, 8))
    
    # Melt the DataFrame for seaborn
    plot_df_melted = avg_dev_df.melt(id_vars="Velocity Bin", var_name="Dataset", value_name="Average Deviation (m/s)")
    
    # Convert 'Velocity Bin' to numeric by extracting the lower bound
    plot_df_melted['Velocity Lower Bound'] = plot_df_melted['Velocity Bin'].apply(
        lambda x: float(x.split('-')[0]) if isinstance(x, str) else float(x)
    )
    
    # Sort by 'Velocity Lower Bound' for accurate plotting
    plot_df_melted = plot_df_melted.sort_values('Velocity Lower Bound')
    
    # Initialize the seaborn line plot
    sns.lineplot(
        data=plot_df_melted,
        x='Velocity Lower Bound',
        y='Average Deviation (m/s)',
        hue='Dataset',
        marker='o'
    )
    
    plt.xlabel("Ground Truth Velocity Lower Bound (m/s)", fontsize=14)
    plt.ylabel("Average Deviation Magnitude (m/s)", fontsize=14)
    plt.title(f"Average Magnitude of Deviations per Dataset (Above {velocity_threshold} m/s)", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Set x-axis limits starting from the velocity threshold
    plt.xlim(velocity_threshold, 0.56)  # Adjust 0.56 if your max velocity changes
    
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()
    
    print(f"Average deviation per bin plot saved to '{output_path}'.")



def plot_deviation_density_per_velocity_bin(deviation_frequency, output_path, bin_size=0.05):
    """
    Plot the density of deviations per ground truth velocity bin for each dataset.

    Args:
        deviation_frequency (dict): Dictionary with dataset names as keys and another dict as values
                                    mapping ground truth velocity to deviation count.
        output_path (str): Path to save the deviation density plot.
        bin_size (float): Size of each velocity bin.

    Returns:
        None
    """
    # Define fixed velocity bins from 0 to 0.56 m/s with bin size 0.05 m/s
    min_v = 0.0
    max_v = 0.56
    bins = np.arange(min_v, max_v + bin_size, bin_size)
    bin_labels = [f"{bins[i]:.2f}-{bins[i+1]:.2f}" for i in range(len(bins)-1)]

    # Prepare DataFrame for plotting
    plot_data = []
    for dataset, freq_dict in deviation_frequency.items():
        # Initialize counts for each bin
        bin_counts = np.zeros(len(bins)-1, dtype=int)
        for v, count in freq_dict.items():
            # Find the appropriate bin
            bin_index = np.digitize(v, bins) - 1  # digitize returns indices starting at 1
            if 0 <= bin_index < len(bin_counts):
                bin_counts[bin_index] += count
        # Calculate total deviations for density normalization
        total_deviations = bin_counts.sum()
        if total_deviations > 0:
            bin_density = (bin_counts / total_deviations) * 100  # Percentage
        else:
            bin_density = bin_counts  # All zeros
        for i in range(len(bin_density)):
            plot_data.append({
                "Dataset": dataset,
                "Velocity Bin": bin_labels[i],
                "Density (%)": bin_density[i]
            })

    plot_df = pd.DataFrame(plot_data)

    # Pivot the DataFrame for easier plotting
    pivot_df = plot_df.pivot(index="Velocity Bin", columns="Dataset", values="Density (%)").fillna(0)

    # Plotting
    pivot_df.plot(kind='bar', figsize=(15, 8))
    plt.xlabel("Ground Truth Velocity Bin (m/s)", fontsize=14)
    plt.ylabel("Density of Deviations (%)", fontsize=14)
    plt.title("Density of Velocity Deviations per Ground Truth Velocity Bin with Treshold=0.01", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def plot_sum_deviation_per_velocity_bin(deviation_sum, output_path, bin_size=0.05):
    """
    Plot the sum of deviation magnitudes per ground truth velocity bin for each dataset.

    Args:
        deviation_sum (dict): Dictionary with dataset names as keys and another dict as values
                              mapping ground truth velocity to sum of deviation magnitudes.
        output_path (str): Path to save the sum deviation plot.
        bin_size (float): Size of each velocity bin.

    Returns:
        None
    """
    # Define fixed velocity bins from 0 to 0.56 m/s with bin size 0.05 m/s
    min_v = 0.0
    max_v = 0.56
    bins = np.arange(min_v, max_v + bin_size, bin_size)
    bin_labels = [f"{bins[i]:.2f}-{bins[i+1]:.2f}" for i in range(len(bins)-1)]

    # Prepare DataFrame for plotting
    plot_data = []
    for dataset, sum_dict in deviation_sum.items():
        # Initialize sums for each bin
        bin_sums = np.zeros(len(bins)-1, dtype=float)
        for v, sum_dev in sum_dict.items():
            # Find the appropriate bin
            bin_index = np.digitize(v, bins) - 1  # digitize returns indices starting at 1
            if 0 <= bin_index < len(bin_sums):
                bin_sums[bin_index] += sum_dev
        for i in range(len(bin_sums)):
            plot_data.append({
                "Dataset": dataset,
                "Velocity Bin": bin_labels[i],
                "Sum of Deviations": bin_sums[i]
            })

    plot_df = pd.DataFrame(plot_data)

    # Pivot the DataFrame for easier plotting
    pivot_df = plot_df.pivot(index="Velocity Bin", columns="Dataset", values="Sum of Deviations").fillna(0)

    # Plotting
    pivot_df.plot(kind='bar', figsize=(15, 8))
    plt.xlabel("Ground Truth Velocity Bin (m/s)", fontsize=14)
    plt.ylabel("Sum of Velocity Deviations (m/s)", fontsize=14)
    plt.title("Sum of Velocity Deviations per Ground Truth Velocity Bin with Treshold=0.01", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def plot_deviation_frequency(deviation_frequency, output_path, bin_size=0.05):
    """
    Plot the frequency of deviations per ground truth velocity bin for each dataset.

    Args:
        deviation_frequency (dict): Dictionary with dataset names as keys and another dict as values
                                    mapping ground truth velocity to deviation count.
        output_path (str): Path to save the deviation frequency plot.
        bin_size (float): Size of each velocity bin.

    Returns:
        None
    """
    # Define fixed velocity bins from 0 to 0.56 m/s with bin size 0.05 m/s
    min_v = 0.0
    max_v = 0.56
    bins = np.arange(min_v, max_v + bin_size, bin_size)
    bin_labels = [f"{bins[i]:.2f}-{bins[i+1]:.2f}" for i in range(len(bins)-1)]

    # Prepare DataFrame for plotting
    plot_data = []
    for dataset, freq_dict in deviation_frequency.items():
        # Initialize counts for each bin
        bin_counts = np.zeros(len(bins)-1, dtype=int)
        for v, count in freq_dict.items():
            # Find the appropriate bin
            bin_index = np.digitize(v, bins) - 1  # digitize returns indices starting at 1
            if 0 <= bin_index < len(bin_counts):
                bin_counts[bin_index] += count
        for i in range(len(bin_counts)):
            plot_data.append({
                "Dataset": dataset,
                "Velocity Bin": bin_labels[i],
                "Frequency": bin_counts[i]
            })

    plot_df = pd.DataFrame(plot_data)

    # Pivot the DataFrame for easier plotting
    pivot_df = plot_df.pivot(index="Velocity Bin", columns="Dataset", values="Frequency").fillna(0)

    # Plotting
    pivot_df.plot(kind='bar', figsize=(15, 8))
    plt.xlabel("Ground Truth Velocity Bin (m/s)", fontsize=14)
    plt.ylabel("Frequency of Deviations", fontsize=14)
    plt.title("Frequency of Velocity Deviations per Ground Truth Velocity Bin with Treshold=0.01", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()


def plot_fixed_velocity_bin_metrics(deviation_frequency, deviation_sum, output_path, bin_label="0.00-0.10 m/s"):
    """
    Plot frequency, sum, and density of deviations within a specific velocity bin for each dataset.

    Args:
        deviation_frequency (dict): Dictionary with dataset names as keys and another dict as values
                                    mapping ground truth velocity to deviation count.
        deviation_sum (dict): Dictionary with dataset names as keys and another dict as values
                              mapping ground truth velocity to sum of deviation magnitudes.
        output_path (str): Path to save the fixed velocity bin metrics plot.
        bin_label (str): The specific velocity bin to focus on (e.g., "0.00-0.10 m/s").

    Returns:
        None
    """
    # Define the specific velocity bin
    target_bin = bin_label  # e.g., "0.00-0.10 m/s"

    # Prepare data for plotting
    plot_data = []
    for dataset in deviation_frequency:
        # Extract the lower and upper bounds from bin_label
        try:
            lower_str, upper_str = target_bin.split('-')
            # Remove any non-numeric characters from upper_str
            upper_str = ''.join(filter(lambda x: x.isdigit() or x == '.', upper_str))
            lower = float(lower_str)
            upper = float(upper_str)
        except Exception as e:
            print(f"Error parsing bin_label '{bin_label}': {e}")
            continue

        # Count frequency
        freq = 0
        for v, count in deviation_frequency[dataset].items():
            if lower <= v < upper:
                freq += count

        # Sum of deviations
        sum_dev = 0.0
        for v, s_dev in deviation_sum[dataset].items():
            if lower <= v < upper:
                sum_dev += s_dev

        # Total deviations for density
        total_devs = sum(deviation_frequency[dataset].values())
        density = (freq / total_devs) * 100 if total_devs > 0 else 0

        plot_data.append({
            "Dataset": dataset,
            "Frequency": freq,
            "Sum of Deviations": sum_dev,
            "Density (%)": density
        })

    plot_df = pd.DataFrame(plot_data)

    # Check if plot_df is empty
    if plot_df.empty:
        print(f"No data available for the bin '{bin_label}'. Skipping plot.")
        return

    # Set the order for metrics
    metrics_order = ["Frequency", "Sum of Deviations", "Density (%)"]

    # Melt the DataFrame for seaborn plotting
    plot_df_melted = plot_df.melt(id_vars="Dataset", value_vars=metrics_order, var_name="Metric", value_name="Value")

    # Plotting
    plt.figure(figsize=(12, 8))
    sns.barplot(data=plot_df_melted, x="Dataset", y="Value", hue="Metric")
    plt.xlabel("Dataset", fontsize=14)
    plt.ylabel("Value", fontsize=14)
    plt.title(f"Deviation Metrics within {target_bin} Velocity Bin", fontsize=16)
    plt.legend(title="Metric", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

import re  # Add this import at the top of your script

def plot_fixed_velocity_bin_metrics(deviation_frequency, deviation_sum, output_path, bin_label="0.00-0.10 m/s"):
    """
    Plot frequency, sum, and density of deviations within a specific velocity bin for each dataset.

    Args:
        deviation_frequency (dict): Dictionary with dataset names as keys and another dict as values
                                    mapping ground truth velocity to deviation count.
        deviation_sum (dict): Dictionary with dataset names as keys and another dict as values
                              mapping ground truth velocity to sum of deviation magnitudes.
        output_path (str): Path to save the fixed velocity bin metrics plot.
        bin_label (str): The specific velocity bin to focus on (e.g., "0.00-0.10 m/s").

    Returns:
        None
    """
    # Define the specific velocity bin
    target_bin = bin_label  # e.g., "0.00-0.10 m/s"

    # Use regular expressions to extract numeric boundaries
    match = re.match(r"([0-9.]+)-([0-9.]+)", target_bin)
    if not match:
        print(f"Invalid bin_label format: '{bin_label}'. Expected format 'lower-upper m/s'.")
        return

    lower_str, upper_str = match.groups()
    try:
        lower = float(lower_str)
        upper = float(upper_str)
    except ValueError as e:
        print(f"Error converting bin boundaries to float: {e}")
        return

    # Prepare data for plotting
    plot_data = []
    for dataset in deviation_frequency:
        # Count frequency
        freq = 0
        for v, count in deviation_frequency[dataset].items():
            if lower <= v < upper:
                freq += count

        # Sum of deviations
        sum_dev = 0.0
        for v, s_dev in deviation_sum[dataset].items():
            if lower <= v < upper:
                sum_dev += s_dev

        # Total deviations for density
        total_devs = sum(deviation_frequency[dataset].values())
        density = (freq / total_devs) * 100 if total_devs > 0 else 0

        plot_data.append({
            "Dataset": dataset,
            "Frequency": freq,
            "Sum of Deviations": sum_dev,
            "Density (%)": density
        })

    plot_df = pd.DataFrame(plot_data)

    # Check if plot_df is empty
    if plot_df.empty:
        print(f"No data available for the bin '{bin_label}'. Skipping plot.")
        return

    # Set the order for metrics
    metrics_order = ["Frequency", "Sum of Deviations", "Density (%)"]

    # Melt the DataFrame for seaborn plotting
    plot_df_melted = plot_df.melt(id_vars="Dataset", value_vars=metrics_order, var_name="Metric", value_name="Value")

    # Plotting
    plt.figure(figsize=(12, 8))
    sns.barplot(data=plot_df_melted, x="Dataset", y="Value", hue="Metric")
    plt.xlabel("Dataset", fontsize=14)
    plt.ylabel("Value", fontsize=14)
    plt.title(f"Deviation Metrics within {target_bin} Velocity Bin", fontsize=16)
    plt.legend(title="Metric", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()


def plot_ground_truth_velocity_density_continuous_first_dataset(ground_truth_density, output_path):
    """
    Plot the continuous density distribution of ground truth velocities for the first dataset.
    
    Args:
        ground_truth_density (dict): Dictionary with dataset names as keys and lists of ground truth velocities as values.
        output_path (str): Path to save the continuous ground truth velocity density plot.
    
    Returns:
        None
    """
    plt.figure(figsize=(15, 8))
    
    # Check if ground_truth_density is not empty
    if not ground_truth_density:
        print("The ground_truth_density dictionary is empty. No data to plot.")
        return
    
    # Get the first dataset
    first_dataset = next(iter(ground_truth_density))
    velocities = ground_truth_density[first_dataset]
    
    if len(velocities) == 0:
        print(f"No ground truth velocities available for dataset '{first_dataset}'. Skipping plot.")
        return
    
    # Define color palette
    color = sns.color_palette("tab10", n_colors=1)[0]
    
    # Plot KDE for the first dataset
    sns.kdeplot(
        velocities, 
        label=first_dataset, 
        shade=True, 
        color=color,
        alpha=0.6,
        bw_adjust=1  # Adjust bandwidth for smoothness
    )
    
    plt.xlabel("Ground Truth Velocity (m/s)", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.title(f"Continuous Density Distribution of Ground Truth Velocities", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xlim(0, 0.56)  # Set x-axis limits from 0 to 0.56 m/s
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()
    
    print(f"Continuous ground truth velocity density plot for '{first_dataset}' saved to '{output_path}'.")


def plot_deviation_density_per_velocity_bin(deviation_frequency, output_path, bin_size=0.05):
    """
    Plot the density of deviations per ground truth velocity bin for each dataset.

    Args:
        deviation_frequency (dict): Dictionary with dataset names as keys and another dict as values
                                    mapping ground truth velocity to deviation count.
        output_path (str): Path to save the deviation density plot.
        bin_size (float): Size of each velocity bin.

    Returns:
        None
    """
    # Define fixed velocity bins from 0 to 0.56 m/s with bin size 0.05 m/s
    min_v = 0.0
    max_v = 0.56
    bins = np.arange(min_v, max_v + bin_size, bin_size)
    bin_labels = [f"{bins[i]:.2f}-{bins[i+1]:.2f}" for i in range(len(bins)-1)]

    # Prepare DataFrame for plotting
    plot_data = []
    for dataset, freq_dict in deviation_frequency.items():
        # Initialize counts for each bin
        bin_counts = np.zeros(len(bins)-1, dtype=int)
        for v, count in freq_dict.items():
            # Find the appropriate bin
            bin_index = np.digitize(v, bins) - 1  # digitize returns indices starting at 1
            if 0 <= bin_index < len(bin_counts):
                bin_counts[bin_index] += count
        # Calculate total deviations for density normalization
        total_deviations = bin_counts.sum()
        if total_deviations > 0:
            bin_density = (bin_counts / total_deviations) * 100  # Percentage
        else:
            bin_density = bin_counts  # All zeros
        for i in range(len(bin_density)):
            plot_data.append({
                "Dataset": dataset,
                "Velocity Bin": bin_labels[i],
                "Density (%)": bin_density[i]
            })

    plot_df = pd.DataFrame(plot_data)

    # Pivot the DataFrame for easier plotting
    pivot_df = plot_df.pivot(index="Velocity Bin", columns="Dataset", values="Density (%)").fillna(0)

    # Plotting
    pivot_df.plot(kind='bar', figsize=(15, 8))
    plt.xlabel("Ground Truth Velocity Bin (m/s)", fontsize=14)
    plt.ylabel("Density of Deviations (%)", fontsize=14)
    plt.title("Density of Velocity Deviations per Ground Truth Velocity Bin", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def generate_aggregate_plots(aggregate_results, aggregate_plots_folder, bin_size=0.05, velocity_threshold=0.35):
    """
    Generate aggregate plots comparing datasets, including average deviation per bin considering a velocity threshold.

    Args:
        aggregate_results (dict): Contains accumulated data for aggregate plots.
        aggregate_plots_folder (str): Directory to save the aggregate plots.
        bin_size (float): Size of each velocity bin for the frequency and sum plots.
        velocity_threshold (float): The minimum ground truth velocity to consider (in m/s).

    Returns:
        None
    """
    os.makedirs(aggregate_plots_folder, exist_ok=True)

    # Unpack aggregate data
    aggregate_deviations = aggregate_results["aggregate_deviations"]
    aggregate_high_velocity_deviations = aggregate_results["aggregate_high_velocity_deviations"]
    aggregate_deviation_histograms = aggregate_results["aggregate_deviation_histograms"]
    deviation_frequency = aggregate_results["deviation_frequency"]
    deviation_sum = aggregate_results["deviation_sum"]
    ground_truth_density = aggregate_results["ground_truth_density"]

    # Generate Aggregate Overall Deviation Plot
    aggregate_overall_deviation_path = os.path.join(aggregate_plots_folder, "aggregate_overall_deviation.png")
    plot_aggregate_overall_deviation(aggregate_deviations, aggregate_overall_deviation_path, bins=50)

    # Generate Aggregate High-Velocity Deviation Plot
    aggregate_high_velocity_deviation_path = os.path.join(aggregate_plots_folder, "aggregate_high_velocity_deviation.png")
    plot_aggregate_high_velocity_deviation(aggregate_high_velocity_deviations, aggregate_high_velocity_deviation_path, bins=50)

    # Generate Aggregate Error by Fixed Velocity Regions (Frequency-based)
    aggregate_error_by_regions_frequency_path = os.path.join(aggregate_plots_folder, "aggregate_error_by_fixed_velocity_regions_frequency.png")
    plot_deviation_frequency(deviation_frequency, aggregate_error_by_regions_frequency_path, bin_size=bin_size)

    # Generate Aggregate Error by Fixed Velocity Regions (Sum-based)
    aggregate_error_by_regions_sum_path = os.path.join(aggregate_plots_folder, "aggregate_error_by_fixed_velocity_regions_sum.png")
    plot_sum_deviation_per_velocity_bin(deviation_sum, aggregate_error_by_regions_sum_path, bin_size=bin_size)

    # Generate Aggregate Error by Fixed Velocity Regions (Density-based)
    aggregate_error_by_regions_density_path = os.path.join(aggregate_plots_folder, "aggregate_error_by_fixed_velocity_regions_density.png")
    plot_deviation_density_per_velocity_bin(deviation_frequency, aggregate_error_by_regions_density_path, bin_size=bin_size)

    # Generate Fixed Velocity Bin Metrics Plot for 0.00-0.10 m/s
    fixed_velocity_bin = "0.00-0.10 m/s"
    aggregate_error_by_regions_0_0_1_path = os.path.join(aggregate_plots_folder, "aggregate_error_by_fixed_velocity_regions_0_0_1.png")
    plot_fixed_velocity_bin_metrics(deviation_frequency, deviation_sum, aggregate_error_by_regions_0_0_1_path, bin_label=fixed_velocity_bin)

    # Generate Ground Truth Velocity Density Plot (Continuous for First Dataset)
    ground_truth_velocity_density_continuous_first_dataset_path = os.path.join(
        aggregate_plots_folder, "ground_truth_velocity_density_continuous_first_dataset.png"
    )
    plot_ground_truth_velocity_density_continuous_first_dataset(
        aggregate_results["ground_truth_density"], 
        ground_truth_velocity_density_continuous_first_dataset_path
    )

    # **New Steps:** Compute Average Deviations with Velocity Threshold and Generate the Plot
    # Compute the average deviation per bin per dataset considering the velocity threshold
    avg_dev_df = compute_average_deviation(deviation_frequency, deviation_sum, velocity_threshold=velocity_threshold)
    
    # Define the output path for the average deviation plot
    average_deviation_plot_path = os.path.join(aggregate_plots_folder, f"average_deviation_per_bin_per_dataset_above_{velocity_threshold}_m_s.png")
    
    # Generate the plot with the velocity threshold
    plot_average_deviation_per_bin(avg_dev_df, average_deviation_plot_path, velocity_threshold=velocity_threshold)

    # Generate Aggregate Deviation Histogram
    aggregate_deviation_histogram_path = os.path.join(aggregate_plots_folder, "aggregate_deviation_histogram.png")
    plot_aggregate_deviation_histogram(aggregate_deviation_histograms, aggregate_deviation_histogram_path, bins=50)

    print("\nAggregate plots generated and saved successfully.")


def plot_aggregate_overall_deviation(aggregate_deviations, output_path, bins=50):
    plt.figure(figsize=(12, 8))
    colors = plt.cm.tab10.colors  # Up to 10 distinct colors

    for idx, (dataset, deviations) in enumerate(aggregate_deviations.items()):
        color = colors[idx % len(colors)]
        plt.hist(deviations, bins=bins, alpha=0.5, color=color, label=dataset, density=True)

    plt.xlabel("Velocity Deviation (m/s)", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.title("Aggregate Velocity Deviations Across All Datasets", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def plot_aggregate_high_velocity_deviation(aggregate_high_velocity_deviations, output_path, bins=50):
    plt.figure(figsize=(12, 8))
    colors = plt.cm.tab10.colors  # Up to 10 distinct colors

    for idx, (dataset, deviations) in enumerate(aggregate_high_velocity_deviations.items()):
        color = colors[idx % len(colors)]
        plt.hist(deviations, bins=bins, alpha=0.5, color=color, label=dataset, density=True)

    plt.xlabel("High-Velocity Deviation (m/s)", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.title("Aggregate High-Velocity Deviations Across All Datasets", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def plot_aggregate_deviation_histogram(aggregate_deviation_histograms, output_path, bins=50):
    plt.figure(figsize=(12, 8))
    colors = plt.cm.tab10.colors  # Up to 10 distinct colors

    for idx, (dataset, deviations) in enumerate(aggregate_deviation_histograms.items()):
        color = colors[idx % len(colors)]
        plt.hist(deviations, bins=bins, alpha=0.5, color=color, label=dataset, density=True)

    plt.xlabel("Velocity Deviation (m/s)", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.title("Aggregate Deviation Histograms Across All Datasets", fontsize=16)
    plt.legend(title="Dataset", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

In [231]:
# Define model prediction folders
model_paths = {
    "DAT_x2_Scalar_No_Physics": "/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/DAT_x2_Scalar_No_Physics",
    "DAT_X2_Vector_No_Physics": "/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/DAT_X2_Vector_No_Physics",
    "DAT_X2_Vector_umass_2": "/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/DAT_X2_Vector_umass_2"
}

# Define ground truth images folder
ground_truth_folder = "/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/GroundTruth"

# Define output folder for analysis results
output_folder = "/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/output_folder"

deviation_threshold = 0.01 # Adjust based on your requirements


# Perform batch comprehensive analysis with parallel processing
analysis_results = batch_comprehensive_analysis_parallel(
    dataset_paths=model_paths,
    ground_truth_folder=ground_truth_folder,
    velocity_threshold=0.4, 
    deviation_threshold=deviation_threshold,
 # Adjust as needed
    n_jobs=-1  # Use all available processors
)

# Extract average metrics DataFrame
average_metrics_df = analysis_results["average_metrics_df"]

# Debugging: Inspect the average_metrics_df
print("\n--- Average Metrics DataFrame ---")
print(average_metrics_df)

# Save average error metrics to CSV
metrics_csv_path = os.path.join(output_folder, "average_error_metrics.csv")
average_metrics_df.to_csv(metrics_csv_path, index=False)
print(f"\nAverage error metrics per dataset saved to '{metrics_csv_path}'.")

# Define paths for aggregate plots
aggregate_plots_folder = os.path.join(output_folder, "aggregate_plots")
os.makedirs(aggregate_plots_folder, exist_ok=True)

# Generate Aggregate Plots
generate_aggregate_plots(analysis_results, aggregate_plots_folder, bin_size=0.05)

print("\nAll analyses and aggregate plots completed successfully.")


Preparing to analyze Dataset: 'DAT_x2_Scalar_No_Physics'

Preparing to analyze Dataset: 'DAT_X2_Vector_No_Physics'

Preparing to analyze Dataset: 'DAT_X2_Vector_umass_2'

Starting parallel processing with all available jobs...
Ground truth for 'XZ_127_35.png' not found in dataset 'DAT_x2_Scalar_No_Physics'. Skipping.
Ground truth for 'XZ_30_35.png' not found in dataset 'DAT_x2_Scalar_No_Physics'. Skipping.
Ground truth for 'XZ_127_38.png' not found in dataset 'DAT_x2_Scalar_No_Physics'. Skipping.
Ground truth for 'XZ_30_34.png' not found in dataset 'DAT_x2_Scalar_No_Physics'. Skipping.
Ground truth for 'XZ_127_34.png' not found in dataset 'DAT_x2_Scalar_No_Physics'. Skipping.

--- Metrics DataFrame Head ---
                    Dataset         Image       MAE       MSE      RMSE  \
0  DAT_x2_Scalar_No_Physics  XZ_63_35.png  0.004820  0.000117  0.010814   
1  DAT_x2_Scalar_No_Physics  XZ_43_35.png  0.003922  0.000115  0.010740   
2  DAT_x2_Scalar_No_Physics  XZ_83_39.png  0.002206  0.00


`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(


Continuous ground truth velocity density plot for 'DAT_x2_Scalar_No_Physics' saved to '/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/output_folder/aggregate_plots/ground_truth_velocity_density_continuous_first_dataset.png'.
Average deviation per bin plot saved to '/home/vittorio/Scrivania/ResShift_4_scale/data/AnalysisVelocity/output_folder/aggregate_plots/average_deviation_per_bin_per_dataset_above_0.35_m_s.png'.

Aggregate plots generated and saved successfully.

All analyses and aggregate plots completed successfully.
