In [None]:
import wandb

api = wandb.Api()
entity, project = "sayakpaul", "keras-xla-benchmarks"  
runs = api.runs(entity + "/" + project) 
print(f"Total runs: {len(runs)}")

In [None]:
from model_mapping import MODEL_NAME_MAPPING

all_variants = [
    variant for k in MODEL_NAME_MAPPING for variant in MODEL_NAME_MAPPING[k]
]
all_variants

In [None]:
import pandas as pd

resolutions = []
accelerators = []

model_families = []
model_variants = []
xla_status = []

flops = []
params = []
throughputs = []

for run in runs:
    run_config = run.config
    run_summary = run.summary._json_dict

    if (
        all_variants[0] in run_config["variant"]
        or all_variants[1] in run_config["variant"]
    ):
        model_families.append(run_config["family"])
        model_variants.append(run_config["variant"])
        resolutions.append(run_config["resolution"])
        xla_status.append(run_config["xla"])

        accelerator_name = run.name.split("@")[-1].split("-")[1]
        accelerators.append(accelerator_name)

        flops.append(run_summary["FLOPs (giga)"])
        params.append(run_summary["Num parameters (million)"])
        throughputs.append(run_summary["Throughput (samples/sec)"])

viz_df = pd.DataFrame(
    {
        "model_family": model_families,
        "model_variant": model_variants,
        "resolution": resolutions,
        "xla": xla_status,
        "accelerator": accelerators,
        "flop (giga)": flops,
        "params (million)": params,
        "throughput (samples/sec)": throughputs,
    }
)
viz_df.head()

In [None]:
def plot_topk_per_accelerator(
    accelerator="a100", topk=10, resolution=320, xla_status=True
):
    filtered_df = viz_df[viz_df["accelerator"] == accelerator]
    subset_df = filtered_df.query(f"resolution == {resolution} and xla == {xla_status}")
    topk_df = subset_df.nlargest(topk, ["throughput (samples/sec)"])
    return topk_df

In [None]:
# Adapted from
# https://github.com/nlp-with-transformers/notebooks/blob/main/08_model-compression.ipynb

import matplotlib.pyplot as plt


def plot_metrics(df, savefig=False):
    for model_variant in df["model_variant"]:
        filtered = df.query(f"model_variant == '{model_variant}'")
        plt.scatter(
            filtered["flop (giga)"],
            filtered["throughput (samples/sec)"],
            alpha=0.5,
            s=filtered["params (million)"] * 5,
            label=model_variant,
            marker="o",
        )

    legend = plt.legend(bbox_to_anchor=(1, 1))
    for handle in legend.legendHandles:
        handle.set_sizes([20])

    plt.ylabel("Throughput (samples/sec)", fontsize=14)
    plt.xlabel("FLOPS (giga)", fontsize=14)

    accelerator_name = df["accelerator"].unique()[0]
    resolution = df["resolution"].unique()[0]
    xla_status = df["xla"].unique()[0]
    plt.title(
        f"Accelerator: {accelerator_name}, Resolution: {resolution}, XLA: {xla_status}",
        fontsize=14,
    )
    if not savefig:
        plt.show()
    else:
        plot_name = f"{accelerator_name}_{resolution}_{xla_status}.png"
        plt.savefig(plot_name, dpi=300, bbox_inches="tight")

In [None]:
a100_df = plot_topk_per_accelerator("a100")
plot_metrics(a100_df)

In [None]:
a100_res_640_df = plot_topk_per_accelerator("a100", resolution=640)
plot_metrics(a100_res_640_df)

In [None]:
grouped = viz_df.groupby(["resolution", "accelerator"])[
    "throughput (samples/sec)"
].idxmax()
result = viz_df.loc[grouped, viz_df.columns]
result