# μP Complete-P Scaling Analysis

This notebook analyzes the scaling behavior of Mixture of Experts (MoE) models using μP (Maximal Update Parameterization) with Complete-P scaling methodology. The analysis focuses on understanding how different hyperparameter configurations affect model performance across various model sizes and expert configurations.

## Overview

The notebook performs a comprehensive analysis of training runs from the `camlsys/photon` project, examining:

- **Model Architectures**: Dense models (MPT variants) and MoE models with 4 or 8 experts
- **Scaling Dimensions**: Width multipliers, depth multipliers, and base model dimensions
- **Training Hyperparameters**: Learning rates and batch sizes
- **Model Configurations**: Use of Peri-LN, embedding normalization, and bias settings

## Key Analysis Components

1. **Data Collection**: Downloads training metrics from Weights & Biases for specified model runs
2. **Parameter Counting**: Computes total trainable parameters, expert vs non-expert parameters
3. **Performance Evaluation**: Extracts perplexity vs token curves for each training run
4. **Hyperparameter Optimization**: Identifies optimal learning rates and batch sizes for each configuration
5. **Comparative Analysis**: Groups results by base model configuration and scaling multipliers

## Visualizations

The notebook generates comparative plots showing:
- **Best Final Perplexity vs Learning Rate**: Optimal learning rate identification for each scaling configuration
- **Best Final Perplexity vs Batch Size**: Optimal batch size identification for each scaling configuration

Results are grouped by base model configuration (d_model_base, number of experts, normalization settings) with different lines representing various depth/width multiplier combinations.

## Key Metrics

- **Final Perplexity**: Primary performance metric measured after training completion
- **Total Tokens**: Training data volume (filters for runs with ≥1B tokens)
- **Parameter Counts**: Breakdown of trainable, expert, and non-expert parameters
- **Hyperparameter Ranges**: Learning rates, batch sizes, and scaling multipliers tested

This analysis supports research into optimal scaling strategies for MoE models and provides insights into the μP Complete-P methodology's effectiveness across different model configurations.

In [None]:
import logging
import operator

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb

from fedmoe_plots.data_analysis import ColumnNotFoundError, get_perplexity_versus_tokens
from fedmoe_plots.parameter_counting import compute_parameter_counts
from fedmoe_plots.plotting_utils import configure_logging_for_jupyter
from fedmoe_plots.wandb_utils import (
    download_wandb_whole_history,
    get_clientrun_property_from_config,
    get_run_uuid_from_config,
)

configure_logging_for_jupyter()

log = logging.getLogger("mup_completep_scaling.ipynb")

EXCLUDE_INCOMPLETE_RUNS = True

In [None]:
FOUR_EXPERTS_RUNS = [
    "(^tune-sigma_moe_4e_44m)",
    "(^tune-deptoe_4e_256)",
    "(^deptoe_4e_44m_mpt-base)",
    "(^tune-deptoe_4e_128)",
    "(^tune-deptoe_4e_depth)",
    "(^tune-deptoe_4e_max)",
    "(^tune-deptoe_4e_width)",
]

EIGHT_EXPERTS_RUNS = [
    "(^tune-deptoe_8e_128)",
    "(^tune-deptoe_8e_256)",
    "(^tune-deptoe_8e_base)",
    "(^tune-deptoe_8e_mpt-base)",
]

RUNS_REGEX = [
    "(^tune-mpt_18m)",
    "(^tune-mpt_1B)",
    "(^tune-mpt_200m)",
    "(^tune-mpt_3m)",
    "(^tune-peri_mpt_18m)",
    "(^tune-peri_mpt_200m)",
    "(^tune-peri_mpt_3m)",
    *FOUR_EXPERTS_RUNS,
    *EIGHT_EXPERTS_RUNS,
]
RUNS_NO_PERI_LN = [
    "(^tune-mpt_3m)",
    "(^tune-mpt_18m)",
    "(^tune-mpt_200m)",
    "(^tune-mpt_1B)",
]
RUNS_W_PERI_LN = set(RUNS_REGEX) - set(RUNS_NO_PERI_LN)

In [None]:
api = wandb.Api(timeout=100)
full_runs_regex = "|".join(RUNS_REGEX)
runs = api.runs(
    path="camlsys/photon",
    filters={"display_name": {"$regex": f"{full_runs_regex}"}},
)
log.info("Found %d runs matching the regex: %s", len(runs), full_runs_regex)

In [None]:
def _get_n_experts(config: dict) -> int:
    try:
        n_experts = config["llm_config"]["model"]["ffn_config"]["ff_n_experts"]
        assert n_experts is not None, "ff_n_experts should not be None"
        assert isinstance(
            n_experts,
            int,
        ), f"ff_n_experts should be an integer and not of type {type(n_experts)}"

    except KeyError:
        return 1
    else:
        return n_experts


def _get_d_model_base(config: dict) -> int:
    d_model = config["llm_config"]["model"]["d_model"]
    assert d_model is not None, "d_model should not be None"
    assert isinstance(d_model, int), "d_model should be an integer"
    width_multiplier = config["llm_config"]["model"]["mup_config"][
        "mup_width_multiplier"
    ]
    assert width_multiplier is not None, "mup_width_multiplier should not be None"
    assert isinstance(
        width_multiplier,
        int,
    ), "mup_width_multiplier should be an integer"
    return d_model // width_multiplier

In [None]:
unique_depth_multiplier = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["model"]["mup_config"][
            "completep_depth_multiplier"
        ],
    )
    for run in runs
}
unique_width_multiplier = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["model"]["mup_config"][
            "mup_width_multiplier"
        ],
    )
    for run in runs
}
unique_lr = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["optimizer"]["lr"],
    )
    for run in runs
}
unique_batch_size = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["global_train_batch_size"],
    )
    for run in runs
}
unique_n_experts = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=_get_n_experts,
    )
    for run in runs
}
unique_peri_ln = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["model"]["use_peri_norm"],
    )
    for run in runs
}
unique_embedding_ln = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["model"][
            "use_embedding_norm"
        ],
    )
    for run in runs
}
unique_no_bias = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=lambda config: config["llm_config"]["model"]["no_bias"],
    )
    for run in runs
}
unique_d_model_base = {
    get_clientrun_property_from_config(
        run,
        get_property_fn=_get_d_model_base,
    )
    for run in runs
}
log.info(
    "Unique depth multipliers: %s",
    ", ".join(str(dm) for dm in unique_depth_multiplier if dm is not None),
)
log.info(
    "Unique width multipliers: %s",
    ", ".join(str(wm) for wm in unique_width_multiplier if wm is not None),
)
log.info(
    "Unique learning rates: %s",
    ", ".join(str(lr) for lr in unique_lr if lr is not None),
)
log.info(
    "Unique batch sizes: %s",
    ", ".join(str(bs) for bs in unique_batch_size if bs is not None),
)
log.info(
    "Unique number of experts: %s",
    ", ".join(str(ne) for ne in unique_n_experts if ne is not None),
)
log.info(
    "Unique use_peri_norm: %s",
    ", ".join(str(p) for p in unique_peri_ln if p is not None),
)
log.info(
    "Unique use_embedding_norm: %s",
    ", ".join(str(en) for en in unique_embedding_ln if en is not None),
)
log.info(
    "Unique no_bias: %s",
    ", ".join(str(nb) for nb in unique_no_bias if nb is not None),
)
log.info(
    "Unique d_model_base: %s",
    ", ".join(str(dm) for dm in unique_d_model_base if dm is not None),
)

In [None]:
# Data collection and processing for expert density analysis

log.info("🔄 Starting data collection and processing...")
log.info("Found %s runs to process", len(runs))

results_list = []

for i, run in enumerate(runs):
    log.info(
        "📊 Processing run %s/%s: %s",
        i + 1,
        len(runs),
        get_run_uuid_from_config(run),
    )
    try:
        # Extract configuration parameters (depth multiplier, width multiplier,
        # learning rate, batch size, number of experts, hidden dimension of the model,
        # run UUID)
        run_uuid = get_run_uuid_from_config(run)
        depth_multiplier = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["model"]["mup_config"][
                "completep_depth_multiplier"
            ],
        )
        width_multiplier = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["model"]["mup_config"][
                "mup_width_multiplier"
            ],
        )
        learning_rate = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["optimizer"]["lr"],
        )
        global_train_batch_size = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"][
                "global_train_batch_size"
            ],
        )
        n_total_experts = get_clientrun_property_from_config(
            run,
            get_property_fn=_get_n_experts,
        )
        d_model_base = get_clientrun_property_from_config(
            run,
            get_property_fn=_get_d_model_base,
        )
        peri_ln = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["model"][
                "use_peri_norm"
            ],
        )
        embedding_ln = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["model"][
                "use_embedding_norm"
            ],
        )
        no_bias = get_clientrun_property_from_config(
            run,
            get_property_fn=lambda config: config["llm_config"]["model"]["no_bias"],
        )

        # Calculate derived metrics (number of total trainable parameters, number of
        # non-expert parameters, number of expert parameters, number of non embedding
        # parameters, number of embedding parameters)
        run_config = run.config
        assert run_config is not None, "Run config must not be None"
        assert isinstance(
            run_config,
            dict,
        ), f"Run config must be a dictionary, not {type(run_config)}"
        llm_config = run_config.get("llm_config", {})
        assert llm_config is not None, "Model config must not be None"
        assert isinstance(
            llm_config,
            dict,
        ), f"Model config must be a dictionary, not {type(llm_config)}"
        parameter_counts = compute_parameter_counts(llm_config)
        n_trainable_parameters = parameter_counts.n_trainable
        n_non_expert_parameters = parameter_counts.n_non_experts
        n_expert_parameters = parameter_counts.n_experts
        n_non_embedding_parameters = parameter_counts.n_non_embedding
        n_embedding_parameters = parameter_counts.n_embedding

        # Download and process data
        log.info("🔽 Downloading data for %s...", run_uuid)

        try:
            # Download photon metrics for the run
            assert run_uuid is not None, "Run UUID must not be None"
            assert isinstance(
                run_uuid,
                str,
            ), f"Run UUID must be a string not a {type(run_uuid)}"
            run_df = download_wandb_whole_history(
                run=run,
            )

            # Try to get perplexity vs tokens data
            tokens, perplexity = get_perplexity_versus_tokens(
                client_metrics_df=run_df,
                n_clients_per_round=1,
            )

            if len(tokens) == 0 or len(perplexity) == 0:
                log.warning("❌ Empty data arrays for %s", run_uuid)
                continue

            # Calculate metrics
            final_perplexity = (
                perplexity.iloc[-1] if len(perplexity) > 0 else float("nan")
            )
            total_tokens = tokens.iloc[-1] if len(tokens) > 0 else 0
            n_data_points = len(tokens)

            log.info(
                (
                    "   ✅ Data processed: %d points,"
                    " Final perplexity: %.4f, Total tokens: %.0f"
                ),
                n_data_points,
                final_perplexity,
                total_tokens,
            )

            # Store results
            results_list.append(
                {
                    "run_uuid": str(run_uuid),
                    "d_model_base": d_model_base,
                    "n_total_experts": n_total_experts,
                    "depth_multiplier": depth_multiplier,
                    "width_multiplier": width_multiplier,
                    "peri_ln": peri_ln,
                    "embedding_ln": embedding_ln,
                    "no_bias": no_bias,
                    "learning_rate": learning_rate,
                    "global_train_batch_size": global_train_batch_size,
                    "n_trainable_parameters": n_trainable_parameters,
                    "n_non_expert_parameters": n_non_expert_parameters,
                    "n_expert_parameters": n_expert_parameters,
                    "n_non_embedding_parameters": n_non_embedding_parameters,
                    "n_embedding_parameters": n_embedding_parameters,
                    "final_perplexity": final_perplexity,
                    "total_tokens": total_tokens,
                    "n_data_points": n_data_points,
                    "tokens": (
                        tokens.tolist() if hasattr(tokens, "tolist") else list(tokens)
                    ),
                    "perplexity": (
                        perplexity.tolist()
                        if hasattr(perplexity, "tolist")
                        else list(perplexity)
                    ),
                },
            )

        except ColumnNotFoundError:
            log.exception(
                "   ⚠️ Column not found in client metrics DataFrame for %s. "
                "We assume this run crashed and doesn't have the expected data.",
                run_uuid,
                stack_info=True,
            )

            # Remove this run from WandB runs to avoid further processing
            if run.state != "running":
                log.info(
                    "Removing run %s with display name %s.",
                    run.id,
                    run.display_name,
                )
                run.delete()
            continue

        except Exception:
            log.exception("   ❌ Error processing %s", run_uuid, stack_info=True)
            continue

    except Exception:
        assert run is not None, "Run must not be None"
        log.exception(
            "   ❌ Error extracting config for run %s",
            run.name,
            stack_info=True,
        )
        continue

log.info("\n✅ Data collection completed!")
log.info("Successfully processed %d out of %d runs", len(results_list), len(runs))

In [None]:
log.info("\n=== Summary ===")
log.info(
    "Successfully processed %d out of %d runs",
    len(results_list),
    len(runs),
)

incomplete_runs = []
complete_runs = []

# Check for runs that didn't reach 1 billion tokens and log warnings
if results_list:
    log.info("\n⚠️  TOKEN COUNT ANALYSIS ⚠️")
    log.info("=" * 60)

    billion_tokens = 1e9

    for result in results_list:
        if result["total_tokens"] < billion_tokens:
            incomplete_runs.append(result)
        else:
            complete_runs.append(result)

    if incomplete_runs:
        log.warning(
            "🔴 WARNING: %d run(s) did NOT reach 1 billion tokens:",
            len(incomplete_runs),
        )
        log.info("-" * 60)

        for i, result in enumerate(incomplete_runs, 1):
            log.info("\n%s. Run UUID: %s", i, result["run_uuid"])
            log.info(
                "   Total Tokens: %s (%.3fB)",
                format(result["total_tokens"], ","),
                result["total_tokens"] / 1e9,
            )
            log.info(
                "   Completion: %.1f%% of 1B tokens",
                result["total_tokens"] / billion_tokens * 100,
            )
            log.info("   📋 Full Configuration:")
            log.info("      • Run UUID: %s", result["run_uuid"])
            log.info(
                "      • Hidden dimension of the base model: %s",
                result["d_model_base"],
            )
            log.info("      • Use Peri-LN: %s", result["peri_ln"])
            log.info("      • Use Embedding-LN: %s", result["embedding_ln"])
            log.info("      • No Bias: %s", result["no_bias"])
            log.info("      • Total Experts: %s", result["n_total_experts"])
            log.info("      • Depth Multiplier: %s", result["depth_multiplier"])
            log.info("      • Width Multiplier: %s", result["width_multiplier"])
            log.info("      • Learning Rate: %s", result["learning_rate"])
            log.info(
                "      • Global Train Batch Size: %s",
                result["global_train_batch_size"],
            )
            log.info(
                "      • Trainable Parameters: %s",
                format(result["n_trainable_parameters"], ","),
            )
            log.info(
                "      • Non-Expert Parameters: %s",
                format(result["n_non_expert_parameters"], ","),
            )
            log.info(
                "      • Expert Parameters: %s",
                format(result["n_expert_parameters"], ","),
            )
            log.info("      • Final Perplexity: %.4f", result["final_perplexity"])
            log.info("      • Data Points: %s", result["n_data_points"])

    if complete_runs:
        log.info(
            "✅ %d run(s) successfully reached 1+ billion tokens:",
            len(complete_runs),
        )
        for result in complete_runs:
            log.info(
                "   • %s: %d tokens (%.3fB)",
                result["run_uuid"],
                result["total_tokens"],
                result["total_tokens"] / 1e9,
            )

    log.info("\n📊 Token Count Summary:")
    log.info("   • Complete runs (≥1B tokens): %d", len(complete_runs))
    log.info("   • Incomplete runs (<1B tokens): %d", len(incomplete_runs))
    if results_list:
        avg_tokens = sum(r["total_tokens"] for r in results_list) / len(results_list)
        log.info(
            "   • Average tokens across all runs: %.0f (%.3fB)",
            avg_tokens,
            avg_tokens / 1e9,
        )

this_cell_results_list = results_list
if EXCLUDE_INCOMPLETE_RUNS:
    this_cell_results_list = complete_runs

# Convert to DataFrame for easier analysis
if this_cell_results_list:
    results_df = pd.DataFrame(this_cell_results_list)
    log.info("\nResults DataFrame shape: %s", results_df.shape)
    log.info("\nConfiguration summary:")
    summary_cols = [
        "run_uuid",
        "d_model_base",
        "peri_ln",
        "embedding_ln",
        "no_bias",
        "n_total_experts",
        "depth_multiplier",
        "width_multiplier",
        "learning_rate",
        "global_train_batch_size",
        "final_perplexity",
        "n_trainable_parameters",
    ]
    log.info(results_df[summary_cols].to_string(index=False, float_format="%.4f"))
else:
    log.info("No results to analyze")
    results_df = pd.DataFrame()

In [None]:
# Group the results by base configuration (d_model, n_experts) and then by
# depth and width multiplier. For each group, plot the best final
# perplexity versus the learning rate and the final perplexity versus the batch size
# drawing a line for each multipliers configuration in the group.

if results_df.empty:
    log.warning("No results to analyze - results_df is empty")
else:
    log.info("🔍 Starting analysis of grouped results...")

    # Group by base configuration (d_model, n_total_experts)
    base_groups = results_df.groupby(
        [
            "d_model_base",
            "n_total_experts",
            "peri_ln",
            "embedding_ln",
            "no_bias",
        ],
    )

    log.info("Found %d base groups", len(base_groups))

    # Statistical significance analysis across all data
    log.info("\n📊 STATISTICAL SIGNIFICANCE ANALYSIS:")
    log.info("=" * 60)

    # Check for duplicate configurations across the entire dataset
    all_config_groups = results_df.groupby(
        [
            "d_model_base",
            "n_total_experts",
            "peri_ln",
            "embedding_ln",
            "no_bias",
            "depth_multiplier",
            "width_multiplier",
            "learning_rate",
            "global_train_batch_size",
        ],
    )

    total_duplicate_configs = 0
    total_unique_configs = len(all_config_groups)

    for _config, config_df in all_config_groups:
        if len(config_df) > 1:
            total_duplicate_configs += 1

    log.info("Total unique configurations: %d", total_unique_configs)
    log.info(
        "Configurations with multiple runs: %d (%.1f%%)",
        total_duplicate_configs,
        (
            total_duplicate_configs / total_unique_configs * 100
            if total_unique_configs > 0
            else 0
        ),
    )
    log.info(
        "Single-run configurations: %d (%.1f%%)",
        total_unique_configs - total_duplicate_configs,
        (
            (total_unique_configs - total_duplicate_configs)
            / total_unique_configs
            * 100
            if total_unique_configs > 0
            else 0
        ),
    )

    if total_duplicate_configs > 0:
        log.info("✅ Statistical analysis will show mean ± std for multiple runs")
        log.info("   Error bars represent standard deviation across multiple runs")
    else:
        log.info("   All configurations are single runs - no error bars will be shown")
    log.info("=" * 60)

    # Enlarge plots horizontally - increase width multiplier
    fig, axes = plt.subplots(2, len(base_groups), figsize=(8 * len(base_groups), 12))

    # Ensure axes is always 2D for consistent indexing
    if len(base_groups) == 1:
        axes = axes.reshape(2, 1)

    for group_idx, (
        (d_model_base_val, n_experts_val, peri_ln_val, embedding_ln_val, no_bias_val),
        base_group_df,
    ) in enumerate(
        base_groups,
    ):
        log.info(
            (
                "📊 Processing base configuration: d_model_base=%s, n_experts=%s,"
                " peri_ln=%s, embedding_ln=%s, no_bias=%s"
            ),
            d_model_base_val,
            n_experts_val,
            peri_ln_val,
            embedding_ln_val,
            no_bias_val,
        )

        # Exclude outliers using IQR method
        Q1 = base_group_df["final_perplexity"].quantile(0.25)
        Q3 = base_group_df["final_perplexity"].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        # Filter out outliers
        base_group_df_filtered = base_group_df[
            (base_group_df["final_perplexity"] >= lower_bound)
            & (base_group_df["final_perplexity"] <= upper_bound)
        ]

        log.info(
            "   Excluded %d outliers (%.1f%% of data)",
            len(base_group_df) - len(base_group_df_filtered),
            (len(base_group_df) - len(base_group_df_filtered))
            / len(base_group_df)
            * 100,
        )

        # Group by multipliers configuration within this base group
        multiplier_groups = base_group_df_filtered.groupby(
            ["depth_multiplier", "width_multiplier"],
        )

        log.info("   Found %d multiplier configurations", len(multiplier_groups))

        # Plot 1: Best final perplexity vs learning rate
        ax1 = axes[0, group_idx]

        # Plot 2: Final perplexity vs batch size
        ax2 = axes[1, group_idx]

        colors = plt.cm.tab10(  # pyright: ignore[reportAttributeAccessIssue]
            np.linspace(0, 1, len(multiplier_groups)),
        )

        # Track optimal values for vertical lines
        optimal_lr_values = {}
        optimal_bs_values = {}

        for color_idx, ((depth_mult, width_mult), mult_group_df) in enumerate(
            multiplier_groups,
        ):
            label = f"depth x {depth_mult}, width x {width_mult}"
            color = colors[color_idx]

            # For perplexity vs learning rate: group by learning rate and calculate
            # mean and std for statistical significance
            lr_groups = mult_group_df.groupby("learning_rate")
            lr_values = []
            mean_perplexities_lr = []
            std_perplexities_lr = []

            for lr, lr_group in lr_groups:
                perp_values = lr_group["final_perplexity"]
                mean_perp = perp_values.mean()
                std_perp = perp_values.std() if len(perp_values) > 1 else 0.0

                lr_values.append(lr)
                mean_perplexities_lr.append(mean_perp)
                std_perplexities_lr.append(std_perp)

                # Log information about multiple runs for the same configuration
                if len(perp_values) > 1:
                    log.info(
                        "      LR %s: %d runs, mean=%.4f, std=%.4f",
                        lr,
                        len(perp_values),
                        mean_perp,
                        std_perp,
                    )

            # Sort by learning rate for proper line plotting
            lr_perp_pairs = sorted(
                zip(lr_values, mean_perplexities_lr, std_perplexities_lr, strict=True),
            )
            lr_values, mean_perplexities_lr, std_perplexities_lr = (
                zip(*lr_perp_pairs, strict=True) if lr_perp_pairs else ([], [], [])
            )

            # Plot learning rate vs mean perplexity with error bars
            if lr_values:
                ax1.errorbar(
                    lr_values,
                    mean_perplexities_lr,
                    yerr=std_perplexities_lr,
                    fmt="o-",
                    color=color,
                    label=label,
                    linewidth=2,
                    markersize=6,
                    capsize=5,
                    capthick=1.5,
                )

                # Find optimal learning rate for this multiplier group (lowest mean)
                min_perp_idx = np.argmin(mean_perplexities_lr)
                optimal_lr = lr_values[min_perp_idx]
                optimal_lr_values[depth_mult, width_mult] = (optimal_lr, color)

            # For perplexity vs batch size: group by batch size and calculate
            # mean and std for statistical significance
            bs_groups = mult_group_df.groupby("global_train_batch_size")
            bs_values = []
            mean_perplexities_bs = []
            std_perplexities_bs = []

            for bs, bs_group in bs_groups:
                perp_values = bs_group["final_perplexity"]
                mean_perp = perp_values.mean()
                std_perp = perp_values.std() if len(perp_values) > 1 else 0.0

                bs_values.append(bs)
                mean_perplexities_bs.append(mean_perp)
                std_perplexities_bs.append(std_perp)

                # Log information about multiple runs for the same configuration
                if len(perp_values) > 1:
                    log.info(
                        "      BS %s: %d runs, mean=%.4f, std=%.4f",
                        bs,
                        len(perp_values),
                        mean_perp,
                        std_perp,
                    )

            # Sort by batch size for proper line plotting
            bs_perp_pairs = sorted(
                zip(bs_values, mean_perplexities_bs, std_perplexities_bs, strict=True),
            )
            bs_values, mean_perplexities_bs, std_perplexities_bs = (
                zip(*bs_perp_pairs, strict=True) if bs_perp_pairs else ([], [], [])
            )

            # Plot batch size vs mean perplexity with error bars
            if bs_values:
                ax2.errorbar(
                    bs_values,
                    mean_perplexities_bs,
                    yerr=std_perplexities_bs,
                    fmt="s-",
                    color=color,
                    label=label,
                    linewidth=2,
                    markersize=6,
                    capsize=5,
                    capthick=1.5,
                )

                # Find optimal batch size for this multiplier group (lowest mean)
                min_perp_idx = np.argmin(mean_perplexities_bs)
                optimal_bs = bs_values[min_perp_idx]
                optimal_bs_values[depth_mult, width_mult] = (optimal_bs, color)

        # Add vertical lines for optimal learning rates with labels
        for optimal_lr, color in optimal_lr_values.values():
            ax1.axvline(
                x=optimal_lr,
                color=color,
                linestyle="--",
                alpha=0.7,
                linewidth=1.5,
            )
            # Add label with learning rate value
            ax1.text(
                optimal_lr,
                ax1.get_ylim()[1] * 0.95,
                f"{optimal_lr:.1e}",
                rotation=90,
                verticalalignment="top",
                horizontalalignment="right",
                color=color,
                fontsize=9,
                fontweight="bold",
            )

        # Add vertical lines for optimal batch sizes with labels
        for optimal_bs, color in optimal_bs_values.values():
            ax2.axvline(
                x=optimal_bs,
                color=color,
                linestyle="--",
                alpha=0.7,
                linewidth=1.5,
            )
            # Add label with batch size value
            ax2.text(
                optimal_bs,
                ax2.get_ylim()[1] * 0.95,
                f"{optimal_bs}",
                rotation=90,
                verticalalignment="top",
                horizontalalignment="right",
                color=color,
                fontsize=9,
                fontweight="bold",
            )

        # Configure first subplot (learning rate vs perplexity)
        ax1.set_xlabel("Learning Rate", fontsize=12)
        ax1.set_ylabel("Mean Final Perplexity ± Std", fontsize=12)
        ax1.set_title(
            (
                f"Mean Perplexity vs Learning Rate\n"
                f"d_model_base={d_model_base_val}, n_experts={n_experts_val},\n"
                f" peri_ln={peri_ln_val}, embedding_ln={embedding_ln_val},"
                f" no_bias={no_bias_val}"
            ),
            fontsize=14,
        )
        ax1.set_xscale("log")
        ax1.grid(alpha=0.3)
        ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

        # Configure second subplot (batch size vs perplexity)
        ax2.set_xlabel("Global Train Batch Size", fontsize=12)
        ax2.set_ylabel("Mean Final Perplexity ± Std", fontsize=12)
        ax2.set_title(
            (
                f"Mean Perplexity vs Batch Size\n"
                f"d_model_base={d_model_base_val}, n_experts={n_experts_val}"
            ),
            fontsize=14,
        )
        ax2.set_xscale("log")
        ax2.grid(alpha=0.3)
        ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

        log.info("   ✅ Processed %s multiplier configurations", len(multiplier_groups))

    plt.tight_layout()
    plt.show()

    # Additional analysis: Print summary statistics
    log.info("\n📈 SUMMARY STATISTICS BY BASE CONFIGURATION:")
    log.info("=" * 70)

    for (
        d_model_base_val,
        n_experts_val,
        peri_ln_val,
        embedding_ln_val,
        no_bias_val,
    ), base_group_df in base_groups:
        log.info(
            (
                "\n🔹 Base config: d_model_base=%s, n_experts=%s,"
                " peri_ln=%s, embedding_ln=%s, no_bias=%s"
            ),
            d_model_base_val,
            n_experts_val,
            peri_ln_val,
            embedding_ln_val,
            no_bias_val,
        )
        log.info("   Total runs: %s", len(base_group_df))

        # Best overall perplexity in this base configuration (single best run)
        best_run = base_group_df.loc[base_group_df["final_perplexity"].idxmin()]
        log.info("   Best single run perplexity: %.4f", best_run["final_perplexity"])
        log.info(
            "   Best config: depth x %s, width x %s",
            best_run["depth_multiplier"],
            best_run["width_multiplier"],
        )
        log.info("   Best LR: %s", best_run["learning_rate"])
        log.info("   Best batch size: %s", best_run["global_train_batch_size"])

        # Statistical analysis: check for configurations with multiple runs
        config_groups = base_group_df.groupby(
            [
                "depth_multiplier",
                "width_multiplier",
                "learning_rate",
                "global_train_batch_size",
            ],
        )

        multiple_run_configs = []
        for config, config_df in config_groups:
            if len(config_df) > 1:
                depth_mult, width_mult, lr, bs = config
                mean_perp = config_df["final_perplexity"].mean()
                std_perp = config_df["final_perplexity"].std()
                multiple_run_configs.append(
                    {
                        "config": config,
                        "n_runs": len(config_df),
                        "mean_perp": mean_perp,
                        "std_perp": std_perp,
                        "min_perp": config_df["final_perplexity"].min(),
                        "max_perp": config_df["final_perplexity"].max(),
                    },
                )

        if multiple_run_configs:
            log.info(
                "   📊 Configurations with multiple runs (%d total):",
                len(multiple_run_configs),
            )
            for config_info in sorted(
                multiple_run_configs,
                key=operator.itemgetter("mean_perp"),
            ):
                depth_mult, width_mult, lr, bs = config_info["config"]
                log.info(
                    "      • depth x %s, width x %s, LR=%s, BS=%s: %d runs",
                    depth_mult,
                    width_mult,
                    lr,
                    bs,
                    config_info["n_runs"],
                )
                log.info(
                    "        Mean: %.4f ± %.4f, Range: [%.4f, %.4f]",
                    config_info["mean_perp"],
                    config_info["std_perp"],
                    config_info["min_perp"],
                    config_info["max_perp"],
                )
        else:
            log.info("       No configurations with multiple runs (all single runs)")

        # Best configuration by mean performance (for configs with multiple runs)
        if multiple_run_configs:
            best_config_info = min(
                multiple_run_configs, key=operator.itemgetter("mean_perp"),
            )
            depth_mult, width_mult, lr, bs = best_config_info["config"]
            log.info(
                "   🏆 Best mean config: depth x %s, width x %s, LR=%s, BS=%s",
                depth_mult,
                width_mult,
                lr,
                bs,
            )
            log.info(
                "      Mean perplexity: %.4f ± %.4f (%d runs)",
                best_config_info["mean_perp"],
                best_config_info["std_perp"],
                best_config_info["n_runs"],
            )

        # Range of perplexities
        perp_range = (
            base_group_df["final_perplexity"].max()
            - base_group_df["final_perplexity"].min()
        )
        log.info("   Perplexity range: %.4f", perp_range)

        # Multiplier configurations tested
        multiplier_configs = base_group_df[
            ["depth_multiplier", "width_multiplier"]
        ].drop_duplicates()
        log.info("   Multiplier configs tested: %s", len(multiplier_configs))
        for _, config in multiplier_configs.iterrows():
            log.info(
                "      • depth x %s, width x %s",
                config["depth_multiplier"],
                config["width_multiplier"],
            )

    log.info("\n✅ Analysis completed successfully!")