In [None]:
%matplotlib inline

In [None]:
import logging
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.style as mplstyle
import numpy as np
from matplotlib import pyplot as plt
from numpy.typing import NDArray

from fedmoe_plots.plotting_utils import (
    configure_logging_for_jupyter,
    run_matplotlib_preamble,
)
from fedmoe_plots.wall_time_model import ExperimentWallTime

configure_logging_for_jupyter()
color_palette, line_styles, patterns = run_matplotlib_preamble(
    custom_fonts=True,  # Fixed! Now Montserrat will be preserved with LaTeX fallbacks
    use_inverted_style=False,
    backend="inline",  # Use inline backend for Jupyter
)

log = logging.getLogger("wall_clock_time_model.ipynb")

# Figure size
fig_size = (5, 4)  # (4, 3)
fig_size1 = (5, 4)  # (8, 3)
fig_size2 = (5, 4)  # (8, 3)

log.info("Font configuration applied successfully!")
log.info("Font family: %s", mpl.rcParams["font.family"])
log.info("Font serif: %s...", mpl.rcParams["font.serif"][:3])  # Show first 3 fonts

In [None]:
methods_to_label = {
    "ddp": r"\texttt{DDP}",
    "fedavg": r"\texttt{FedAvg}",
    "localadam": r"\texttt{LocalAdam}",
    "desync": r"\texttt{DES-LOC-Adam}",
}

In [None]:
experiment = ExperimentWallTime(
    precision="fp16",
    dataset_size=int(40 * 10e9),
    n_model_parameters=int(135 * 10e6),
    n_workers=4,
    worker_flops_per_second=1 * 10e12,
    worker_mfu=0.5,
    p2p_network_latency=10e-3,
    equivalent_communication_steps=100,
)

In [None]:
# model_size = int(135 * 10e6)  # 135M
# model_size = int(1.7 * 10e9)  # 1.7B
model_size = int(1.3 * 10e9)  # 1.7B
# model_size = int(8.4 * 10e9)  # GAIA-2 (8.4B)
# model_size = int(1 * 10e11)  # 100B
global_batch_size = 1024
# global_batch_size = 256  # GAIA-2
sequence_length = 2048
# sequence_length = 12600  # GAIA-2
# n_sequential_steps = 3_072  # 135M
# n_sequential_steps = 20_000  # 1.7B
n_sequential_steps = 50_688  # 1.3B
# n_sequential_steps = 460_000  # GAIA-2 (8.4B)
dataset_size = global_batch_size * sequence_length * n_sequential_steps
n_workers = 4
# n_workers = 256  # GAIA-2 (8.4B)
# worker_flops_per_second = 2 * 312 * 10e12  # 2xH100
worker_flops_per_second = 312 * 10e12  # 1xH100
# worker_mfu = 0.36
# worker_mfu = 0.91
# worker_mfu = 0.40  # GAIA-2 (8.4B)
worker_mfu = 0.24  # 1.3B
precision = "fp16"
p2p_network_latency = 10e-5
fedavg_sync_frequency = 256
localadam_sync_frequency = 256
desync_sync_frequencies = [256, 3 * 256, 6 * 256]
experiments: dict[str, ExperimentWallTime] = {
    "fedavg": ExperimentWallTime(
        precision=precision,
        dataset_size=dataset_size,
        n_model_parameters=model_size,
        n_workers=n_workers,
        worker_flops_per_second=worker_flops_per_second,
        worker_mfu=worker_mfu,
        p2p_network_latency=p2p_network_latency,
        equivalent_communication_steps=int(n_sequential_steps / fedavg_sync_frequency),
    ),
    "localadam": ExperimentWallTime(
        precision=precision,
        dataset_size=dataset_size,
        n_model_parameters=model_size,
        n_workers=n_workers,
        worker_flops_per_second=worker_flops_per_second,
        worker_mfu=worker_mfu,
        p2p_network_latency=p2p_network_latency,
        equivalent_communication_steps=int(
            3 * n_sequential_steps / localadam_sync_frequency
        ),
    ),
    "desync": ExperimentWallTime(
        precision=precision,
        dataset_size=dataset_size,
        n_model_parameters=model_size,
        n_workers=n_workers,
        worker_flops_per_second=worker_flops_per_second,
        worker_mfu=worker_mfu,
        p2p_network_latency=p2p_network_latency,
        equivalent_communication_steps=sum(
            int(n_sequential_steps / freq) for freq in desync_sync_frequencies
        ),
    ),
    "ddp": ExperimentWallTime(
        precision=precision,
        dataset_size=dataset_size,
        n_model_parameters=model_size,
        n_workers=n_workers,
        worker_flops_per_second=worker_flops_per_second,
        worker_mfu=worker_mfu,
        p2p_network_latency=p2p_network_latency,
        equivalent_communication_steps=n_sequential_steps,
    ),
}

In [None]:
bandwidths = 5 * np.logspace(4, 11, 10000, base=10)

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("P2P Bandwidth (Gbits/s)")
ax.set_ylabel("Wall-Clock time (s)")
ax.grid(which="both", linestyle="--", linewidth=0.5)
for i, (method, experiment) in enumerate(experiments.items()):
    plt.plot(
        10e-9 * bandwidths,
        [experiment.total_time(b) for b in bandwidths],
        label=methods_to_label[method],
        color=color_palette[i],
    )
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1)
ax.set_xscale("log")
ax.set_xlabel("P2P Bandwidth (Gbits/s)")
ax.set_ylabel("Compute Utilization")
ax.grid(which="both", linestyle="--", linewidth=0.5)
for i, (method, experiment) in enumerate(experiments.items()):
    compute_time = experiment.compute_time()
    plt.plot(
        10e-9 * bandwidths,
        [compute_time / experiment.total_time(b) for b in bandwidths],
        label=methods_to_label[method],
        color=color_palette[i],
    )
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("P2P Bandwidth (Gbits/s)")
ax.set_ylabel("Communication time (s)")
ax.grid(which="both", linestyle="--", linewidth=0.5)
for i, (method, experiment) in enumerate(experiments.items()):
    compute_time = experiment.compute_time()
    plt.plot(
        10e-9 * bandwidths,
        [experiment.total_time(b) - compute_time for b in bandwidths],
        label=methods_to_label[method],
        color=color_palette[i],
    )
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("P2P Bandwidth (Gbits/s)")
ax.set_ylabel("Compute time (s)")
ax.grid(which="both", linestyle="--", linewidth=0.5)
for i, (method, experiment) in enumerate(experiments.items()):
    compute_time = experiment.compute_time()
    plt.plot(
        10e-9 * bandwidths,
        [compute_time for _b in bandwidths],
        label=methods_to_label[method],
        color=color_palette[i],
    )
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1)
ax.set_xscale("log")
ax.set_xlabel("P2P Bandwidth (Gbits/s)")
ax.set_ylabel("Compute Utilization")
ax.grid(which="both", linestyle="--", linewidth=0.5)
lines: dict[float, dict[str, float]] = {}
fewer_bandwidths = 4 * np.logspace(4, 11, 10, base=10)
offset = 0.012
for i, (method, experiment) in enumerate(experiments.items()):
    for b in fewer_bandwidths:
        x_b = 10e-9 * b
        if x_b not in lines:
            lines[x_b] = {}
        if method not in lines[x_b]:
            lines[x_b][method] = compute_time / experiment.total_time(b)
    compute_time = experiment.compute_time()
    plt.plot(
        10e-9 * bandwidths,
        [compute_time / experiment.total_time(b) for b in bandwidths],
        label=methods_to_label[method],
        color=color_palette[i],
    )
# Plot small text boxes on top of the data point with the percentual gain
for x_data, (methods_dict) in lines.items():
    # Transform sorted methods to percentages compared to the the slowest method
    min_value = min(methods_dict.values())
    deltas = {method: value - min_value for method, value in methods_dict.items()}
    # Plot the textbox only for the highest percentage
    max_delta_method = max(deltas, key=deltas.get)
    max_delta = deltas[max_delta_method]
    plt.text(
        x_data,
        max_delta + min_value + offset,
        f"{max_delta:.2f}%",
        fontsize=8,
        ha="center",
        va="bottom",
        color=color_palette[7],
        bbox={
            "boxstyle": "round,pad=0.2",
            "facecolor": "white",
            "edgecolor": color_palette[7],
            "alpha": 0.8,
        },
    )

plt.legend()
plt.show()