In [1]:
import os
import json
import pandas as pd
import altair as alt

In [2]:
def load_reference_jsons(root_path, filter = ""):
    records = []

    for dirpath, _, filenames in os.walk(root_path):
        if filter in dirpath:
            for file in filenames:
                if file.endswith(".json"):
                    file_path = os.path.join(dirpath, file)
                    try:
                        with open(file_path, 'r') as f:
                            data = json.load(f)
                            records.append(data)
                    except Exception as e:
                        print(f"Error loading {file_path}: {e}")

    return pd.DataFrame(records)

In [3]:
root_path = "../../../slurm_logs/reference-latest"
df = load_reference_jsons(root_path, "reference")
df.head()

Unnamed: 0,world_size,m,n,k,validate,benchmark,datatype,output_file,M,N,K,ms,flops,algorithm
0,4,8192,4096,3584,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,14336,4.344059,221.468625,torch_dist_all_reduce
1,8,8192,4096,2304,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,18432,2.750998,449.637013,torch_dist_all_reduce
2,2,8192,4096,8192,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,16384,9.226184,119.172957,torch_dist_all_reduce
3,4,8192,4096,3072,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,12288,3.998046,206.259164,torch_dist_all_reduce
4,1,8192,4096,22528,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,22528,15.531021,97.342502,torch_dist_all_reduce


In [4]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_gather' in algorithm:
        title += 'All Gather'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'torch_dist' in algorithm:
        title += ' (Torch Dist/RCCL)'

    filtered_df = filtered_df.sort_values(by=["K", "world_size"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("world_size:O", title="World Size"),
        y=alt.Y("flops:Q", title="FLOPS (GFLOP/s)"),
        color=alt.Color("world_size:N", title="World Size"),
        column=alt.Column("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        tooltip=["shape", "world_size", "flops"]
    ).properties(
        title=title,
        height=300
    ).configure_axisX(
        labelAngle=0
    ).configure_title(
        anchor="middle",
        fontSize=18,
        font='Helvetica'
    )

    chart.display()


In [5]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_gather' in algorithm:
        title += 'All Gather'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'torch_dist' in algorithm:
        title += ' (Torch Dist/RCCL)'

    filtered_df = filtered_df.sort_values(by=["K", "shape"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        y=alt.Y("flops:Q", title="FLOPS (GFLOP/s)"),
        color=alt.Color("shape:N", legend=None),
        column=alt.Column("world_size:N", title="World Size"),
        tooltip=["shape", "world_size", "flops"]
    ).properties(
        title=title,
        height=300
    ).configure_axisX(
        labelAngle=45
    )

    chart = chart.configure_title(
        anchor="middle",     # Center the title
        fontSize=18,
        font='Helvetica'
    )

    display(chart)
