In [27]:
import json
import pandas as pd
import altair as alt

alt.data_transformers.enable("vegafusion")

DataTransformerRegistry.enable('vegafusion')

In [28]:
# Load data from JSON file
file_path = "../stream-k/gemm_tiles_one_shot_trace_rank0.json"
with open(file_path, 'r') as file:
    data = json.load(file)

# Prepare data for plotting
max_tile_id = max(cell["tile_id"] for cell in data)
print(f"JSON file contains {max_tile_id} tiles")
max_tile_id=64

all_timeline_data = []

for tile in data:
    tile_id = tile['tile_id']
    if tile_id > max_tile_id:
        continue
    categories = ['gemm', 'comm', 'poll', 'op']
    
    for category in categories:
        all_timeline_data.append({
            'Tile ID': str(tile_id),  # Convert to string for categorical sorting
            'Category': category,
            'Tile-Category': f"Tile {tile_id} - {category}",  # Unique row per category
            'Start': tile[f'{category}_begin'],
            'End': tile[f'{category}_end'],
            'Duration': tile[f'{category}_end'] - tile[f'{category}_begin']
        })

# Convert to DataFrame
df = pd.DataFrame(all_timeline_data)

JSON file contains 4095 tiles


In [29]:
def plot_timeline(df):
    # Estimate height per tile-category row
    tile_category_count = df['Tile-Category'].nunique()
    chart_height = max(400, tile_category_count * 15)

    # Create Altair chart with distinct y-axis entries per tile-category
    chart = alt.Chart(df).mark_bar().encode(
        x=alt.X('Start:Q', title='Time', scale=alt.Scale(nice=True)),
        x2='End:Q',
        y=alt.Y('Tile-Category:N', title='Tile & Category', sort=[]),
        color=alt.Color('Category:N', scale=alt.Scale(domain=['gemm', 'comm', 'poll', 'op'], range=['red', 'purple', 'green', 'blue'])),
        tooltip=['Tile ID', 'Category', 'Start', 'End', 'Duration']
    ).properties(
        width=900,
        height=chart_height
    ).interactive()

    # Enable scrolling
    chart = chart.configure_view(
        continuousHeight=400,
        strokeWidth=0
    )

    # Show the chart
    chart.show()




In [30]:
plot_timeline(df)