# Scikit-learn Multiclass Evaluation

Comprehensive evaluation of multiclass predictions against ground truth annotations.

In [None]:
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    classification_report,
    cohen_kappa_score,
    confusion_matrix,
    matthews_corrcoef,
    precision_recall_fscore_support,
)

warnings.filterwarnings("ignore")

# Global plot configuration
plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["axes.facecolor"] = "white"
plt.rcParams["savefig.facecolor"] = "white"
plt.rcParams["savefig.bbox"] = "tight"
plt.rcParams["font.size"] = 10
plt.rcParams["axes.titlesize"] = 12
plt.rcParams["axes.labelsize"] = 11
plt.rcParams["xtick.labelsize"] = 9
plt.rcParams["ytick.labelsize"] = 9
plt.rcParams["legend.fontsize"] = 10

# Set seaborn style
sns.set_style("whitegrid")
sns.set_palette("husl")
%matplotlib inline

# CONFIGURE BASE DIRECTORY - Change this to your project directory
BASE_DIR = Path.cwd()  # Current directory
# BASE_DIR = Path('/Users/tod/Desktop/LMM_POC')  # Or set specific path

print(f"📁 Base directory: {BASE_DIR}")

In [None]:
# Option 1: Load from CSV file
# Uncomment and modify the path below to load your data
# INPUT_FILE = DATA_DIR / 'your_predictions.csv'
# df = pd.read_csv(INPUT_FILE)
# print(f"✅ Loaded data from: {INPUT_FILE}")

# Option 2: Load from specific path
# df = pd.read_csv(Path('/path/to/your/data.csv'))

# Option 3: Use existing DataFrame from previous cell
# df = your_existing_dataframe

# Option 4: Create sample data for testing
np.random.seed(42)
n_samples = 1000
classes = ["class_A", "class_B", "class_C", "class_D", "class_E"]

# Generate ground truth
annotator = np.random.choice(classes, n_samples, p=[0.3, 0.25, 0.2, 0.15, 0.1])

# Generate predictions with some errors
pred = annotator.copy()
error_mask = np.random.random(n_samples) < 0.2  # 20% error rate
pred[error_mask] = np.random.choice(classes, error_mask.sum())

# Create DataFrame
df = pd.DataFrame({PRED_COLUMN: pred, TRUTH_COLUMN: annotator})

INPUT_FILE = OUTPUT_DIR / "sample_data.csv"
df.to_csv(INPUT_FILE, index=False)
print(f"💾 Sample data saved to: {INPUT_FILE}")

print(f"\nDataFrame shape: {df.shape}")
print(f"Unique classes in ground truth: {df[TRUTH_COLUMN].nunique()}")
print(f"Unique classes in predictions: {df[PRED_COLUMN].nunique()}")
df.head(10)

In [None]:
# Option 1: Load from CSV file
# df = pd.read_csv(BASE_DIR / 'your_predictions.csv')

# Option 2: Use existing DataFrame
# df = your_existing_dataframe

# Option 3: Create sample data for testing
np.random.seed(42)
n_samples = 1000
classes = ["class_A", "class_B", "class_C", "class_D", "class_E"]

# Generate ground truth
annotator = np.random.choice(classes, n_samples, p=[0.3, 0.25, 0.2, 0.15, 0.1])

# Generate predictions with some errors
pred = annotator.copy()
error_mask = np.random.random(n_samples) < 0.2  # 20% error rate
pred[error_mask] = np.random.choice(classes, error_mask.sum())

# Create DataFrame
df = pd.DataFrame({"pred": pred, "annotator": annotator})

print(f"DataFrame shape: {df.shape}")
print(f"Unique classes in ground truth: {df['annotator'].nunique()}")
print(f"Unique classes in predictions: {df['pred'].nunique()}")
df.head(10)

## Calculate All Metrics

In [None]:
# Class distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Ground truth distribution
df["annotator"].value_counts().plot(kind="bar", ax=axes[0], color="skyblue")
axes[0].set_title("Ground Truth Distribution", fontweight="bold")
axes[0].set_xlabel("Class")
axes[0].set_ylabel("Count")
axes[0].tick_params(axis="x", rotation=45)

# Prediction distribution
df["pred"].value_counts().plot(kind="bar", ax=axes[1], color="lightcoral")
axes[1].set_title("Prediction Distribution", fontweight="bold")
axes[1].set_xlabel("Class")
axes[1].set_ylabel("Count")
axes[1].tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()

## Classification Report

In [None]:
# Get predictions and ground truth
y_true = df["annotator"].to_numpy()
y_pred = df["pred"].to_numpy()

# Identify valid classes (only those in ground truth)
valid_classes = np.unique(y_true)

# Check for invalid predictions
invalid_mask = ~np.isin(y_pred, valid_classes)
n_invalid = invalid_mask.sum()

if n_invalid > 0:
    invalid_classes = set(np.unique(y_pred)) - set(valid_classes)
    print(
        f"⚠️ WARNING: Found {n_invalid} predictions ({n_invalid / len(y_pred) * 100:.2f}%) with invalid classes"
    )
    print(f"   Invalid classes being predicted: {invalid_classes}")
    for cls in invalid_classes:
        count = (y_pred == cls).sum()
        print(f"     '{cls}': {count} occurrences")
    print("   These will be counted as errors but excluded from visualizations\n")

# Calculate all metrics (including invalid predictions as errors)
metrics = {}

# Basic accuracy (invalid predictions count as errors)
metrics["accuracy"] = accuracy_score(y_true, y_pred)
metrics["balanced_accuracy"] = balanced_accuracy_score(y_true, y_pred)

# Agreement metrics
metrics["cohen_kappa"] = cohen_kappa_score(y_true, y_pred)
metrics["matthews_corrcoef"] = matthews_corrcoef(y_true, y_pred)

# Per-class metrics - only for valid classes
precision, recall, f1, support = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average=None, zero_division=0
)

# Macro averages (only on valid classes)
metrics["precision_macro"] = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average="macro", zero_division=0
)[0]
metrics["recall_macro"] = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average="macro", zero_division=0
)[1]
metrics["f1_macro"] = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average="macro", zero_division=0
)[2]

# Weighted averages
metrics["precision_weighted"] = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average="weighted", zero_division=0
)[0]
metrics["recall_weighted"] = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average="weighted", zero_division=0
)[1]
metrics["f1_weighted"] = precision_recall_fscore_support(
    y_true, y_pred, labels=valid_classes, average="weighted", zero_division=0
)[2]

# Store valid classes for later use
metrics["valid_classes"] = valid_classes
metrics["n_invalid_predictions"] = n_invalid

# Display results
print("=" * 60)
print("OVERALL METRICS")
print("=" * 60)
print(f"Accuracy:                {metrics['accuracy']:.4f}")
print(f"Balanced Accuracy:       {metrics['balanced_accuracy']:.4f}")
print(f"Cohen's Kappa:           {metrics['cohen_kappa']:.4f}")
print(f"Matthews Corr Coef:      {metrics['matthews_corrcoef']:.4f}")
if n_invalid > 0:
    print(
        f"Invalid Predictions:     {n_invalid} ({n_invalid / len(y_pred) * 100:.2f}%)"
    )

print("\n" + "=" * 60)
print("AVERAGED METRICS (Valid Classes Only)")
print("=" * 60)
print("\nMacro Averages (unweighted):")
print(f"  Precision: {metrics['precision_macro']:.4f}")
print(f"  Recall:    {metrics['recall_macro']:.4f}")
print(f"  F1-Score:  {metrics['f1_macro']:.4f}")

print("\nWeighted Averages (by support):")
print(f"  Precision: {metrics['precision_weighted']:.4f}")
print(f"  Recall:    {metrics['recall_weighted']:.4f}")
print(f"  F1-Score:  {metrics['f1_weighted']:.4f}")

## Confusion Matrix

In [None]:
# Create per-class metrics DataFrame (only for valid classes)
class_metrics = []

for i, cls in enumerate(valid_classes):
    cls_mask = y_true == cls
    cls_pred = y_pred[cls_mask]

    # Calculate accuracy including invalid predictions as errors
    accuracy = (cls_pred == cls).mean()

    class_metrics.append(
        {
            "class": cls,
            "support": cls_mask.sum(),
            "accuracy": accuracy,
            "precision": precision[i],
            "recall": recall[i],
            "f1": f1[i],
        }
    )

class_df = pd.DataFrame(class_metrics)
class_df = class_df.sort_values("f1", ascending=False)

# Display summary statistics
print("=" * 70)
print("CLASS PERFORMANCE SUMMARY")
print("=" * 70)
print(f"Total number of valid classes: {len(class_df)}")
print("\nF1-Score Statistics:")
print(class_df["f1"].describe().to_string())

if metrics["n_invalid_predictions"] > 0:
    print(
        f"\n⚠️ Note: {metrics['n_invalid_predictions']} predictions with invalid classes"
    )
    print("         are counted as errors but not shown as separate classes")

# Show top and bottom performers
print("\n" + "=" * 70)
print("TOP 10 BEST PERFORMING CLASSES")
print("=" * 70)
top_10 = class_df.nlargest(10, "f1")
print(top_10[["class", "f1", "precision", "recall", "support"]].to_string(index=False))

print("\n" + "=" * 70)
print("TOP 10 WORST PERFORMING CLASSES")
print("=" * 70)
bottom_10 = class_df.nsmallest(10, "f1")
print(
    bottom_10[["class", "f1", "precision", "recall", "support"]].to_string(index=False)
)

# Create visualizations for many classes
if len(class_df) > 30:
    print(
        f"\n📊 Detected {len(class_df)} classes - using optimized visualizations for clarity"
    )

    # Figure 1: Distribution of metrics
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Precision distribution
    axes[0].hist(
        class_df["precision"], bins=30, color="skyblue", edgecolor="black", alpha=0.7
    )
    axes[0].axvline(
        class_df["precision"].mean(),
        color="red",
        linestyle="--",
        label=f"Mean: {class_df['precision'].mean():.3f}",
    )
    axes[0].set_xlabel("Precision")
    axes[0].set_ylabel("Number of Classes")
    axes[0].set_title("Distribution of Precision Across Classes")
    axes[0].legend()
    axes[0].grid(axis="y", alpha=0.3)

    # Recall distribution
    axes[1].hist(
        class_df["recall"], bins=30, color="lightcoral", edgecolor="black", alpha=0.7
    )
    axes[1].axvline(
        class_df["recall"].mean(),
        color="red",
        linestyle="--",
        label=f"Mean: {class_df['recall'].mean():.3f}",
    )
    axes[1].set_xlabel("Recall")
    axes[1].set_ylabel("Number of Classes")
    axes[1].set_title("Distribution of Recall Across Classes")
    axes[1].legend()
    axes[1].grid(axis="y", alpha=0.3)

    # F1-Score distribution
    axes[2].hist(
        class_df["f1"], bins=30, color="lightgreen", edgecolor="black", alpha=0.7
    )
    axes[2].axvline(
        class_df["f1"].mean(),
        color="red",
        linestyle="--",
        label=f"Mean: {class_df['f1'].mean():.3f}",
    )
    axes[2].set_xlabel("F1-Score")
    axes[2].set_ylabel("Number of Classes")
    axes[2].set_title("Distribution of F1-Score Across Classes")
    axes[2].legend()
    axes[2].grid(axis="y", alpha=0.3)

    plt.suptitle("Performance Metrics Distribution", fontweight="bold")
    plt.tight_layout()
    plt.show()

    # Figure 2: Top 20 performers using seaborn
    fig, axes = plt.subplots(1, 3, figsize=(15, 8))
    top_20 = class_df.nlargest(20, "f1")

    # Use seaborn for cleaner horizontal bar plots
    sns.barplot(data=top_20, y="class", x="precision", ax=axes[0], orient="h")
    axes[0].set_xlabel("Precision")
    axes[0].set_title("Top 20 Classes - Precision", fontweight="bold")
    axes[0].set_ylabel("")

    sns.barplot(data=top_20, y="class", x="recall", ax=axes[1], orient="h")
    axes[1].set_xlabel("Recall")
    axes[1].set_title("Top 20 Classes - Recall", fontweight="bold")
    axes[1].set_ylabel("")

    sns.barplot(data=top_20, y="class", x="f1", ax=axes[2], orient="h")
    axes[2].set_xlabel("F1-Score")
    axes[2].set_title("Top 20 Classes - F1-Score", fontweight="bold")
    axes[2].set_ylabel("")

    plt.suptitle("Top 20 Best Performing Classes", fontweight="bold")
    plt.tight_layout()
    plt.show()

    # Figure 3: Bottom 20 performers (if there are enough classes)
    if len(class_df) > 20:
        fig, axes = plt.subplots(1, 3, figsize=(15, 8))
        bottom_20 = class_df.nsmallest(20, "f1")

        sns.barplot(
            data=bottom_20,
            y="class",
            x="precision",
            ax=axes[0],
            orient="h",
            color="#FF6B6B",
        )
        axes[0].set_xlabel("Precision")
        axes[0].set_title("Bottom 20 Classes - Precision", fontweight="bold")
        axes[0].set_ylabel("")

        sns.barplot(
            data=bottom_20,
            y="class",
            x="recall",
            ax=axes[1],
            orient="h",
            color="#FF9999",
        )
        axes[1].set_xlabel("Recall")
        axes[1].set_title("Bottom 20 Classes - Recall", fontweight="bold")
        axes[1].set_ylabel("")

        sns.barplot(
            data=bottom_20, y="class", x="f1", ax=axes[2], orient="h", color="#FFB3B3"
        )
        axes[2].set_xlabel("F1-Score")
        axes[2].set_title("Bottom 20 Classes - F1-Score", fontweight="bold")
        axes[2].set_ylabel("")

        plt.suptitle("Bottom 20 Worst Performing Classes", fontweight="bold")
        plt.tight_layout()
        plt.show()

else:
    # Original visualization for fewer classes using seaborn
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    sns.barplot(data=class_df, x="class", y="precision", ax=axes[0])
    axes[0].set_title("Precision by Class", fontweight="bold")
    axes[0].set_xlabel("Class")
    axes[0].set_ylabel("Precision")
    axes[0].tick_params(axis="x", rotation=45)

    sns.barplot(data=class_df, x="class", y="recall", ax=axes[1])
    axes[1].set_title("Recall by Class", fontweight="bold")
    axes[1].set_xlabel("Class")
    axes[1].set_ylabel("Recall")
    axes[1].tick_params(axis="x", rotation=45)

    sns.barplot(data=class_df, x="class", y="f1", ax=axes[2])
    axes[2].set_title("F1-Score by Class", fontweight="bold")
    axes[2].set_xlabel("Class")
    axes[2].set_ylabel("F1-Score")
    axes[2].tick_params(axis="x", rotation=45)

    plt.suptitle("Per-Class Performance Metrics", fontweight="bold")
    plt.tight_layout()
    plt.show()

In [None]:
# Generate classification report (only for valid classes)
print(classification_report(y_true, y_pred, labels=valid_classes, zero_division=0))

## Per-Class Performance Analysis

In [None]:
# Create confusion matrix (only for valid classes from ground truth)
cm = confusion_matrix(y_true, y_pred, labels=valid_classes)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=valid_classes,
    yticklabels=valid_classes,
    cbar_kws={"label": "Count"},
)
plt.title("Confusion Matrix (Valid Classes Only)", fontweight="bold")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.show()

# Calculate and display accuracy for each class
print("\nPer-Class Accuracy:")
print("=" * 40)
class_accuracy = cm.diagonal() / cm.sum(axis=1)
for i, cls in enumerate(valid_classes):
    print(f"{cls:<20}: {class_accuracy[i]:.4f}")

# Show impact of invalid predictions if any
if metrics["n_invalid_predictions"] > 0:
    print(
        f"\n⚠️ Note: {metrics['n_invalid_predictions']} invalid predictions are counted as errors"
    )
    print("         but not shown in the confusion matrix above")

In [None]:
# Create per-class metrics DataFrame (only for valid classes)
class_metrics = []

for i, cls in enumerate(valid_classes):
    cls_mask = y_true == cls
    cls_pred = y_pred[cls_mask]

    # Calculate accuracy including invalid predictions as errors
    accuracy = (cls_pred == cls).mean()

    class_metrics.append(
        {
            "class": cls,
            "support": cls_mask.sum(),
            "accuracy": accuracy,
            "precision": precision[i],
            "recall": recall[i],
            "f1": f1[i],
        }
    )

class_df = pd.DataFrame(class_metrics)
class_df = class_df.sort_values("f1", ascending=False)

# Display summary statistics
print("=" * 70)
print("CLASS PERFORMANCE SUMMARY")
print("=" * 70)
print(f"Total number of valid classes: {len(class_df)}")
print("\nF1-Score Statistics:")
print(class_df["f1"].describe().to_string())

if metrics["n_invalid_predictions"] > 0:
    print(
        f"\n⚠️ Note: {metrics['n_invalid_predictions']} predictions with invalid classes"
    )
    print("         are counted as errors but not shown as separate classes")

# Show top and bottom performers
print("\n" + "=" * 70)
print("TOP 10 BEST PERFORMING CLASSES")
print("=" * 70)
top_10 = class_df.nlargest(10, "f1")
print(top_10[["class", "f1", "precision", "recall", "support"]].to_string(index=False))

print("\n" + "=" * 70)
print("TOP 10 WORST PERFORMING CLASSES")
print("=" * 70)
bottom_10 = class_df.nsmallest(10, "f1")
print(
    bottom_10[["class", "f1", "precision", "recall", "support"]].to_string(index=False)
)

# SEABORN VISUALIZATIONS FOR MANY CLASSES
print(f"\n📊 Creating visualizations for {len(class_df)} classes using Seaborn...")

# Prepare data in long format for seaborn
class_df_melted = class_df.melt(
    id_vars=["class", "support", "accuracy"],
    value_vars=["precision", "recall", "f1"],
    var_name="metric",
    value_name="score",
)

# Figure 1: Distribution Analysis with Seaborn
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Violin plots for distributions
ax1 = axes[0, 0]
metrics_for_violin = class_df[["precision", "recall", "f1"]]
metrics_long = pd.melt(metrics_for_violin)
sns.violinplot(data=metrics_long, x="variable", y="value", ax=ax1, palette="Set2")
ax1.set_xlabel("Metric", fontweight="bold")
ax1.set_ylabel("Score", fontweight="bold")
ax1.set_title("Score Distributions (Violin Plot)", fontweight="bold")
ax1.set_xticklabels(["Precision", "Recall", "F1-Score"])

# Box plots with outliers
ax2 = axes[0, 1]
sns.boxplot(data=metrics_long, x="variable", y="value", ax=ax2, palette="muted")
sns.swarmplot(
    data=metrics_long, x="variable", y="value", ax=ax2, color="black", alpha=0.3, size=2
)
ax2.set_xlabel("Metric", fontweight="bold")
ax2.set_ylabel("Score", fontweight="bold")
ax2.set_title("Score Distributions with Outliers", fontweight="bold")
ax2.set_xticklabels(["Precision", "Recall", "F1-Score"])

# KDE plots for smooth distributions
ax3 = axes[1, 0]
sns.kdeplot(data=class_df["precision"], ax=ax3, label="Precision", fill=True, alpha=0.6)
sns.kdeplot(data=class_df["recall"], ax=ax3, label="Recall", fill=True, alpha=0.6)
sns.kdeplot(data=class_df["f1"], ax=ax3, label="F1-Score", fill=True, alpha=0.6)
ax3.set_xlabel("Score", fontweight="bold")
ax3.set_ylabel("Density", fontweight="bold")
ax3.set_title("Kernel Density Estimation", fontweight="bold")
ax3.legend()
ax3.set_xlim([0, 1])

# Joint plot for precision vs recall
ax4 = axes[1, 1]
sns.scatterplot(
    data=class_df,
    x="precision",
    y="recall",
    size="support",
    hue="f1",
    palette="RdYlGn",
    ax=ax4,
    sizes=(20, 200),
    alpha=0.7,
)
ax4.set_xlabel("Precision", fontweight="bold")
ax4.set_ylabel("Recall", fontweight="bold")
ax4.set_title("Precision vs Recall (sized by support)", fontweight="bold")
ax4.plot([0, 1], [0, 1], "k--", alpha=0.3)  # Diagonal reference line
ax4.set_xlim([0, 1])
ax4.set_ylim([0, 1])

plt.suptitle("Performance Distribution Analysis", fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

# Figure 2: Top/Bottom Performers with Seaborn
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Top 20 performers
top_20 = class_df.nlargest(20, "f1")
top_20_melted = top_20.melt(
    id_vars=["class"],
    value_vars=["precision", "recall", "f1"],
    var_name="metric",
    value_name="score",
)

sns.barplot(
    data=top_20_melted,
    y="class",
    x="score",
    hue="metric",
    ax=axes[0],
    palette="viridis",
    orient="h",
)
axes[0].set_xlabel("Score", fontweight="bold")
axes[0].set_ylabel("")
axes[0].set_title(
    "Top 20 Best Performing Classes", fontweight="bold", color="darkgreen"
)
axes[0].legend(title="Metric", loc="lower right")
axes[0].set_xlim([0, 1])

# Bottom 20 performers
bottom_20 = class_df.nsmallest(20, "f1")
bottom_20_melted = bottom_20.melt(
    id_vars=["class"],
    value_vars=["precision", "recall", "f1"],
    var_name="metric",
    value_name="score",
)

sns.barplot(
    data=bottom_20_melted,
    y="class",
    x="score",
    hue="metric",
    ax=axes[1],
    palette="rocket_r",
    orient="h",
)
axes[1].set_xlabel("Score", fontweight="bold")
axes[1].set_ylabel("")
axes[1].set_title(
    "Bottom 20 Worst Performing Classes", fontweight="bold", color="darkred"
)
axes[1].legend(title="Metric", loc="upper right")
axes[1].set_xlim([0, 1])

plt.suptitle("Best vs Worst Performers", fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

# Figure 3: Heatmap Analysis
if len(class_df) > 20:
    # Create two heatmaps - one for top performers, one for all classes
    fig, axes = plt.subplots(2, 1, figsize=(14, 10))

    # Top 30 classes heatmap
    top_30 = class_df.nlargest(30, "f1")
    heatmap_data_top = top_30.set_index("class")[["precision", "recall", "f1"]].T

    sns.heatmap(
        heatmap_data_top,
        annot=False,
        cmap="RdYlGn",
        vmin=0,
        vmax=1,
        ax=axes[0],
        cbar_kws={"label": "Score"},
        linewidths=0.5,
    )
    axes[0].set_xlabel("")
    axes[0].set_ylabel("Metric", fontweight="bold")
    axes[0].set_title("Top 30 Classes - Performance Heatmap", fontweight="bold")
    axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha="right")

    # All classes summary heatmap (binned)
    # Create bins for F1 scores
    n_bins = min(50, len(class_df))
    class_df["f1_bin"] = pd.cut(class_df["f1"], bins=n_bins, labels=False)
    binned_stats = class_df.groupby("f1_bin")[["precision", "recall", "f1"]].mean()

    sns.heatmap(
        binned_stats.T,
        annot=False,
        cmap="RdYlGn",
        vmin=0,
        vmax=1,
        ax=axes[1],
        cbar_kws={"label": "Score"},
        xticklabels=False,
    )
    axes[1].set_xlabel(
        f"Classes (binned by F1 score, {len(class_df)} total)", fontweight="bold"
    )
    axes[1].set_ylabel("Metric", fontweight="bold")
    axes[1].set_title("All Classes - Binned Performance Heatmap", fontweight="bold")

    plt.suptitle("Performance Heatmaps", fontweight="bold", y=1.02)
    plt.tight_layout()
    plt.show()

# Figure 4: Clustermap for Pattern Discovery
if len(class_df) <= 100:  # Only for manageable number of classes
    # Prepare data for clustering
    cluster_data = class_df.set_index("class")[["precision", "recall", "f1"]]

    # Create clustermap
    plt.figure(figsize=(10, 12))
    g = sns.clustermap(
        cluster_data,
        cmap="RdYlGn",
        vmin=0,
        vmax=1,
        col_cluster=False,
        linewidths=0.5,
        cbar_kws={"label": "Score"},
        figsize=(10, 12),
    )
    g.ax_heatmap.set_xlabel("Metric", fontweight="bold")
    g.ax_heatmap.set_title(
        "Hierarchical Clustering of Class Performance", fontweight="bold", pad=20
    )
    plt.show()

# Figure 5: Pairplot for Relationships
if len(class_df) <= 50:  # Only for small number of classes
    # Create pairplot
    pair_vars = ["precision", "recall", "f1", "support"]
    g = sns.pairplot(
        class_df[pair_vars], diag_kind="kde", plot_kws={"alpha": 0.6}, height=2.5
    )
    g.fig.suptitle("Metric Relationships", fontweight="bold", y=1.02)
    plt.show()

# Figure 6: Strip/Swarm Plot for All Classes
if len(class_df) > 30:
    fig, ax = plt.subplots(figsize=(14, 6))

    # Create strip plot with jitter
    class_df_melted_subset = class_df_melted[
        class_df_melted["metric"].isin(["precision", "recall", "f1"])
    ]
    sns.stripplot(
        data=class_df_melted_subset,
        x="score",
        y="metric",
        hue="metric",
        palette="Set2",
        size=3,
        alpha=0.6,
        ax=ax,
    )

    # Add violin plot overlay
    sns.violinplot(
        data=class_df_melted_subset,
        x="score",
        y="metric",
        palette="Set2",
        inner=None,
        alpha=0.3,
        ax=ax,
    )

    ax.set_xlabel("Score", fontweight="bold")
    ax.set_ylabel("Metric", fontweight="bold")
    ax.set_title(
        f"Score Distribution for All {len(class_df)} Classes", fontweight="bold"
    )
    ax.legend_.remove()
    ax.set_xlim([0, 1])

    plt.tight_layout()
    plt.show()

print("\n✅ Seaborn visualizations complete!")

## Error Analysis

In [None]:
# Find misclassifications
errors = df[df[PRED_COLUMN] != df[TRUTH_COLUMN]].copy()

print(f"Total errors: {len(errors):,} ({len(errors) / len(df) * 100:.2f}%)")
print(
    f"Total correct: {len(df) - len(errors):,} ({(len(df) - len(errors)) / len(df) * 100:.2f}%)"
)

if len(errors) > 0:
    # Create error pattern
    errors["error_pattern"] = (
        errors[TRUTH_COLUMN].astype(str) + " → " + errors[PRED_COLUMN].astype(str)
    )

    # Top error patterns
    error_counts = errors["error_pattern"].value_counts().head(10)

    print("\nTop 10 Most Common Misclassifications:")
    print("=" * 50)
    for pattern, count in error_counts.items():
        percentage = count / len(errors) * 100
        print(f"{pattern:<30} : {count:>5} ({percentage:>5.1f}%)")

    # Plot error patterns
    plt.figure(figsize=(12, 6))
    error_counts.plot(kind="barh", color="salmon")
    plt.title("Top 10 Error Patterns", fontsize=12, fontweight="bold")
    plt.xlabel("Count")
    plt.ylabel("Error Pattern (Actual → Predicted)")
    plt.grid(axis="x", alpha=0.3)
    plt.tight_layout()

    if SAVE_FIGURES:
        fig_path = OUTPUT_DIR / "error_patterns.png"
        plt.savefig(fig_path, dpi=FIGURE_DPI, bbox_inches="tight")
        print(f"\n💾 Figure saved to: {fig_path}")

    plt.show()

## Export Results

In [None]:
# Save metrics to CSV
metrics_df = pd.DataFrame([metrics])
metrics_df.to_csv(BASE_DIR / "evaluation_metrics.csv", index=False)
print(f"Metrics saved to: {BASE_DIR / 'evaluation_metrics.csv'}")

# Save per-class metrics
class_df.to_csv(BASE_DIR / "per_class_metrics.csv", index=False)
print(f"Per-class metrics saved to: {BASE_DIR / 'per_class_metrics.csv'}")

# Save confusion matrix (with valid classes only)
cm_df = pd.DataFrame(cm, index=valid_classes, columns=valid_classes)
cm_df.to_csv(BASE_DIR / "confusion_matrix.csv")
print(f"Confusion matrix saved to: {BASE_DIR / 'confusion_matrix.csv'}")

print("\n✅ All results exported successfully!")