In [1]:
"""
get scaling laws.

usage:
1. pick your flops
2. choose what you want to optimize that has a nice convex-like minima (e.g. width, lr, etc ...).
    - optionally in log space
3. decide what metric you care about optimizing (e.g. val loss)
    - also optionally in log space
4. make a function that:
   - takes flops as input
   - takes all your parameters you want to optimize
   - returns the metric you care about
5. feed everything into get_scaling_laws() and let it do its thing
   - under the hood it uses powell's method to optimize your params for each flop amount
   - you can tweak tolerance and max_iter if you want to be more conservative
   - it'll automatically fit scaling laws to everything and plot nicely in output dir
"""

import pathlib
import typing

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from scipy.optimize import minimize
from tqdm import tqdm


class ScalingParameter(typing.NamedTuple):
    name: str
    initial: float
    bounds: tuple[float, float]
    log_space: bool


class OptimizationResult(typing.NamedTuple):
    flops: float
    parameters: dict[str, float]
    metric_value: float
    iteration: int


class ScalingLaw(typing.NamedTuple):
    name: str
    slope: float
    intercept: float
    log_space: bool
    r2: float


class TrainFunction(typing.Protocol):
    def __call__(self, flops: float, *args, **kwargs: typing.Any) -> float: ...


def get_scaling_laws(
    params: list[ScalingParameter],
    flops_values: list[float],
    train_fn: TrainFunction,
    metric_name: str,
    goal: typing.Literal["minimize", "maximize"] = "minimize",
    tolerance: float = 0.5,
    log_metric: bool = False,
    verbose: bool = True,
) -> tuple[list[ScalingLaw], list[dict], list[OptimizationResult]]:
    # Collect final results and optimization history
    results = []
    all_history = []
    prev_params = None

    # Optimize parameters at each FLOPS value
    for flops in tqdm(flops_values, desc="Processing FLOPS", disable=not verbose):
        best_step, history = _optimize_parameters(
            flops=flops, params=params, train_fn=train_fn, goal=goal, prev_params=prev_params, verbose=verbose, tol=tolerance
        )
        prev_params = best_step.parameters
        all_history.extend(history)
        results.append({"flops": flops, "parameters": best_step.parameters, metric_name: best_step.metric_value})

    # Prepare log(FLOPS) for linear fits
    log_flops = np.log10(flops_values)

    # Compute scaling laws for parameters
    laws = []
    for param in params:
        values = [r["parameters"][param.name] for r in results]
        y = np.log10(values) if param.log_space else values
        slope, intercept, r_value, *_ = stats.linregress(log_flops, y)
        laws.append(ScalingLaw(param.name, slope, intercept, param.log_space, r_value**2))  # type: ignore

    # Compute scaling law for metric
    metric_values = [r[metric_name] for r in results]
    y = np.log10([abs(m) for m in metric_values]) if log_metric else metric_values
    slope, intercept, r_value, *_ = stats.linregress(log_flops, y)
    laws.append(ScalingLaw(metric_name, slope, intercept, log_metric, r_value**2))  # type: ignore

    # Visualize and save results
    _visualize_results(
        laws=laws,
        results=results,
        all_history=all_history,
        metric_name=metric_name,
        output_dir="scaling_results",
    )

    return laws, results, all_history


def _optimize_parameters(
    flops: float,
    params: list[ScalingParameter],
    train_fn: TrainFunction,
    goal: typing.Literal["minimize", "maximize"] = "minimize",
    prev_params: dict[str, float] | None = None,
    max_iter: int = 10,
    tol: float = 0.5,
    verbose: bool = True,
) -> tuple[OptimizationResult, list[OptimizationResult]]:
    # Keep track of all attempts for this FLOPS
    history = []
    pbar = tqdm(total=max_iter, desc=f"FLOPS={flops:.2e}") if verbose else None

    # Objective function for Powell's method
    def objective(x: np.ndarray) -> float:
        # Convert parameters back if they are in log space
        param_dict = {p.name: (10**xi if p.log_space else xi) for xi, p in zip(x, params)}

        # Run the training function and record results
        metric = train_fn(flops=flops, **param_dict)
        history.append(OptimizationResult(flops, param_dict.copy(), metric, len(history)))

        if verbose and pbar:
            param_str = ", ".join(f"{k}={v:.2e}" for k, v in param_dict.items())
            pbar.write(f"Iteration {len(history)}: {param_str}, metric={metric:.4f}")
            pbar.update(1)

        # Flip metric for maximization
        return metric * (-1 if goal == "maximize" else 1)

    # Determine initial guess
    x0 = [
        (np.log10(prev_params[p.name]) if p.log_space else prev_params[p.name])
        if (prev_params and p.name in prev_params)
        else (np.log10(p.initial) if p.log_space else p.initial)
        for p in params
    ]

    # Set bounds in log space if needed
    bounds = [(np.log10(b[0]), np.log10(b[1])) if p.log_space else b for p, b in zip(params, [p.bounds for p in params])]

    # Run Powell optimization
    try:
        minimize(objective, x0, method="Powell", bounds=bounds, tol=tol, options=dict(maxiter=max_iter))
    finally:
        if pbar:
            pbar.close()

    # Select best result depending on minimization or maximization
    best_step = max(history, key=lambda x: x.metric_value) if goal == "maximize" else min(history, key=lambda x: x.metric_value)

    return best_step, history


def _visualize_results(
    laws: list[ScalingLaw],
    results: list[dict],
    all_history: list[OptimizationResult],
    metric_name: str,
    output_dir: str = "scaling_results",
) -> None:
    # Create output directory
    output_path = pathlib.Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Save full history as CSV
    pd.DataFrame(
        [
            {
                "flops": step.flops,
                "iteration": step.iteration,
                "metric_value": step.metric_value,
                **step.parameters,
            }
            for step in all_history
        ]
    ).to_csv(output_path / "optimization_history.csv", index=False)

    # Save scaling laws as CSV
    pd.DataFrame(
        [{"name": law.name, "slope": law.slope, "intercept": law.intercept, "log_space": law.log_space, "r2": law.r2} for law in laws]
    ).to_csv(output_path / "scaling_laws.csv", index=False)

    # Plot scaling laws for parameters and metric
    fig, axes = plt.subplots(nrows=len(laws), ncols=1, figsize=(12, 5 * len(laws)))
    axes = np.atleast_1d(axes)

    for idx, law in enumerate(laws):
        ax = axes[idx]
        x_values = [r["flops"] for r in results]
        y_values = [r["parameters"][law.name] if law.name != metric_name else r[metric_name] for r in results]

        # Log-scale FLOPS axis
        ax.set_xscale("log")

        # Use log scale for y if needed
        if law.log_space:
            ax.set_yscale("log")

        # Smooth curve for fitted law
        x_smooth = np.logspace(np.log10(min(x_values)), np.log10(max(x_values)), 100)
        if law.log_space:
            y_smooth = 10 ** (law.slope * np.log10(x_smooth) + law.intercept)
        else:
            y_smooth = law.slope * np.log10(x_smooth) + law.intercept

        ax.scatter(x_values, y_values, color="blue", s=100)
        ax.plot(x_smooth, y_smooth, color="red", linewidth=2)
        ax.set_title(f"{law.name} vs FLOPS (R² = {law.r2:.3f})")
        ax.grid(True)

    plt.tight_layout()
    plt.savefig(output_path / "scaling_plots.png", bbox_inches="tight")
    plt.close()

    # Visualize parameter attempts for each FLOPS
    param_laws = [law for law in laws if law.name != metric_name]
    fig, axes = plt.subplots(nrows=len(param_laws), ncols=1, figsize=(12, 5 * len(param_laws)))
    axes = np.atleast_1d(axes)

    # Unique FLOPS values for consistent coloring
    unique_flops = sorted(set(step.flops for step in all_history))
    colors = plt.cm.viridis(np.linspace(0, 1, len(unique_flops)))  # type: ignore
    flops_color_map = dict(zip(unique_flops, colors))

    for idx, law in enumerate(param_laws):
        ax = axes[idx]
        attempts_by_flops = {}

        # Group all attempts by FLOPS
        for step in all_history:
            if step.flops not in attempts_by_flops:
                attempts_by_flops[step.flops] = {"params": [], "metrics": []}
            attempts_by_flops[step.flops]["params"].append(step.parameters[law.name])
            attempts_by_flops[step.flops]["metrics"].append(step.metric_value)

        # Plot attempts per FLOPS value
        for flops_val, attempts in attempts_by_flops.items():
            color = flops_color_map[flops_val]
            ax.scatter(attempts["params"], attempts["metrics"], alpha=0.5, color=color, label=f"FLOPS={flops_val:.1e}")

            # Highlight best result for that FLOPS
            optimal_result = next(r for r in results if r["flops"] == flops_val)
            ax.scatter(
                optimal_result["parameters"][law.name],
                optimal_result[metric_name],
                color=color,
                marker="*",
                s=200,
                edgecolor="black",
                linewidth=1.5,
                zorder=10,
            )

        # Log-scale parameter axis if needed
        if law.log_space:
            ax.set_xscale("log")

        ax.set_xlabel(law.name)
        ax.set_ylabel(metric_name)
        ax.set_title(f"Parameter Attempts: {law.name}")
        ax.grid(True)
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.savefig(output_path / "parameter_attempts.png", bbox_inches="tight")
    plt.close()

Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
                                                                                [A
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
                                                                                [A
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
                                                                                [A
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
                                                                                [A
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
                                                                                [A
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
                                                                                [A
Processing

Iteration 1: learning_rate=1.00e-02, batch_size=2.56e+02, num_layers=5.00e+00, metric=17.2622
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=10.0597
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=16.8585
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_s


                                                                                [A
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
FLOPS=3.98e+12:   0%|                                    | 0/10 [00:00<?, ?it/s][A

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
                                                                                [A
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
                                                                                [A
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
                                                                                [A
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
                                                                                [A
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
                                                                                [A
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.00it/s]
             

Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=13.9076
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=24.4230
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batc


                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
         

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=16.9299
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=30.8900
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]

Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000



                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
FLOPS=1.58e+13: 100%|███████████████████████████| 10/10 [00:00<00:00, 89.67it/s][A
                                                                                [A
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  5.84it/s]
                                                                                
Processing F

Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=8.56e+01, metric=-100.0000
Iteration 18: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.11e+01, metric=-100.0000
Iteration 19: learning_rate=1.8


                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
         

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=19.8782
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=37.0111
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                
Processing FL

Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e


                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A


Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=22.9240
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=42.9577
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  5.61it/s]
                                                                                [A
Processing

Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-


                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
         

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=26.0938
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=48.7563
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]

Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000



                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                [A
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
FLOPS=1.00e+15: 12it [00:00, 120.45it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [0

Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=8.56e+01, metric=-100.0000
Iteration 18: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.11e+01, metric=-100.0000
Iteration 19: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.45e+01, metric=-100.0000
Iteration 20: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.66e+01, metric=-100.0000
Iteration 21: learning_rate=1.

                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]

Iteration 22: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.87e+01, metric=-100.0000
Iteration 23: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.92e+01, metric=-100.0000
Iteration 24: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.95e+01, metric=-100.0000



                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  8.75it/s]
FLOPS=1.00e+15: 25it [00:00, 126.45it/s][A
Processing FLOPS: 100%|███████████████████████████| 6/6 [00:00<00:00,  6.91it/s]


Iteration 25: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.97e+01, metric=-100.0000


In [4]:
def test_train_function(flops: float, learning_rate: float, batch_size: float, num_layers: float) -> float:
    flops_scale = np.log10(flops)
    optimal_lr = 0.1 * (10 ** (-flops_scale / 15))
    optimal_batch = 32 * (2 ** (flops_scale - 12))
    optimal_layers = 2 * (flops_scale - 11)

    penalties = [
        -2.0 * (flops_scale - 11) * (np.log10(learning_rate / optimal_lr)) ** 2,
        -0.5 * (flops_scale - 11) * (np.log10(batch_size / optimal_batch)) ** 2,
        -1.0 * (flops_scale - 11) * ((num_layers - optimal_layers) / optimal_layers) ** 2,
    ]

    if min(penalties) < -10 * (flops_scale - 11):
        return -100

    return 10 * np.log10(flops) - 100 + sum(penalties)

scaling_parameters = [
    ScalingParameter("learning_rate", 0.01, (1e-6, 1.0), True),
    ScalingParameter("batch_size", 256, (32, 16384), True),
    ScalingParameter("num_layers", 5, (1, 100), False),
]

laws, results, history = get_scaling_laws(
    params=scaling_parameters,
    flops_values=np.logspace(12, 15, 6).tolist(),
    train_fn=test_train_function,
    verbose=True,
    metric_name="performance",
)

Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s]
FLOPS=1.00e+12:   0%|                                    | 0/10 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:   0%|                                   | 0/6 [00:00<?, ?it/s][A
          

Iteration 1: learning_rate=1.00e-02, batch_size=2.56e+02, num_layers=5.00e+00, metric=17.2622
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=10.0597
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=16.8585
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_s


FLOPS=3.98e+12:   0%|                                    | 0/10 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processin

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=13.9076
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=24.4230
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s][A
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s]
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s]
                                                                                
Processing FLOPS:  17%|████▌                      | 1/6 [00:00<00:00,  5.50it/s]
                   

Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=8.56e+01, metric=-100.0000
Iteration 18: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.11e+01, metric=-100.0000
Iteration 19: learning_rate=1.8


FLOPS=1.58e+13:   0%|                                    | 0/10 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s][A
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s][A
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s][A
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s][A
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s][A
                                                                                
Processin

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=16.9299
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=30.8900
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch


Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
FLOPS=1.58e+13: 14it [00:00, 139.23it/s][A
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                

Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=8.56e+01, metric=-100.0000
Iteration 18: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.11e+01, metric=-100.0000
Iteration 19: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.45e+01, metric=-100.0000
Iteration 20: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.66e+01, metric=-100.0000
Iteration 21: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.79e+01, metric=-100.0000


                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
Processing FLOPS:  33%|█████████                  | 2/6 [00:00<00:00,  7.39it/s]
                                                                                
FLOPS=1.58e+13: 25it [00:00, 191.25it/s]          | 2/6 [00:00<00:00,  7.39it/s]
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s]

Iteration 22: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.87e+01, metric=-100.0000
Iteration 23: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.92e+01, metric=-100.0000
Iteration 24: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.95e+01, metric=-100.0000
Iteration 25: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.97e+01, metric=-100.0000



FLOPS=6.31e+13:   0%|                                    | 0/10 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processin

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=19.8782
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=37.0111
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000


Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s][A
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s]
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s]
                                                                                
Processing FLOPS:  50%|█████████████▌             | 3/6 [00:00<00:00,  7.45it/s]
FLOPS=6.31e+13: 14it [00:00, 135.82it/s][A
                                                        

Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=8.56e+01, metric=-100.0000
Iteration 18: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.11e+01, metric=-100.0000
Iteration 19: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.45e+01, metric=-100.0000
Iteration 20: learning_rate=1.


FLOPS=2.51e+14:   0%|                                    | 0/10 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=22.9240
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=42.9577
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


                                                                                

Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
[A                                                                             


Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s][A
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
FLOPS=2.51e+1

Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e

Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
                                                                                
FLOPS=2.51e+14: 22it [00:00, 113.80it/s]

Iteration 20: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.66e+01, metric=-100.0000
Iteration 21: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.79e+01, metric=-100.0000
Iteration 22: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.87e+01, metric=-100.0000
Iteration 23: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.92e+01, metric=-100.0000


Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
                                                                                
Processing FLOPS:  67%|██████████████████         | 4/6 [00:00<00:00,  7.30it/s]
FLOPS=2.51e+14: 24it [00:00, 113.05it/s][A
                                                                                
FLOPS=2.51e+14: 25it [00:00, 115.44it/s]█         | 4/6 [00:00<00:00,  7.30it/s]
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]

Iteration 24: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.95e+01, metric=-100.0000
Iteration 25: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.97e+01, metric=-100.0000



FLOPS=1.00e+15:   0%|                                    | 0/10 [00:00<?, ?it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processin

Iteration 1: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 2: learning_rate=1.96e-04, batch_size=2.56e+02, num_layers=5.00e+00, metric=26.0938
Iteration 3: learning_rate=5.11e-03, batch_size=2.56e+02, num_layers=5.00e+00, metric=48.7563
Iteration 4: learning_rate=2.61e-05, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 5: learning_rate=7.51e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 6: learning_rate=5.11e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 7: learning_rate=2.74e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000
Iteration 8: learning_rate=1.87e-06, batch_size=2.56e+02, num_layers=5.00e+00, metric=-100.0000


                                                                                
FLOPS=1.00e+15:  80%|█████████████████████▌     | 8/10 [00:00<00:00, 259.39it/s]

Iteration 9: learning_rate=1.87e-06, batch_size=3.47e+02, num_layers=5.00e+00, metric=-100.0000


Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s][A
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
[A                                     

Iteration 10: learning_rate=1.87e-06, batch_size=1.51e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 11: learning_rate=1.87e-06, batch_size=3.76e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 12: learning_rate=1.87e-06, batch_size=6.59e+03, num_layers=5.00e+00, metric=-100.0000
Iteration 13: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=5.00e+00, metric=-100.0000


                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                                                                                
Processing FLOPS:  83%|██████████████████████▌    | 5/6 [00:00<00:00,  6.01it/s]
                            

Iteration 14: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=3.88e+01, metric=-100.0000
Iteration 15: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=6.22e+01, metric=-100.0000
Iteration 16: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=7.66e+01, metric=-100.0000
Iteration 17: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=8.56e+01, metric=-100.0000
Iteration 18: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.11e+01, metric=-100.0000
Iteration 19: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.45e+01, metric=-100.0000
Iteration 20: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.66e+01, metric=-100.0000
Iteration 21: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.79e+01, metric=-100.0000
Iteration 22: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.87e+01, metric=-100.0000
Iteration 23: learning_rate=1.87e-06, batch_size=9.68e+03, num_layers=9.92e+01, metric=-100.0000
Iteration 24: learning_rate=1.

In [5]:
laws, results, history

([ScalingLaw(name='learning_rate', slope=0.0, intercept=-4.583592135001262, log_space=True, r2=0.0),
  ScalingLaw(name='batch_size', slope=0.0, intercept=2.4082399653118496, log_space=True, r2=0.0),
  ScalingLaw(name='num_layers', slope=0.0, intercept=5.0, log_space=False, r2=0.0),
  ScalingLaw(name='performance', slope=0.0, intercept=-100.0, log_space=False, r2=0.0)],
 [{'flops': 1000000000000.0,
   'parameters': {'learning_rate': 2.6086022527814775e-05,
    'batch_size': 256.0,
    'num_layers': 5.0},
   'performance': -100},
  {'flops': 3981071705534.969,
   'parameters': {'learning_rate': 2.6086022527814775e-05,
    'batch_size': 256.0,
    'num_layers': 5.0},
   'performance': -100},
  {'flops': 15848931924611.11,
   'parameters': {'learning_rate': 2.6086022527814775e-05,
    'batch_size': 256.0,
    'num_layers': 5.0},
   'performance': -100},
  {'flops': 63095734448019.43,
   'parameters': {'learning_rate': 2.6086022527814775e-05,
    'batch_size': 256.0,
    'num_layers': 5.0},