# Expert Density Analysis for Federated Mixture of Experts

This notebook analyzes the impact of expert density (number of local experts per client) on model performance in federated learning scenarios with Mixture of Experts (MoE) architectures.

## Objectives

1. **Data Collection**: Download metrics from WandB for different expert density configurations
2. **Performance Analysis**: Compare perplexity vs. token consumption across configurations
3. **Efficiency Evaluation**: Assess the trade-offs between expert density and training efficiency
4. **Configuration Impact**: Understand how overlapping factors and client counts affect performance

## Key Metrics

- **Expert Density**: Number of local experts per client
- **Overlapping Factor**: Factor determining expert sharing across clients
- **Final Perplexity**: Model performance at the end of training
- **Total Tokens**: Computational cost measure
- **Efficiency**: Ratio of final perplexity to total tokens consumed

## Analysis Workflow

1. Load and filter runs from WandB based on run UUID patterns
2. Extract configuration parameters for each run
3. Download server and client metrics data
4. Compute perplexity vs. token relationships
5. Aggregate results by configuration parameters
6. Generate visualizations and summary statistics

In [None]:
import json
import logging
import operator
from datetime import UTC, datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb

from fedmoe_plots.data_analysis import (
    ColumnNotFoundError,
    get_device_throughput_series,
    get_perplexity_versus_tokens,
)
from fedmoe_plots.plotting_utils import configure_logging_for_jupyter
from fedmoe_plots.wandb_utils import (
    ClientRunNotFoundError,
    download_photon_metrics,
    get_clientrun_property_from_config,
    get_experts_global_batch_size,
    get_n_local_experts,
    get_non_experts_global_batch_size,
    get_run_uuid_from_config,
    remove_runs_by_regex,
)

configure_logging_for_jupyter()

log = logging.getLogger("experts_density.ipynb")

EXCLUDE_INCOMPLETE_RUNS = True
BASE_BATCH_SIZE = 512

In [None]:
run_uuid_1 = "4e_rho_bspol_scratch"
run_uuid_2 = "8e_rho_bspol_scratch"
api = wandb.Api(timeout=100)
runs = api.runs(
    path="camlsys/photon",
    filters={"display_name": {"$regex": f"(^{run_uuid_1})|(^{run_uuid_2})"}},
)

In [None]:
for run in runs:
    log.info("Run name: %s, ID: %s, State: %s", run.name, run.id, run.state)

In [None]:
unique_n_total_clients = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["fl"]["n_total_clients"],
    )
    for run in runs
}
unique_local_batch_size = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["global_train_batch_size"],
    )
    for run in runs
}
unique_overlapping_factor = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["fl"]["experts_overlapping_factor"],
    )
    for run in runs
}
unique_n_total_experts = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["model"]["ffn_config"][
            "ff_n_experts"
        ],
    )
    for run in runs
}
unique_n_local_experts = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=get_n_local_experts,
    )
    for run in runs
}
log.info(
    "Unique n_total_clients: %s, unique_local_batch_size: %s, "
    "unique_overlapping_factor: %s, unique_n_total_experts: %s, "
    "unique_n_local_experts: %s",
    unique_n_total_clients,
    unique_local_batch_size,
    unique_overlapping_factor,
    unique_n_total_experts,
    unique_n_local_experts,
)

In [None]:
# Data collection and processing for expert density analysis

log.info("🔄 Starting data collection and processing...")
unique_run_uuids = {
    str(
        get_clientrun_property_from_config(
            run,
            get_property_fn=operator.itemgetter("run_uuid"),
        ),
    )
    for run in runs
}
log.info("Found %s runs to process", len(unique_run_uuids))
results_list = []

for i, unique_run_uuid in enumerate(unique_run_uuids):
    log.info(
        "📊 Processing run %s/%s: %s",
        i + 1,
        len(unique_run_uuids),
        unique_run_uuid,
    )

    run: wandb.apis.public.Run | None = (  # pyright: ignore[reportAttributeAccessIssue]
        None
    )
    try:
        # Get any run that matches the unique run UUID
        run = next(r for r in runs if get_run_uuid_from_config(r) == unique_run_uuid)
        assert run is not None, f"Run with UUID {unique_run_uuid} not found"

        # Extract configuration parameters
        config = run.config
        run_uuid = get_clientrun_property_from_config(
            run,
            get_property_fn=operator.itemgetter("run_uuid"),
        )

        n_total_clients = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["fl"]["n_total_clients"],
        )

        n_total_experts = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["model"]["ffn_config"][
                "ff_n_experts"
            ],
        )

        overlapping_factor = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["fl"]["experts_overlapping_factor"],
        )

        global_train_batch_size = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"][
                "global_train_batch_size"
            ],
        )

        # Calculate derived metrics
        n_local_experts = get_n_local_experts(config)
        experts_global_batch_size = get_experts_global_batch_size(config)
        non_experts_global_batch_size = get_non_experts_global_batch_size(config)

        log.info(
            "📋 Configuration: TE=%s, LE=%s, Clients=%s, OF=%s",
            n_total_experts,
            n_local_experts,
            n_total_clients,
            overlapping_factor,
        )

        # Download and process data
        log.info("🔽 Downloading data for %s...", run_uuid)

        try:
            # Download photon metrics for the run
            assert run_uuid is not None, "Run UUID must not be None"
            assert isinstance(
                run_uuid,
                str,
            ), f"Run UUID must be a string not a {type(run_uuid)}"
            _s_df, clients_df = download_photon_metrics(
                base_name="camlsys/photon",
                run_uuid=run_uuid,
            )

            # Try to get perplexity vs tokens data
            assert n_total_clients is not None, "Total clients must not be None"
            assert isinstance(
                n_total_clients,
                int,
            ), f"Total clients must be an integer not a {type(n_total_clients)}"
            result = get_perplexity_versus_tokens(
                client_metrics_df=clients_df,
                n_clients_per_round=n_total_clients,
            )

            tokens, perplexity = result

            if len(tokens) == 0 or len(perplexity) == 0:
                log.warning("❌ Empty data arrays for %s", run_uuid)
                continue

            # Try to get the throughput data
            steps, throughput = get_device_throughput_series(
                client_metrics_df=clients_df,
                moving_window=10,
            )

            # Calculate metrics
            final_perplexity = (
                perplexity.iloc[-1] if len(perplexity) > 0 else float("nan")
            )
            total_tokens = tokens.iloc[-1] if len(tokens) > 0 else 0
            n_data_points = len(tokens)

            log.info(
                (
                    "   ✅ Data processed: %d points,"
                    " Final perplexity: %.4f, Total tokens: %.0f"
                ),
                n_data_points,
                final_perplexity,
                total_tokens,
            )

            # Store results
            results_list.append(
                {
                    "run_uuid": str(run_uuid),
                    "n_total_clients": n_total_clients,
                    "n_total_experts": n_total_experts,
                    "n_local_experts": n_local_experts,
                    "overlapping_factor": overlapping_factor,
                    "global_train_batch_size": global_train_batch_size,
                    "experts_global_batch_size": experts_global_batch_size,
                    "non_experts_global_batch_size": non_experts_global_batch_size,
                    "final_perplexity": final_perplexity,
                    "total_tokens": total_tokens,
                    "n_data_points": n_data_points,
                    "tokens": (
                        tokens.tolist() if hasattr(tokens, "tolist") else list(tokens)
                    ),
                    "perplexity": (
                        perplexity.tolist()
                        if hasattr(perplexity, "tolist")
                        else list(perplexity)
                    ),
                    "steps": (
                        steps.tolist() if hasattr(steps, "tolist") else list(steps)
                    ),
                    "throughput": (
                        throughput.tolist()
                        if hasattr(throughput, "tolist")
                        else list(throughput)
                    ),
                },
            )

        except ClientRunNotFoundError:
            # Log exception with stack trace
            log.exception(
                "   ⚠️  Client run not found for %s",
                run_uuid,
                stack_info=True,
            )

            # Remove this run from WandB runs to avoid further processing
            remove_runs_by_regex("camlsys/photon", f"^{run_uuid}*")
            continue

        except ColumnNotFoundError:
            log.exception(
                "   ⚠️ Column not found in client metrics DataFrame for %s. "
                "We assume this run crashed and doesn't have the expected data.",
                run_uuid,
                stack_info=True,
            )

            # Remove this run from WandB runs to avoid further processing
            remove_runs_by_regex("camlsys/photon", f"^{run_uuid}*")
            continue

        except Exception:
            log.exception("   ❌ Error processing %s", run_uuid, stack_info=True)
            continue

    except Exception:
        assert run is not None, "Run must not be None"
        log.exception(
            "   ❌ Error extracting config for run %s",
            run.name,
            stack_info=True,
        )
        continue

log.info("\n✅ Data collection completed!")
log.info(
    "Successfully processed %d out of %d runs",
    len(results_list),
    len(unique_run_uuids),
)

In [None]:
# Enhanced summary with configuration analysis
log.info("\n=== Summary ===")
log.info(
    "Successfully processed %d out of %d runs",
    len(results_list),
    len(unique_run_uuids),
)

# Analyze planned configurations if we have results
if results_list:
    log.info("\n📊 CONFIGURATION ANALYSIS 📊")
    log.info("=" * 60)

    # Group by n_total_experts to analyze configurations
    experts_groups = {}
    for result in results_list:
        n_experts = result["n_total_experts"]
        if n_experts not in experts_groups:
            experts_groups[n_experts] = []
        experts_groups[n_experts].append(result)

    for n_experts in sorted(experts_groups.keys()):
        runs_for_experts = experts_groups[n_experts]
        log.info("\n🔢 Total Experts: %d", n_experts)
        log.info("   Number of runs: %d", len(runs_for_experts))

        # Collect the theoretical configurations
        theoretical_configs: set[tuple[int, int, int]] = set()
        for i in range(1, n_experts + 1):
            # If i is not a power of 2, skip it
            if (i & (i - 1)) != 0:
                continue
            for k in range(i, n_experts + 1):
                # If k is not a power of 2, skip it
                if (k & (k - 1)) != 0:
                    continue
                theoretical_configs.update(
                    (i, k, local_batch_size)
                    for local_batch_size in {
                        BASE_BATCH_SIZE // i,
                        BASE_BATCH_SIZE // k,
                    }
                )

        # Analyze the configuration space for this expert count
        actual_configs: set[tuple[int, int, int]] = set()
        actual_configs.update(
            (
                run["overlapping_factor"],
                run["n_total_clients"],
                run["global_train_batch_size"],
            )
            for run in runs_for_experts
        )

        # Calculate theoretical vs actual configurations
        log.info("   Theoretical configurations: %d", len(theoretical_configs))
        log.info("   Actual runs: %d", len(actual_configs))

        if len(actual_configs) < len(theoretical_configs):
            coverage = (len(actual_configs) / len(theoretical_configs)) * 100
            log.info("   Coverage: %.1f%% ⚠️", coverage)
        else:
            log.info("   Coverage: 100%% ✅")

        # Show detailed configuration breakdown by showing first those that are missing
        # and then those that are present
        missing_configs = theoretical_configs - actual_configs
        if missing_configs:
            log.warning("   ⚠️ Missing configurations (%d):", len(missing_configs))
            for config in sorted(missing_configs):
                log.warning(
                    "      • Overlapping Factor: %.1f, Clients: %d, "
                    "Local Batch Size: %d",
                    config[0],
                    config[1],
                    config[2],
                )
        else:
            log.info("   All theoretical configurations are present")

# Check for runs that didn't reach 1 billion tokens and log warnings
incomplete_runs = []
complete_runs = []
if results_list:
    log.info("\n⚠️  TOKEN COUNT ANALYSIS ⚠️")
    log.info("=" * 60)

    billion_tokens = 1e9

    for result in results_list:
        if result["total_tokens"] < billion_tokens:
            incomplete_runs.append(result)
        else:
            complete_runs.append(result)

    if incomplete_runs:
        log.warning(
            "🔴 WARNING: %d run(s) did NOT reach 1 billion tokens:",
            len(incomplete_runs),
        )
        log.info("-" * 60)

        for i, result in enumerate(incomplete_runs, 1):
            log.info("\n%s. Run UUID: %s", i, result["run_uuid"])
            log.info(
                "   Total Tokens: %s (%.3fB)",
                format(result["total_tokens"], ","),
                result["total_tokens"] / 1e9,
            )
            log.info(
                "   Completion: %.1f%% of 1B tokens",
                result["total_tokens"] / billion_tokens * 100,
            )
            log.info("   📋 Full Configuration:")
            log.info("      • Total Experts: %s", result["n_total_experts"])
            log.info("      • Local Experts: %s", result["n_local_experts"])
            log.info("      • Total Clients: %s", result["n_total_clients"])
            log.info("      • Overlapping Factor: %.1f", result["overlapping_factor"])
            log.info(
                "      • Expert Global Batch Size: %s",
                result["experts_global_batch_size"],
            )
            log.info(
                "      • Non-Expert Global Batch Size: %s",
                result["non_experts_global_batch_size"],
            )
            log.info("      • Local Batch Size: %s", result["global_train_batch_size"])
            log.info("      • Final Perplexity: %.4f", result["final_perplexity"])
            log.info("      • Data Points: %s", result["n_data_points"])

    if complete_runs:
        log.info(
            "✅ %d run(s) successfully reached 1+ billion tokens:",
            len(complete_runs),
        )
        for result in complete_runs:
            log.info(
                "   • %s: %d tokens (%.3fB)",
                result["run_uuid"],
                result["total_tokens"],
                result["total_tokens"] / 1e9,
            )

    log.info("\n📊 Token Count Summary:")
    log.info("   • Complete runs (≥1B tokens): %d", len(complete_runs))
    log.info("   • Incomplete runs (<1B tokens): %d", len(incomplete_runs))
    if results_list:
        avg_tokens = sum(r["total_tokens"] for r in results_list) / len(results_list)
        log.info(
            "   • Average tokens across all runs: %.0f (%.3fB)",
            avg_tokens,
            avg_tokens / 1e9,
        )

# Convert to DataFrame for easier analysis
this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs
if this_cell_results_list:
    results_df = pd.DataFrame(this_cell_results_list)
else:
    log.info("No results to analyze")
    results_df = pd.DataFrame()

In [None]:
# Dump complete run UUIDs to file

if complete_runs:
    # Extract just the UUIDs from complete runs
    complete_run_uuids = [run["run_uuid"] for run in complete_runs]

    # Create output filename with timestamp
    timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S")
    output_file = Path(f"complete_run_uuids_{timestamp}.json")

    # Save to JSON file for easy reading
    output_data = {
        "timestamp": timestamp,
        "total_complete_runs": len(complete_run_uuids),
        "total_runs_analyzed": len(results_list),
        "complete_run_uuids": complete_run_uuids,
        "run_details": [
            {
                "run_uuid": run["run_uuid"],
                "n_total_experts": run["n_total_experts"],
                "n_local_experts": run["n_local_experts"],
                "overlapping_factor": run["overlapping_factor"],
                "final_perplexity": run["final_perplexity"],
                "total_tokens": run["total_tokens"],
            }
            for run in complete_runs
        ],
    }

    with output_file.open("w", encoding="utf-8") as f:
        json.dump(output_data, f, indent=2)

    log.info("✅ Complete run UUIDs saved to: %s", output_file)
    log.info("📊 Summary:")
    log.info("   - Total complete runs: %d", len(complete_run_uuids))
    log.info("   - Total runs analyzed: %d", len(results_list))
    log.info("   - Complete run UUIDs with total experts:")
    for i, run in enumerate(complete_runs, 1):
        log.info("     %d. %s (TE: %d)", i, run["run_uuid"], run["n_total_experts"])

    # Also save a simple text file with just the UUIDs and total experts (one per line)
    txt_file = Path(f"complete_run_uuids_{timestamp}.txt")
    with txt_file.open("w", encoding="utf-8") as f:
        f.write("# Complete Run UUIDs with Total Experts\n")
        f.write("# Format: run_uuid,n_total_experts\n")
        for run in complete_runs:
            f.write(f"{run['run_uuid']},{run['n_total_experts']}\n")

    log.info("📝 Also saved text version with total experts to: %s", txt_file)

else:
    log.warning("❌ No complete runs found to save")
    log.info("   All %d runs are incomplete (< 1B tokens)", len(results_list))

In [None]:
# Enhanced Expert Density Analysis with Independent, High-Quality Plots
log.info("Generating enhanced Expert Density Analysis with improved cosmetics...")

this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    # Quick token count check before plotting
    billion_tokens = 1e9
    incomplete_runs = [
        r for r in this_cell_results_list if r["total_tokens"] < billion_tokens
    ]

    if incomplete_runs:
        log.warning(
            "\n ATTENTION: %s of %s runs didn't reach 1B tokens!",
            len(incomplete_runs),
            len(this_cell_results_list),
        )
        log.info("   This may affect training convergence analysis.\n")

    # Set up enhanced plotting style
    plt.style.use("seaborn-v0_8-whitegrid")
    colors = plt.cm.Set1(  # pyright: ignore[reportAttributeAccessIssue]
        np.linspace(0, 1, len(this_cell_results_list)),
    )

    # ==== PLOT 1: Training Progress - Perplexity vs Total Tokens ====
    plt.figure(figsize=(14, 8))

    for i, result in enumerate(this_cell_results_list):
        tokens = np.array(result["tokens"])
        perplexity = np.array(result["perplexity"])

        # Enhanced legend label with key information
        incomplete_indicator = " ⚠️" if result["total_tokens"] < billion_tokens else ""
        legend_label = (
            f"TE: {result['n_total_experts']}, LE: {result['n_local_experts']}, "
            f"TC: {result['n_total_clients']}, OF: {result['overlapping_factor']:.1f}, "
            f"EgBS: {result['experts_global_batch_size']},"
            f" nEgBS: {result['non_experts_global_batch_size']}{incomplete_indicator}"
        )

        # Use different line style for incomplete runs
        linestyle = "--" if result["total_tokens"] < billion_tokens else "-"
        alpha = 0.7 if result["total_tokens"] < billion_tokens else 0.9
        linewidth = 2.0 if result["total_tokens"] < billion_tokens else 2.5

        plt.plot(
            tokens,
            perplexity,
            label=legend_label,
            linewidth=linewidth,
            alpha=alpha,
            color=colors[i],
            linestyle=linestyle,
            marker="o" if i < 3 else ("s" if i < 6 else "^"),
            markersize=4,
            markevery=max(1, len(tokens) // 15),
            markerfacecolor="white",
            markeredgewidth=1.5,
            markeredgecolor=colors[i],
        )

    plt.xlabel("Total Tokens", fontsize=14, fontweight="bold")
    plt.ylabel("Language Perplexity", fontsize=14, fontweight="bold")
    plt.title(
        (
            "Training Progress: Perplexity vs Total Tokens\n"
            "Expert Density Analysis with Batch Size Configurations"
        ),
        fontsize=16,
        fontweight="bold",
        pad=20,
    )
    plt.yscale("log")
    plt.legend(
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=1 if len(this_cell_results_list) <= 4 else 2,
        fontsize=10,
        loc="upper right",
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
# Enhanced Expert Density Analysis with Independent, High-Quality Plots
log.info("Generating enhanced Expert Density Analysis with improved cosmetics...")

this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    # Quick token count check before plotting
    billion_tokens = 1e9
    incomplete_runs = [
        r for r in this_cell_results_list if r["total_tokens"] < billion_tokens
    ]

    if incomplete_runs:
        log.warning(
            "\n ATTENTION: %s of %s runs didn't reach 1B tokens!",
            len(incomplete_runs),
            len(this_cell_results_list),
        )
        log.info("   This may affect training convergence analysis.\n")

    # Set up enhanced plotting style
    plt.style.use("seaborn-v0_8-whitegrid")
    colors = plt.cm.Set1(  # pyright: ignore[reportAttributeAccessIssue]
        np.linspace(0, 1, len(this_cell_results_list)),
    )
    # ==== PLOT 2: Final Performance Scatter ====
    plt.figure(figsize=(12, 8))

    # Create scatter plot with enhanced styling
    for result in this_cell_results_list:
        marker = "o" if result["total_tokens"] >= billion_tokens else "^"
        alpha = 0.8 if result["total_tokens"] >= billion_tokens else 0.6
        edge_color = "darkred" if result["total_tokens"] < billion_tokens else "black"
        edge_width = 2.5 if result["total_tokens"] < billion_tokens else 1.5

        scatter = plt.scatter(
            result["n_local_experts"],
            result["final_perplexity"],
            c=result["overlapping_factor"],
            s=200,
            alpha=alpha,
            cmap="viridis",
            edgecolors=edge_color,
            linewidth=edge_width,
            marker=marker,
            vmin=min(r["overlapping_factor"] for r in this_cell_results_list),
            vmax=max(r["overlapping_factor"] for r in this_cell_results_list),
        )

    # Enhanced annotations
    for result in this_cell_results_list:
        incomplete_indicator = " ⚠️" if result["total_tokens"] < billion_tokens else ""
        annotation_text = (
            f"TE: {result['n_total_experts']}, LE:{result['n_local_experts']}, "
            f"TC: {result['n_total_clients']}, "
            f"EBS: {result['experts_global_batch_size']},"
            f" NEBS: {result['non_experts_global_batch_size']}, "
            f"{result['total_tokens'] / 1e9:.2f}B tokens{incomplete_indicator}"
        )
        plt.annotate(
            annotation_text,
            (result["n_local_experts"], result["final_perplexity"]),
            xytext=(8, 8),
            textcoords="offset points",
            fontsize=9,
            bbox={
                "boxstyle": "round,pad=0.3",
                "facecolor": "white",
                "alpha": 0.8,
                "edgecolor": "gray",
            },
            arrowprops={
                "arrowstyle": "->",
                "connectionstyle": "arc3,rad=0.1",
                "alpha": 0.6,
            },
        )

    plt.xlabel("Number of Local Experts", fontsize=14, fontweight="bold")
    plt.ylabel("Final Perplexity", fontsize=14, fontweight="bold")
    plt.title(
        (
            "Final Performance vs Expert Configuration\n"
            "Color = Overlapping Factor | Red edges = Incomplete runs (<1B tokens)"
        ),
        fontsize=16,
        fontweight="bold",
        pad=20,
    )

    cbar = plt.colorbar(scatter, shrink=0.8, aspect=20)
    cbar.set_label("Overlapping Factor", fontsize=12, fontweight="bold")
    cbar.ax.tick_params(labelsize=10)

    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
# Enhanced Expert Density Analysis with Independent, High-Quality Plots
log.info("Generating enhanced Expert Density Analysis with improved cosmetics...")

this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    # Quick token count check before plotting
    billion_tokens = 1e9
    incomplete_runs = [
        r for r in this_cell_results_list if r["total_tokens"] < billion_tokens
    ]

    if incomplete_runs:
        log.warning(
            "\n ATTENTION: %s of %s runs didn't reach 1B tokens!",
            len(incomplete_runs),
            len(this_cell_results_list),
        )
        log.info("   This may affect training convergence analysis.\n")

    # Set up enhanced plotting style
    plt.style.use("seaborn-v0_8-whitegrid")
    colors = plt.cm.Set1(  # pyright: ignore[reportAttributeAccessIssue]
        np.linspace(0, 1, len(this_cell_results_list)),
    )
    # ==== PLOT 3: Training Efficiency Analysis ====
    plt.figure(figsize=(12, 8))

    for i, result in enumerate(this_cell_results_list):
        incomplete_indicator = " ⚠️" if result["total_tokens"] < billion_tokens else ""
        alpha = 0.7 if result["total_tokens"] < billion_tokens else 0.8
        edge_color = "darkred" if result["total_tokens"] < billion_tokens else "black"
        edge_width = 2.5 if result["total_tokens"] < billion_tokens else 1.5

        plt.scatter(
            result["total_tokens"] / 1e9,  # Convert to billions for readability
            result["final_perplexity"],
            s=result["n_local_experts"] * 80
            + 120,  # Size proportional to local experts
            alpha=alpha,
            color=colors[i],
            edgecolors=edge_color,
            linewidth=edge_width,
            label=(
                f"TE: {result['n_total_experts']}, "
                f"LE: {result['n_local_experts']}, "
                f"TC: {result['n_total_clients']}, "
                f"OF: {result['overlapping_factor']:.1f}"
                f"{incomplete_indicator}"
            ),
        )

    plt.xlabel("Total Tokens Consumed (Billions)", fontsize=14, fontweight="bold")
    plt.ylabel("Final Perplexity", fontsize=14, fontweight="bold")
    plt.title(
        (
            "Training Efficiency Analysis\n"
            "Marker Size ∝ Local Experts | Red edges = Incomplete runs"
        ),
        fontsize=16,
        fontweight="bold",
        pad=20,
    )

    plt.legend(
        title=(
            "TE: Total Experts, LE: Local Experts, TC: Total Clients,"
            " OF: Overlap Factor"
            "\n⚠️ = Incomplete runs (<1B tokens)"
        ),
        frameon=True,
        fancybox=True,
        shadow=True,
        fontsize=10,
        title_fontsize=11,
        loc="upper left",
        bbox_to_anchor=(1.02, 1),
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5)
    plt.tight_layout()
    plt.show()

    # ==== PLOT 4: Convergence Rate Comparison ====
    plt.figure(figsize=(14, 8))

    for i, result in enumerate(this_cell_results_list):
        tokens = np.array(result["tokens"])
        perplexity = np.array(result["perplexity"])

        if len(tokens) > 1:
            tokens_norm = (tokens - tokens.min()) / (tokens.max() - tokens.min())
            incomplete_indicator = (
                " ⚠️" if result["total_tokens"] < billion_tokens else ""
            )
            linestyle = "--" if result["total_tokens"] < billion_tokens else "-"
            alpha = 0.7 if result["total_tokens"] < billion_tokens else 0.9
            linewidth = 2.0 if result["total_tokens"] < billion_tokens else 2.5

            plt.plot(
                tokens_norm,
                perplexity,
                label=(
                    f"TE: {result['n_total_experts']}, "
                    f"LE: {result['n_local_experts']}, "
                    f"TC: {result['n_total_clients']}, "
                    f"OF: {result['overlapping_factor']:.1f}"
                    f"{incomplete_indicator}"
                ),
                linewidth=linewidth,
                alpha=alpha,
                linestyle=linestyle,
                color=colors[i],
                marker="o" if i < 3 else ("s" if i < 6 else "^"),
                markersize=4,
                markevery=max(1, len(tokens_norm) // 20),
                markerfacecolor="white",
                markeredgewidth=1.5,
                markeredgecolor=colors[i],
            )

    plt.xlabel(
        "Normalized Training Progress (0 = Start, 1 = End)",
        fontsize=14,
        fontweight="bold",
    )
    plt.ylabel("Language Perplexity", fontsize=14, fontweight="bold")
    plt.title(
        (
            "Convergence Rate Comparison\n"
            "Normalized Training Progress with Expert Configurations"
        ),
        fontsize=16,
        fontweight="bold",
        pad=20,
    )
    plt.yscale("log")
    plt.legend(
        title=(
            "TE: Total Experts, LE: Local Experts, TC: Total Clients,"
            " OF: Overlap Factor"
            "\n ⚠️ = Incomplete runs (<1B tokens)"
        ),
        frameon=True,
        fancybox=True,
        shadow=True,
        fontsize=10,
        title_fontsize=11,
        ncol=1 if len(this_cell_results_list) <= 4 else 2,
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5)
    plt.tight_layout()
    plt.show()

    # ==== PLOT 5: Configuration Summary Heatmap ====
    plt.figure(figsize=(10, 6))

    # Create a summary matrix for visualization
    config_data = []
    config_labels = []

    for result in this_cell_results_list:
        config_data.append(
            [
                result["n_total_experts"],
                result["n_local_experts"],
                result["overlapping_factor"],
                result["experts_global_batch_size"],
                result["non_experts_global_batch_size"],
                result["final_perplexity"],
                result["total_tokens"] / 1e9,
            ],
        )
        incomplete_mark = "⚠️" if result["total_tokens"] < billion_tokens else ""
        config_labels.append(f"{result['run_uuid'][:8]}...{incomplete_mark}")

    config_matrix = np.array(config_data)

    # Normalize each column to [0, 1] for better heatmap visualization
    config_matrix_norm = np.zeros_like(config_matrix)
    for i in range(config_matrix.shape[1]):
        col_min, col_max = config_matrix[:, i].min(), config_matrix[:, i].max()
        if col_max > col_min:
            config_matrix_norm[:, i] = (config_matrix[:, i] - col_min) / (
                col_max - col_min
            )
        else:
            config_matrix_norm[:, i] = 0.5  # If all values are the same

    im = plt.imshow(config_matrix_norm.T, cmap="RdYlBu_r", aspect="auto", alpha=0.8)

    # Set labels
    plt.xticks(range(len(config_labels)), config_labels, rotation=45, ha="right")
    plt.yticks(
        range(len(config_data[0])),
        [
            "Total Experts",
            "Local Experts",
            "Overlap Factor",
            "Expert Global BS",
            "Non-Expert Global BS",
            "Final Perplexity",
            "Total Tokens (B)",
        ],
    )

    # Add text annotations with actual values
    for i in range(len(config_labels)):
        for j in range(len(config_data[0])):
            if j < 5:  # Integer values
                text = f"{int(config_matrix[i, j])}"
            else:  # Float values
                text = f"{config_matrix[i, j]:.2f}"
            plt.text(
                i,
                j,
                text,
                ha="center",
                va="center",
                color="white" if config_matrix_norm[i, j] > 0.5 else "black",
                fontweight="bold",
                fontsize=9,
            )

    plt.title(
        (
            "Configuration Summary Heatmap\n"
            "Normalized values with actual values overlaid"
        ),
        fontsize=16,
        fontweight="bold",
        pad=20,
    )

    cbar = plt.colorbar(im, shrink=0.8, aspect=20)
    cbar.set_label(
        "Normalized Value (0 = Min, 1 = Max)",
        fontsize=12,
        fontweight="bold",
    )

    plt.tight_layout()
    plt.show()

    log.info("✅ Enhanced analysis with improved cosmetics completed!")
    log.info("📊 Generated 5 independent plots:")
    log.info("   1. Training Progress (Perplexity vs Tokens)")
    log.info("   2. Final Performance Scatter")
    log.info("   3. Training Efficiency Analysis")
    log.info("   4. Convergence Rate Comparison")
    log.info("   5. Configuration Summary Heatmap")

    if incomplete_runs:
        log.warning(
            "⚠️  Note: %s run(s) marked as incomplete in all plots",
            len(incomplete_runs),
        )

else:
    log.info("❌ No results available for analysis")

In [None]:
# Enhanced Expert Density Analysis - Throughput vs Steps Analysis
log.info("Generating Throughput vs Steps Analysis with focus on Local Experts...")

this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    # Quick token count check before plotting
    billion_tokens = 1e9
    incomplete_runs = [
        r for r in this_cell_results_list if r["total_tokens"] < billion_tokens
    ]

    if incomplete_runs:
        log.warning(
            "\n ATTENTION: %s of %s runs didn't reach 1B tokens!",
            len(incomplete_runs),
            len(this_cell_results_list),
        )
        log.info("   This may affect throughput analysis.\n")

    # Set up enhanced plotting style with consistent colors for local experts
    plt.style.use("seaborn-v0_8-whitegrid")

    # Create color mapping based on unique local experts
    unique_local_experts = sorted(
        {r["n_local_experts"] for r in this_cell_results_list},
    )
    color_map = plt.cm.Set1(  # pyright: ignore[reportAttributeAccessIssue]
        np.linspace(0, 1, len(unique_local_experts)),
    )
    local_experts_color_dict = {
        le: color_map[i] for i, le in enumerate(unique_local_experts)
    }

    log.info(
        "Found %s unique local expert configurations: %s",
        len(unique_local_experts),
        unique_local_experts,
    )

    # ==== PLOT 6: Enhanced Throughput Analysis focused on Local Experts ====
    plt.figure(figsize=(18, 12))

    # Subplot 1: Raw throughput over training steps (colored by local experts)
    plt.subplot(2, 3, 1)
    for result in this_cell_results_list:
        steps = np.array(result["steps"])
        throughput = np.array(result["throughput"])

        if len(steps) > 0 and len(throughput) > 0:
            # Color based on local experts
            color = local_experts_color_dict[result["n_local_experts"]]

            # Enhanced legend label with key information
            incomplete_indicator = (
                " ⚠️" if result["total_tokens"] < billion_tokens else ""
            )
            legend_label = (
                f"LE: {result['n_local_experts']}, "
                f"TE: {result['n_total_experts']}, "
                f"TC: {result['n_total_clients']}, "
                f"OF: {result['overlapping_factor']:.1f}"
                f"{incomplete_indicator}"
            )

            # Use different line style for incomplete runs
            linestyle = "--" if result["total_tokens"] < billion_tokens else "-"
            alpha = 0.7 if result["total_tokens"] < billion_tokens else 0.9
            linewidth = 2.0 if result["total_tokens"] < billion_tokens else 2.5

            plt.plot(
                steps,
                throughput,
                label=legend_label,
                linewidth=linewidth,
                alpha=alpha,
                color=color,
                linestyle=linestyle,
                marker=(
                    "o"
                    if result["n_local_experts"] == min(unique_local_experts)
                    else (
                        "s"
                        if result["n_local_experts"] == max(unique_local_experts)
                        else "^"
                    )
                ),
                markersize=3,
                markevery=max(1, len(steps) // 20),
                markerfacecolor="white",
                markeredgewidth=1.2,
                markeredgecolor=color,
            )

    plt.xlabel("Training Steps", fontsize=12, fontweight="bold")
    plt.ylabel("Device Throughput (tokens/s)", fontsize=12, fontweight="bold")
    plt.title(
        "Device Throughput vs Training Steps\nGrouped by Number of Local Experts",
        fontsize=14,
        fontweight="bold",
        pad=15,
    )
    plt.legend(
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=1,
        fontsize=9,
        loc="best",
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5)

    # Calculate statistics by local experts
    throughput_stats = {}
    for result in this_cell_results_list:
        le = result["n_local_experts"]
        throughput = np.array(result["throughput"])
        if len(throughput) > 0:
            avg_throughput = np.mean(throughput)
            if le not in throughput_stats:
                throughput_stats[le] = []
            throughput_stats[le].append(avg_throughput)

    # Calculate mean and std for each local expert configuration
    local_experts_summary = {}
    for le, throughputs in throughput_stats.items():
        local_experts_summary[le] = {
            "mean": np.mean(throughputs),
            "std": np.std(throughputs),
            "count": len(throughputs),
            "values": throughputs,
        }

    # Subplot 2: Average throughput by local experts with error bars
    plt.subplot(2, 3, 2)
    local_experts_list = sorted(local_experts_summary.keys())
    means = [local_experts_summary[le]["mean"] for le in local_experts_list]
    stds = [local_experts_summary[le]["std"] for le in local_experts_list]
    colors = [local_experts_color_dict[le] for le in local_experts_list]

    bars = plt.bar(
        [int(le) for le in local_experts_list],
        means,
        yerr=stds,
        color=colors,
        alpha=0.8,
        edgecolor="black",
        linewidth=1.2,
        capsize=5,
        error_kw={"linewidth": 2, "ecolor": "black"},
    )

    # Add value labels on bars
    for i, (le, mean, std) in enumerate(
        zip(local_experts_list, means, stds, strict=True),
    ):
        count = local_experts_summary[le]["count"]
        plt.text(
            i,
            mean + std + max(means) * 0.02,
            f"{mean:.1f}±{std:.1f}\n(n={count})",
            ha="center",
            va="bottom",
            fontweight="bold",
            fontsize=10,
        )

    plt.xlabel("Number of Local Experts", fontsize=12, fontweight="bold")
    plt.ylabel("Average Throughput (tokens/s)", fontsize=12, fontweight="bold")
    plt.title(
        "Average Throughput by Local Experts\nError bars show ±1 standard deviation",
        fontsize=14,
        fontweight="bold",
        pad=15,
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5, axis="y")

    # Subplot 3: Detailed scatter plot with individual runs
    plt.subplot(2, 3, 3)
    for result in this_cell_results_list:
        throughput = np.array(result["throughput"])
        if len(throughput) > 0:
            avg_throughput = np.mean(throughput)
            color = local_experts_color_dict[result["n_local_experts"]]

            marker = "o" if result["total_tokens"] >= billion_tokens else "^"
            alpha = 0.8 if result["total_tokens"] >= billion_tokens else 0.6
            edge_color = (
                "darkred" if result["total_tokens"] < billion_tokens else "black"
            )
            edge_width = 2.5 if result["total_tokens"] < billion_tokens else 1.5

            plt.scatter(
                result["n_local_experts"],
                avg_throughput,
                c=[color],
                s=150,
                alpha=alpha,
                edgecolors=edge_color,
                linewidth=edge_width,
                marker=marker,
            )

    plt.xlabel("Number of Local Experts", fontsize=12, fontweight="bold")
    plt.ylabel("Average Throughput (tokens/s)", fontsize=12, fontweight="bold")
    plt.title(
        "Individual Run Throughput\nGrouped by Local Experts",
        fontsize=14,
        fontweight="bold",
        pad=15,
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5)

    # Add jitter to x-axis for better visibility
    plt.xlim(min(unique_local_experts) - 0.5, max(unique_local_experts) + 0.5)

    # Subplot 4: Box plot showing distribution by local experts
    plt.subplot(2, 3, 4)
    throughput_by_le = [
        local_experts_summary[le]["values"] for le in local_experts_list
    ]

    bp = plt.boxplot(
        throughput_by_le,
        tick_labels=[str(le) for le in local_experts_list],
        patch_artist=True,
        notch=True,
        showmeans=True,
    )

    # Color the boxes according to local experts
    for patch, color in zip(bp["boxes"], colors, strict=True):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

    plt.xlabel("Number of Local Experts", fontsize=12, fontweight="bold")
    plt.ylabel("Average Throughput (tokens/s)", fontsize=12, fontweight="bold")
    plt.title(
        (
            "Throughput Distribution by Local Experts\n"
            "Box plots with mean (triangle) and median (line)"
        ),
        fontsize=14,
        fontweight="bold",
        pad=15,
    )
    plt.grid(alpha=0.3, linestyle="-", linewidth=0.5, axis="y")

    # Subplot 5: Throughput coefficient of variation by local experts
    plt.subplot(2, 3, 5)
    cv_by_le = {}
    for result in this_cell_results_list:
        le = result["n_local_experts"]
        throughput = np.array(result["throughput"])
        if len(throughput) > 1:
            cv = (
                np.std(throughput) / np.mean(throughput)
                if np.mean(throughput) > 0
                else 0
            )
            if le not in cv_by_le:
                cv_by_le[le] = []
            cv_by_le[le].append(cv)

    if cv_by_le:
        le_list = sorted(cv_by_le.keys())
        cv_means = [np.mean(cv_by_le[le]) for le in le_list]
        cv_stds = [
            np.std(cv_by_le[le]) if len(cv_by_le[le]) > 1 else 0 for le in le_list
        ]
        colors_cv = [local_experts_color_dict[le] for le in le_list]

        bars = plt.bar(
            [int(le) for le in le_list],
            cv_means,
            yerr=cv_stds,
            color=colors_cv,
            alpha=0.8,
            edgecolor="black",
            linewidth=1.2,
            capsize=5,
            error_kw={"linewidth": 2, "ecolor": "black"},
        )

        # Add value labels
        for i, (le, mean, std) in enumerate(
            zip(le_list, cv_means, cv_stds, strict=True),
        ):
            count = len(cv_by_le[le])
            plt.text(
                i,
                mean + std + max(cv_means) * 0.02,
                f"{mean:.3f}±{std:.3f}\n(n={count})",
                ha="center",
                va="bottom",
                fontweight="bold",
                fontsize=10,
            )

        plt.xlabel("Number of Local Experts", fontsize=12, fontweight="bold")
        plt.ylabel(
            "Throughput Coefficient of Variation",
            fontsize=12,
            fontweight="bold",
        )
        plt.title(
            "Throughput Stability by Local Experts\n(Lower CV = More Stable)",
            fontsize=14,
            fontweight="bold",
            pad=15,
        )
        plt.grid(alpha=0.3, linestyle="-", linewidth=0.5, axis="y")

    # Subplot 6: Summary statistics table
    plt.subplot(2, 3, 6)
    plt.axis("off")

    # Create summary table
    table_data = []
    headers = [
        "Local\nExperts",
        "Count",
        "Mean\nThroughput",
        "Std Dev",
        "CV Mean",
        "CV Std",
    ]

    for le in sorted(local_experts_summary.keys()):
        stats = local_experts_summary[le]
        cv_stats = cv_by_le.get(le, [0])
        table_data.append(
            [
                str(le),
                str(stats["count"]),
                f"{stats['mean']:.1f}",
                f"{stats['std']:.1f}",
                f"{np.mean(cv_stats):.3f}" if cv_stats else "N/A",
                f"{np.std(cv_stats):.3f}" if len(cv_stats) > 1 else "N/A",
            ],
        )

    # Create table
    table = plt.table(
        cellText=table_data,
        colLabels=headers,
        cellLoc="center",
        loc="center",
        bbox=[0, 0.3, 1, 0.7],
    )
    table.auto_set_font_size(value=False)
    table.set_fontsize(10)
    table.scale(1, 2)

    # Color the rows according to local experts
    for i, le in enumerate(sorted(local_experts_summary.keys())):
        color = local_experts_color_dict[le]
        for j in range(len(headers)):
            table[i + 1, j].set_facecolor(color)
            table[i + 1, j].set_alpha(0.3)

    plt.title(
        "Throughput Summary Statistics\nby Number of Local Experts",
        fontsize=14,
        fontweight="bold",
        pad=20,
    )

    plt.tight_layout()
    plt.show()

    # Generate enhanced throughput analysis summary focused on local experts
    log.info("\n=== Throughput Analysis Summary by Local Experts ===")

    for le in sorted(local_experts_summary.keys()):
        stats = local_experts_summary[le]
        cv_stats = cv_by_le.get(le, [])

        log.info("\n📊 %s Local Experts:", le)
        log.info("   Number of runs: %s", stats["count"])
        log.info(
            "   Average throughput: %.2f ± %.2f tokens/s",
            stats["mean"],
            stats["std"],
        )
        log.info(
            "   Range: %.2f - %.2f tokens/s",
            min(stats["values"]),
            max(stats["values"]),
        )
        if cv_stats:
            log.info(
                "   Stability (CV): %.4f ± %.4f",
                np.mean(cv_stats),
                np.std(cv_stats),
            )

        # Show individual run details
        log.info("   Individual runs:")
        for result in this_cell_results_list:
            if result["n_local_experts"] == le:
                throughput = np.array(result["throughput"])
                if len(throughput) > 0:
                    avg_tp = np.mean(throughput)
                    incomplete_mark = (
                        "⚠️" if result["total_tokens"] < billion_tokens else "✅"
                    )
                    log.info(
                        "     %s %s...: %.2f tokens/s (TE:%s, OF:%.1f)",
                        incomplete_mark,
                        result["run_uuid"][:12],
                        avg_tp,
                        result["n_total_experts"],
                        result["overlapping_factor"],
                    )

    # Statistical analysis
    if len(local_experts_summary) > 1:
        log.info("\n🔍 Statistical Analysis:")

        # Find best and worst performing local expert configurations
        best_le = max(
            local_experts_summary.keys(),
            key=lambda x: local_experts_summary[x]["mean"],
        )
        worst_le = min(
            local_experts_summary.keys(),
            key=lambda x: local_experts_summary[x]["mean"],
        )

        log.info(
            "   🚀 Best average throughput: %s local experts (%.2f tokens/s)",
            best_le,
            local_experts_summary[best_le]["mean"],
        )
        log.info(
            "   Worst average throughput: %s local experts (%.2f tokens/s)",
            worst_le,
            local_experts_summary[worst_le]["mean"],
        )

        # Performance difference
        perf_diff = (
            local_experts_summary[best_le]["mean"]
            - local_experts_summary[worst_le]["mean"]
        )
        perf_ratio = (
            local_experts_summary[best_le]["mean"]
            / local_experts_summary[worst_le]["mean"]
        )
        log.info(
            "   📈 Performance difference: %.2f tokens/s (%.2f x speedup)",
            perf_diff,
            perf_ratio,
        )

        # Most stable configuration
        if cv_by_le:
            most_stable_le = min(cv_by_le.keys(), key=lambda x: np.mean(cv_by_le[x]))
            log.info(
                "   📊 Most stable: %s local experts (CV: %.4f)",
                most_stable_le,
                np.mean(cv_by_le[most_stable_le]),
            )

    log.info("✅ Enhanced throughput analysis by local experts completed!")

else:
    log.info("❌ No results available for throughput analysis")

In [None]:
this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    log.info("Results summary:")
    log.info(
        results_df[
            [
                "run_uuid",
                "n_total_clients",
                "global_train_batch_size",
                "n_local_experts",
                "overlapping_factor",
                "final_perplexity",
                "total_tokens",
            ]
        ].head(),
    )

    # Plot 1: Final Perplexity vs Experts Density (Local Experts)
    plt.figure(figsize=(12, 8))

    plt.subplot(2, 2, 1)
    scatter = plt.scatter(
        results_df["n_local_experts"],
        results_df["final_perplexity"],
        c=results_df["overlapping_factor"],
        cmap="viridis",
        s=100,
        alpha=0.7,
    )
    plt.xlabel("Experts Density (Local Experts per Client)")
    plt.ylabel("Final Perplexity")
    plt.title("Final Perplexity vs Experts Density")
    plt.colorbar(scatter, label="Overlapping Factor")
    plt.grid(alpha=0.3)

    # Plot 2: Final Perplexity vs Total Clients
    plt.subplot(2, 2, 2)
    scatter = plt.scatter(
        results_df["n_total_clients"],
        results_df["final_perplexity"],
        c=results_df["n_local_experts"],
        cmap="plasma",
        s=100,
        alpha=0.7,
    )
    plt.xlabel("Total Clients")
    plt.ylabel("Final Perplexity")
    plt.title("Final Perplexity vs Total Clients")
    plt.colorbar(scatter, label="Local Experts per Client")
    plt.grid(alpha=0.3)

    # Plot 3: Experts Density vs Overlapping Factor
    plt.subplot(2, 2, 3)
    scatter = plt.scatter(
        results_df["overlapping_factor"],
        results_df["n_local_experts"],
        c=results_df["final_perplexity"],
        cmap="coolwarm",
        s=100,
        alpha=0.7,
    )
    plt.xlabel("Overlapping Factor")
    plt.ylabel("Experts Density")
    plt.title("Experts Density vs Overlapping Factor")
    plt.colorbar(scatter, label="Final Perplexity")
    plt.grid(alpha=0.3)

    # Plot 4: Total Tokens vs Final Perplexity
    plt.subplot(2, 2, 4)
    scatter = plt.scatter(
        results_df["total_tokens"],
        results_df["final_perplexity"],
        c=results_df["n_local_experts"],
        cmap="tab10",
        s=100,
        alpha=0.7,
    )
    plt.xlabel("Total Tokens")
    plt.ylabel("Final Perplexity")
    plt.title("Total Tokens vs Final Perplexity")
    plt.colorbar(scatter, label="Experts Density")
    plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

else:
    log.info("No results to plot. Check if data was successfully processed.")

In [None]:
# Plot detailed perplexity curves


this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    plt.figure(figsize=(15, 10))

    # Plot 1: Perplexity vs Tokens for different expert densities
    plt.subplot(2, 2, 1)
    for result in this_cell_results_list:
        if result["tokens"] and result["perplexity"]:
            tokens = np.array(result["tokens"])
            perplexity = np.array(result["perplexity"])
            # Only plot non-NaN values
            valid_mask = ~np.isnan(perplexity)
            if np.any(valid_mask):
                label = (
                    f"UUID: {result['run_uuid'][:8]}..."
                    f" (LE: {result['n_local_experts']},"
                    f" TC: {result['n_total_clients']},"
                    f" OF: {result['overlapping_factor']})"
                )
                plt.plot(
                    tokens[valid_mask],
                    perplexity[valid_mask],
                    label=label,
                    linewidth=2,
                    alpha=0.8,
                )

    plt.xlabel("Total Tokens")
    plt.ylabel("Perplexity")
    plt.title("Perplexity vs Tokens by Expert Configuration")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.grid(alpha=0.3)
    plt.yscale("log")

    # Plot 2: Convergence comparison (normalized tokens)
    plt.subplot(2, 2, 2)
    for result in this_cell_results_list:
        if result["tokens"] and result["perplexity"]:
            tokens = np.array(result["tokens"])
            perplexity = np.array(result["perplexity"])
            valid_mask = ~np.isnan(perplexity)
            if np.any(valid_mask) and len(tokens[valid_mask]) > 0:
                # Normalize tokens to [0, 1] for comparison
                tokens_norm = (tokens[valid_mask] - tokens[valid_mask].min()) / (
                    tokens[valid_mask].max() - tokens[valid_mask].min() + 1e-8
                )
                label = (
                    f"LE: {result['n_local_experts']},"
                    f" TC: {result['n_total_clients']},"
                    f" OF: {result['overlapping_factor']}"
                )
                plt.plot(
                    tokens_norm,
                    perplexity[valid_mask],
                    label=label,
                    linewidth=2,
                    alpha=0.8,
                )

    plt.xlabel("Normalized Training Progress")
    plt.ylabel("Perplexity")
    plt.title("Convergence Comparison (Normalized)")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.yscale("log")

    # Plot 3: Final performance summary
    plt.subplot(2, 2, 3)
    # Group by experts density for box plot
    expert_densities = sorted(results_df["n_local_experts"].unique())
    perplexity_by_density = [
        results_df[results_df["n_local_experts"] == ed]["final_perplexity"].to_numpy()
        for ed in expert_densities
    ]

    plt.boxplot(perplexity_by_density)
    plt.xlabel("Local Experts")
    plt.ylabel("Final Perplexity")
    plt.title("Final Perplexity Distribution by Local Experts")
    plt.xticks(rotation=45)
    plt.grid(alpha=0.3)

    # Plot 4: Efficiency analysis (final perplexity vs total tokens)
    plt.subplot(2, 2, 4)
    for result in this_cell_results_list:
        plt.scatter(
            result["total_tokens"],
            result["final_perplexity"],
            s=result["n_local_experts"] * 50
            + 50,  # Size proportional to expert density
            alpha=0.7,
            label=(
                f"LE: {result['n_local_experts']},"
                f" TC: {result['n_total_clients']},"
                f" OF: {result['overlapping_factor']}"
            ),
        )

    plt.xlabel("Total Tokens")
    plt.ylabel("Final Perplexity")
    plt.title("Efficiency: Final Perplexity vs Training Cost\n(Size ∝ Local Experts)")
    plt.legend()
    plt.grid(alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Summary statistics
    log.info("\n=== Summary Statistics ===")
    log.info("Number of configurations analyzed: %s", len(results_df))
    log.info(
        "Local experts range: %s - %s",
        results_df["n_local_experts"].min(),
        results_df["n_local_experts"].max(),
    )
    log.info(
        "Overlapping factor range: %s - %s",
        results_df["overlapping_factor"].min(),
        results_df["overlapping_factor"].max(),
    )
    log.info(
        "Final perplexity range: %s - %s",
        results_df["final_perplexity"].min(),
        results_df["final_perplexity"].max(),
    )
    log.info(
        "Best performing configuration (lowest perplexity): UUID %s",
        results_df.loc[results_df["final_perplexity"].idxmin(), "run_uuid"],
    )
    results_df["efficiency"] = (
        results_df["final_perplexity"] / results_df["total_tokens"]
    )
    log.info(
        "Most efficient configuration (lowest perplexity/token ratio): UUID %s",
        results_df.loc[results_df["efficiency"].idxmin(), "run_uuid"],
    )
else:
    log.info("No results available for plotting.")

In [None]:
# Create summary table and insights


this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

if len(this_cell_results_list) > 0:
    log.info("=== Expert Density Analysis Summary Table ===")

    # Create a clean summary table
    summary_columns = [
        "run_uuid",
        "n_total_clients",
        "global_train_batch_size",
        "n_local_experts",
        "overlapping_factor",
        "n_total_experts",
        "final_perplexity",
        "total_tokens",
    ]

    summary_df = results_df[summary_columns].copy()
    summary_df["run_uuid_short"] = summary_df["run_uuid"].str[:12] + "..."
    summary_df = summary_df.drop("run_uuid", axis=1)

    # Reorder columns for better readability
    summary_df = summary_df[
        [
            "run_uuid_short",
            "n_total_clients",
            "n_total_experts",
            "n_local_experts",
            "overlapping_factor",
            "global_train_batch_size",
            "final_perplexity",
            "total_tokens",
        ]
    ]

    # Sort by final perplexity for easy comparison
    summary_df = summary_df.sort_values("final_perplexity")

    log.info(summary_df.to_string(index=False, float_format="%.4f"))

    log.info("\n=== Key Insights ===")

    # Best and worst performing configurations
    best_idx = results_df["final_perplexity"].idxmin()
    worst_idx = results_df["final_perplexity"].idxmax()

    log.info("🏆 Best Performance:")
    log.info("   UUID: %s", results_df.loc[best_idx, "run_uuid"])
    log.info("   Local Experts: %s", results_df.loc[best_idx, "n_local_experts"])
    log.info(
        "   Overlapping Factor: %s",
        results_df.loc[best_idx, "overlapping_factor"],
    )
    log.info("   Final Perplexity: %.4f", results_df.loc[best_idx, "final_perplexity"])

    log.info("\n📉 Worst Performance:")
    log.info("   UUID: %s", results_df.loc[worst_idx, "run_uuid"])
    log.info("   Local Experts: %s", results_df.loc[worst_idx, "n_local_experts"])
    log.info(
        "   Overlapping Factor: %s",
        results_df.loc[worst_idx, "overlapping_factor"],
    )
    log.info("   Final Perplexity: %.4f", results_df.loc[worst_idx, "final_perplexity"])

    # Correlation analysis
    log.info("\n🔍 Correlation Analysis:")
    correlations = results_df[
        [
            "n_local_experts",
            "overlapping_factor",
            "n_total_clients",
            "global_train_batch_size",
            "final_perplexity",
            "total_tokens",
        ]
    ].corr()

    perplexity_corr = correlations["final_perplexity"].sort_values(
        key=abs,
        ascending=False,
    )
    log.info("   Correlation with Final Perplexity:")
    for var, corr in perplexity_corr.items():
        if var != "final_perplexity":
            log.info("   - %s: %.3f", var, corr)

    # Efficiency analysis
    if "efficiency" in results_df.columns:
        most_efficient_idx = results_df["efficiency"].idxmin()
        log.info("\n⚡ Most Efficient Configuration:")
        log.info("   UUID: %s", results_df.loc[most_efficient_idx, "run_uuid"])
        log.info(
            "   Local Experts: %s",
            results_df.loc[most_efficient_idx, "n_local_experts"],
        )
        log.info(
            "   Overlapping Factor: %s",
            results_df.loc[most_efficient_idx, "overlapping_factor"],
        )
        log.info(
            "   Efficiency (Perplexity/Token): %.2e",
            results_df.loc[most_efficient_idx, "efficiency"],
        )

    # Local experts impact
    log.info("\n📊 Local Experts Impact:")
    experts_groups = results_df.groupby("n_local_experts")["final_perplexity"].agg(
        ["mean", "std", "count"],
    )
    log.info("   Average Final Perplexity by Local Experts:")
    for n_experts, stats in experts_groups.iterrows():
        log.info(
            "   - %d Local Experts: %.4f ± %.4f (n=%d)",
            n_experts,
            stats["mean"],
            stats["std"],
            int(stats["count"]),
        )

else:
    log.info("No results available for summary analysis.")