In [107]:
import os
import json
import pandas as pd
import altair as alt

In [108]:
def load_reference_jsons(root_path, filter = ""):
    records = []

    for dirpath, _, filenames in os.walk(root_path):
        if filter in dirpath:
            for file in filenames:
                if file.endswith(".json"):
                    file_path = os.path.join(dirpath, file)
                    try:
                        with open(file_path, 'r') as f:
                            data = json.load(f)
                            records.append(data)
                    except Exception as e:
                        print(f"Error loading {file_path}: {e}")

    return pd.DataFrame(records)

In [109]:
root_path = "../../../slurm_logs/latest"
df = load_reference_jsons(root_path, "reference")
df.head()

Unnamed: 0,world_size,m,n,k,validate,benchmark,datatype,output_file,M,N,K,ms,flops,gemm_ms,gemm_experiments,communication_ms,communication_experiments,algorithm
0,1,8192,8192,28672,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,8192,28672,35.196939,109.335948,34.061827,126,0.795417,126,torch_dist_all_gather
1,4,8192,1024,14336,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,14336,4.14997,231.826401,2.60686,126,1.297372,126,torch_dist_all_gather
2,8,8192,4608,4608,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4608,36864,4.955925,561.57808,3.769611,126,1.062334,126,torch_dist_all_reduce
3,8,8192,4096,1792,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,4096,14336,2.600427,369.967232,1.519383,126,0.979543,126,torch_dist_all_reduce
4,1,8192,8192,28672,True,True,fp32,/work1/amd/muhaawad/git/amd/pdp/iris/slurm_log...,8192,8192,28672,34.759633,110.711487,34.461598,126,0.136722,126,torch_dist_all_reduce


In [110]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_gather' in algorithm:
        title += 'All Gather'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'torch_dist' in algorithm:
        title += ' (Torch Dist/RCCL)'

    filtered_df = filtered_df.sort_values(by=["M", "N", "K", "world_size"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("world_size:O", title="World Size"),
        y=alt.Y("flops:Q", title="FLOPS (GFLOP/s)", scale=alt.Scale(domain=[0, 600])),
        color=alt.Color("world_size:N", title="World Size"),
        column=alt.Column("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        tooltip=["shape", "world_size", "flops"]
    ).properties(
        height=300
    ).configure_axisX(
        labelAngle=0
    ).configure_title(
        anchor="middle",
        fontSize=18,
        font='Helvetica'
    )
    chart = chart.properties(
        title=alt.TitleParams(
            text=title,
            anchor='middle',
            fontSize=18,
            font='Helvetica'
        ),
        height=300
    )
    chart.display()
    chart.save(f'{algorithm}.svg')
    chart.save(f'{algorithm}.png', scale_factor=4)
    chart.save(f'{algorithm}.pdf')    
    


In [111]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_gather' in algorithm:
        title += 'All Gather'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'torch_dist' in algorithm:
        title += ' (Torch Dist/RCCL)'

    filtered_df = filtered_df.sort_values(by=["M", "N", "K", "world_size"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        y=alt.Y("flops:Q", title="FLOPS (GFLOP/s)", scale=alt.Scale(domain=[0, 600])),
        color=alt.Color("shape:N", legend=None),
        column=alt.Column("world_size:N", title="World Size"),
        tooltip=["shape", "world_size", "flops"]
    ).properties(
        height=300
    ).configure_axisX(
        labelAngle=45
    )

    chart = chart.configure_title(
        anchor="middle",
        fontSize=18,
        font='Helvetica'
    )
    chart = chart.properties(
        title=alt.TitleParams(
            text=title,
            anchor='middle',
            fontSize=18,
            font='Helvetica'
        ),
        height=300
    )   
    chart.save(f'{algorithm}.svg')
    chart.save(f'{algorithm}.png', scale_factor=4)
    chart.save(f'{algorithm}.pdf')
    display(chart)
    


In [112]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_gather' in algorithm:
        title += 'All Gather'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'torch_dist' in algorithm:
        title += ' (Torch Dist/RCCL)'

    filtered_df = filtered_df.sort_values(by=["M", "N", "K", "world_size"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("world_size:O", title="World Size"),
        y=alt.Y("gemm_ms:Q", title="GEMM Time (ms)"),
        color=alt.Color("world_size:N", title="World Size"),
        column=alt.Column("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        tooltip=["shape", "world_size", "flops", "communication_ms", "gemm_ms"]
    ).properties(
        height=300
    ).configure_axisX(
        labelAngle=0
    ).configure_title(
        anchor="middle",
        fontSize=18,
        font='Helvetica'
    )
    chart = chart.properties(
        title=alt.TitleParams(
            text=title,
            anchor='middle',
            fontSize=18,
            font='Helvetica'
        ),
        height=300
    )
    chart.display()
    


In [113]:
for algorithm in df['algorithm'].unique():
    filtered_df = df[df['algorithm'] == algorithm].copy()
    filtered_df["shape"] = filtered_df.apply(lambda row: f"M{row['M']}N{row['N']}K{row['K']}", axis=1)

    title = ''
    if 'all_gather' in algorithm:
        title += 'All Gather'
    if 'all_reduce' in algorithm:
        title += 'All Reduce'
    if 'torch_dist' in algorithm:
        title += ' (Torch Dist/RCCL)'

    filtered_df = filtered_df.sort_values(by=["M", "N", "K", "world_size"])

    chart = alt.Chart(filtered_df).mark_bar().encode(
        x=alt.X("world_size:O", title="World Size"),
        y=alt.Y("communication_ms:Q", title="Communication Time (ms)"),
        color=alt.Color("world_size:N", title="World Size"),
        column=alt.Column("shape:N", title="", sort=filtered_df["shape"].unique().tolist()),
        tooltip=["shape", "world_size", "flops", "communication_ms", "gemm_ms"]
    ).properties(
        height=300
    ).configure_axisX(
        labelAngle=0
    ).configure_title(
        anchor="middle",
        fontSize=18,
        font='Helvetica'
    )
    chart = chart.properties(
        title=alt.TitleParams(
            text=title,
            anchor='middle',
            fontSize=18,
            font='Helvetica'
        ),
        height=300
    )
    chart.display()
    
