In [2]:
import json
import os
import sys

sys.path.append('../')
sys.path.append('../../')
sys.path.append('../src')
sys.path.append('../prompts')
sys.path.append('../src/llmperf')

# Run multiple models through the benchmarking process

In [3]:
## chinese_rag_bundle
# results_dir = '../data/bundle_tests/chinese_rag_bundle/switching_time/20251031-163047.211562'
# results_dir = '../data/bundle_tests/chinese_rag_bundle/switching_time/20251031-180601.530151'

## 3.1 8b bundle
results_dir = '../data/bundle_tests/3.1_8b/switching_time/20251107-104802.785450'


# Analyze metrics through models

In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


from benchmarking.utils import read_perf_eval_json_files

## Read the input json file

# Analyze switching time

__Note:__ This analysis will work if a Bundle endpoint is used. Users will be able to test and compare performance metrics for different experts.

In [5]:
import pandas as pd
import re
from typing import Optional

def find_uuid(file_name: str) -> Optional[str]:
    match = re.search(r'[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}', file_name)
    uuid = None
    if match:
        uuid = match.group()
    else:
        raise ValueError(f"UUID not found in filename {file_name}")
        
    return uuid


# post processing individual request json files
def read_json_files_to_df(directory: str) -> pd.DataFrame:
    data = []

    # Iterate through all files in the directory
    for filename in os.listdir(directory):
        if filename.endswith('individual_responses.json'):
            model_name = '_'.join(filename.split('_')[2:3])
            input_tokens = filename.split('_')[3]
            concurrent_requests = int(filename.split('_')[5])
            file_path = os.path.join(directory, filename)
            # Read the JSON file
            with open(file_path, 'r') as file:
                json_data = json.load(file)
                num_requests = len(json_data)
                # Extract relevant fields from each JSON object and append to the data list
                for item in json_data:
                    if pd.isnull(item['error_code']):
                        data.append(
                            {
                                'start_time': item['start_time'],
                                'end_time': item['end_time'],
                                'server_ttft_s': item['server_ttft_s'],
                                'client_ttft_s': item['client_ttft_s'],
                                'model_name': model_name,
                                'uuid': find_uuid(filename),
                                'input_tokens': input_tokens,
                                'concurrent_requests': concurrent_requests,
                                'filename': filename,
                                'num_requests': num_requests,
                            }
                        )

    # Create a DataFrame from the data list
    df = pd.DataFrame(data)
    return df

# Get the DataFrame
df = read_json_files_to_df(results_dir)
# df_2 = read_json_files_to_df(results_dir_2)

# df = pd.concat([df_1, df_2], ignore_index=True)

# transforming str to date time for sorting
df['start_time'] = pd.to_datetime(df['start_time'])
df = df.sort_values(by=['start_time'])

# transforming back to str for plotting
df['start_time_short'] = df['start_time'].dt.strftime(date_format='%H:%M:%S')
# # Convert datetime to full string with microseconds
df['start_time_str'] = df['start_time'].dt.strftime('%Y-%m-%d %H:%M:%S.%f')
# df.head(10)

In [6]:
from typing import List, Tuple, Optional

def get_grouping_and_batching_info(df: pd.DataFrame) -> Tuple[List[int], List[int], pd.DataFrame]:
    """Generate grouping and batching info from DataFrame and add them as columns."""
    df = df.sort_values('end_time').reset_index(drop=True)
    df['group'] = (df['server_ttft_s'] != df['server_ttft_s'].shift()).cumsum()

    # Count requests per group
    consecutive_counts = (
        df.groupby(['group', 'server_ttft_s'])
        .size()
        .reset_index(name='consecutive_count')
    )

    # Lists at the group level
    requests_grouping = consecutive_counts['consecutive_count'].tolist()
    requests_batching = [1 << (x - 1).bit_length() for x in requests_grouping]

    # Map group-level info back to each row
    group_mapping = consecutive_counts[['group', 'consecutive_count']].set_index('group')['consecutive_count']
    batching_mapping = {grp: 1 << (count - 1).bit_length() for grp, count in group_mapping.items()}

    df['requests_grouping_per_request'] = df['group'].map(group_mapping)
    df['requests_batching_per_request'] = df['group'].map(batching_mapping)

    return requests_grouping, requests_batching, df.drop(columns=['group'])

dict_groupings = {}
dfs_with_batching = []

for filename in os.listdir(results_dir):
    if 'individual_responses' in filename:
        model_finame: str
        in_tok_finame: int
        out_tok_finame: int
        concurrency_finame: Optional[int]
        qps_finame: Optional[float]

        # model_finame, in_tok_finame, out_tok_finame, concurrency_finame, qps_finame = extract_file_info(
        #     filename
        # )
                
        df_file = df[df['filename'] == filename].copy()
        # df_file = df_file[df_file['error_code'].isnull()]
        
        requests_grouping, requests_batching, df_with_batching = get_grouping_and_batching_info(df_file)
        
        dfs_with_batching.append(df_with_batching[['uuid', 'start_time', 'end_time', 'server_ttft_s', 'requests_grouping_per_request', 'requests_batching_per_request']])
        
        dict_groupings[filename] = {
            'requests_grouping': requests_grouping,
            'requests_batching': requests_batching
        }
    
df_groupings = pd.DataFrame.from_dict(dict_groupings).T
df = df.merge(df_groupings, left_on='filename', right_index=True, how='left')
# print(df.sort_values(['filename','end_time'])[['filename','end_time','server_ttft_s','requests_grouping','requests_batching']].head(30))

dfs_with_batching = pd.concat(dfs_with_batching, ignore_index=True)
df = df.merge(dfs_with_batching, on=['uuid', 'start_time', 'end_time', 'server_ttft_s'], how='left')

df.sort_values(['start_time'])[['filename', 'start_time', 'end_time','server_ttft_s','requests_grouping','requests_batching','requests_grouping_per_request','requests_batching_per_request']].head(50)


Unnamed: 0,filename,start_time,end_time,server_ttft_s,requests_grouping,requests_batching,requests_grouping_per_request,requests_batching_per_request
0,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:04.135129,10:48:06.647972,1.270036,[1],[1],1,1
1,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:07.569415,10:48:08.537529,0.094227,"[1, 2]","[1, 2]",1,1
2,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:07.569683,10:48:10.013078,1.333117,"[1, 2]","[1, 2]",2,2
3,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:07.569854,10:48:10.011126,1.333117,"[1, 2]","[1, 2]",2,2
4,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:10.995325,10:48:13.411547,1.49737,"[3, 2]","[4, 2]",3,4
5,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:10.995534,10:48:13.717357,0.162506,"[3, 2]","[4, 2]",2,2
6,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:10.995690,10:48:13.717839,0.162506,"[3, 2]","[4, 2]",2,2
7,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:10.996208,10:48:13.412533,1.49737,"[3, 2]","[4, 2]",3,4
8,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:10.996423,10:48:13.409190,1.49737,"[3, 2]","[4, 2]",3,4
9,synthetic_0_Meta-Llama-3-1-8B-Instruct_3900_10...,2025-11-07 10:48:14.609951,10:48:17.546411,1.792799,"[1, 8]","[1, 8]",8,8


In [7]:
def calculate_switching_time(df: pd.DataFrame) -> pd.DataFrame:
    """
    Calculate switching time per uuid based on max requests_batching_per_request.

    Args:
        df (pd.DataFrame): DataFrame containing per-request info, must include:
            - uuid
            - server_ttft_s
            - requests_batching_per_request

    Returns:
        pd.DataFrame: DataFrame with columns:
            uuid, max_requests_batching_per_request, switching_time
    """
    results = []

    for uuid, group in df.groupby('uuid'):
        
        group = group.sort_values(by='start_time')
        
        # Get max batching for this uuid
        max_batching = group['requests_batching_per_request'].max()

        # Filter only rows with that max batching value
        max_batch_rows = group[group['requests_batching_per_request'] == max_batching]

        # Calculate switching time
        highest_ttft = max_batch_rows['server_ttft_s'].max()
        lowest_ttft = max_batch_rows['server_ttft_s'].min()
        switching_time = highest_ttft - lowest_ttft

        results.append({
            'uuid': uuid,
            'model': group['model_name'].iloc[0],
            'start_time': group['start_time'].iloc[0],
            'input_tokens': group['input_tokens'].iloc[0],
            'num_requests': group['num_requests'].iloc[0],
            'concurrent_requests': group['concurrent_requests'].iloc[0],
            'max_requests_batching_per_request': max_batching,
            'switching_time': switching_time
        })

    return pd.DataFrame(results)

# Now calculate switching times per uuid
df_switching = calculate_switching_time(df)
df_switching.sort_values(by='start_time').head(50)

Unnamed: 0,uuid,model,start_time,input_tokens,num_requests,concurrent_requests,max_requests_batching_per_request,switching_time
65,b2086129-572b-466f-85d2-a9960c5938f4,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:04.135129,3900,1,1,1,0.0
75,d21809b7-7f81-4185-91cc-dbb4e62a37cc,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:07.569415,3900,3,3,2,0.0
71,c7b1ff1f-3728-465c-8b5e-21d92818145b,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:10.995325,3900,5,5,4,0.0
6,191ddf48-590d-4cdb-837f-37afd08f8445,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:14.609951,3900,9,9,8,0.0
32,5f16410c-1bd8-4993-ae52-748afc308e3d,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:18.426996,3900,16,16,16,0.0
27,582cb8ee-8be9-408d-b873-6f63ee5985af,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:24.614671,8000,1,1,1,0.0
44,76c334b3-329b-46bc-9e6d-bbd6c1d8d3ac,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:27.682627,8000,3,3,4,0.0
69,c48211ce-02ad-4b56-b169-e5311ac2bff1,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:31.343475,8000,5,5,4,0.0
82,eeaaaaaa-8ac3-4864-842a-1234e1368ede,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:35.587984,8000,9,9,8,0.0
1,06ea17ae-d4a0-4ccf-bde5-d1b0b97308f6,Meta-Llama-3-1-8B-Instruct,2025-11-07 10:48:41.732494,8000,16,16,16,0.0


In [12]:
import pandas as pd
import plotly.express as px
import plotly.io as pio

model_family = {
    # 'Llama-8b': [
    #     'Meta-Llama-3-1-8B-Instruct',
    #     'Hermes-3-Llama-3-1-8B',
    #     'LLaMa3-1-8B-Legal-ThaiCCL-Combine',
    #     'Llama-3-1-8B-Magpie-Align-SFT-v0-2',
    #     'Llama-3-1-EIRAI-8B',
    #     'Llama-3-1-OpenScholar-8B',
    #     'natsumura-assistant-1-0-llama-3-1-8b',
    #     'llama-3-1-tulu-2-8b',
    #     'narrativAIV2',
    #     'Llama-3-1-Storm-8B'
    # ],
    # 'Qwen-32b': [
    #     'Qwen3-32B',
    #     'DMind-1',
    #     'DeepSWE-Preview',
    #     'Gazal-R1-32B-sft-merged-preview',
    #     'INFIndo-Qwen3-32B-Preview',
    #     'Qwen3-32B-abliterated',
    #     'Qwen3-32B-bf16',
    #     'Qwen3-32B-Copy',
    #     'Smoothie-Qwen3-32B',
    #     'UIGEN-T3-32B-Preview'
    # ],
    # 'Qwen-14b': [
    #     'Qwen2-5-14B',
    #     '14B-Qwen2-5-Freya-x1',
    #     'EVA-Qwen2-5-14B-v0-0',
    #     'General-Reasoner-Qwen2-5-14B',
    #     'Qwen2-5-Coder-14B',
    #     'SauerkrautLM-v2-14b-SFT',
    #     'SuperNova-Medius',
    #     'Tsunami-1-0-14B-Instruct',
    #     'Virtuoso-Small',
    #     'legml-v0-1'
    # ]
    'Chinese-rag': [
        'Qwen3-32B',    
        'DeepSeek-V3-1-Terminus',  
    ],
    '3.d1_8b': [
        'Meta-Llama-3-1-8B-Instruct',    
    ],
}

# Create config column for every row
df['config'] = df[['model_name', 'input_tokens', 'concurrent_requests']] \
    .apply(lambda x: f'{x[0]}-{x[1]}-{x[2]}', axis=1)

df = df.sort_values(by='start_time_str').reset_index(drop=True)
uuids_per_family = {}

for family_name, members in model_family.items():
    df_family = df[df['model_name'].isin(members)].copy().reset_index(drop=True)
    if df_family.empty:
        print(f"⚠️ No data for {family_name}, skipping...")
        continue

    # Identify switch point for this family
    switch_idx = df_family[(df_family['num_requests'] == 5) &
                           (df_family['concurrent_requests'] == 1)].index.min()
    
    # Filter df_family with records that start with switch_idx
    if switch_idx != -1:
        df_family_filtered = df_family[df_family.index >= switch_idx] 
    uuids_per_family[family_name] = df_family_filtered['uuid'].unique()
    
    if pd.isna(switch_idx):
        print(f"⚠️ No switch point found for {family_name}, skipping...")
        continue

    # Add numeric index for plotting
    df_family['time_idx'] = range(len(df_family))

    # Create scatter plot with config shown in hover for all points
    fig = px.scatter(
        df_family,
        x="time_idx",
        y="server_ttft_s",
        color="config",
        hover_data=[
            "model_name",
            "input_tokens",
            "concurrent_requests",
            "start_time_str",
            "uuid",
            "config",  # show for every point
            "requests_grouping_per_request",
            "requests_batching_per_request",
        ],
        title=f"Scatter Plot of Server TTFT Over Time — {family_name}",
        labels={"time_idx": "Time", "server_ttft_s": "Server TTFT (s)"}
    )

    # Warm-up zone
    fig.add_vrect(
        x0=0, x1=switch_idx - df_family.index.min(),
        fillcolor="orange", opacity=0.1,
        annotation_text="Warm up zone", annotation_position="top left"
    )

    # Switching time zone
    fig.add_vrect(
        x0=switch_idx - df_family.index.min(), x1=len(df_family) - 1,
        fillcolor="blue", opacity=0.05,
        annotation_text="Switching time zone", annotation_position="top left"
    )

    # Custom hovertemplate to display config nicely
    fig.update_traces(
        hovertemplate=(
            "<b>Model config: %{customdata[5]}</b><extra></extra><br>"
            "Server TTFT (s): %{y}<br>"
            "Model: %{customdata[0]}<br>"
            "Input Tokens: %{customdata[1]}<br>"
            "Grouping: %{customdata[6]}<br>"
            "Batching: %{customdata[7]}<br>"
            "Start Time: %{customdata[3]}<br>"
        )
    )

    # Custom x-axis with timestamps
    fig.update_xaxes(
        tickvals=df_family['time_idx'][::max(1, len(df_family) // 20)],
        ticktext=df_family['start_time_str'][::max(1, len(df_family) // 20)],
        tickangle=90
    )

    # Layout settings
    fig.update_layout(
        yaxis_autorange=True,
        legend_title="Model Name + SS + Batch size",
        legend=dict(x=1.02, y=1, bgcolor="rgba(255,255,255,0.8)"),
        margin=dict(l=40, r=200, t=60, b=120),
        hovermode="closest"
    )

    # Show in browser and save
    fig.show(renderer="browser")
    file_name = f"scatter_plot_{family_name}.html"
    pio.write_html(fig, file_name)
    print(f"✅ Saved: {file_name}")


⚠️ No data for Chinese-rag, skipping...
✅ Saved: scatter_plot_3.d1_8b.html


In [9]:
parent_dir = os.path.dirname(results_dir)
current_run = results_dir.split('/')[-1]
consolidated_results_path = os.path.join(parent_dir,f'consolidated_results/{current_run}.xlsx')


df_consolidated_results = pd.read_excel(consolidated_results_path)
df_consolidated_results['model'] = df_consolidated_results.model.str.replace('_', '-')

In [9]:
df_consolidated_with_switching_time = df_consolidated_results.merge(df_switching[['uuid','switching_time']], on='uuid', how='left')

In [11]:
df_consolidated_with_switching_time.head(50)

Unnamed: 0,uuid,name,model,num_input_tokens,num_output_tokens,num_concurrent_requests,server_ttft_s_min,server_ttft_s_mean,server_ttft_s_p50,server_ttft_s_max,...,num_requests_started,num_completed_requests,num_completed_requests_per_min,number_errors,error_code_frequency,requests_grouping,requests_batching,request_batching_frequencies,representative_batch_size,switching_time
0,4631ede2-adc5-4b86-bba4-24c86cff9144,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,78.7198,78.8826,78.7569,79.4029,...,5,5,0.7307,0,{},"[1, 1, 1, 1, 1]","[1, 1, 1, 1, 1]",{1: 5},1,0.68313
1,e086559b-f3dd-4340-9a4f-ad86f70fa541,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,79.2918,79.2918,79.2918,79.2918,...,1,1,0.7292,0,{},[1],[1],{1: 1},1,0.0
2,3e80ae0a-29db-4cc4-8447-f945fe023519,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,79.3102,79.3102,79.3102,79.3102,...,1,1,0.7289,0,{},[1],[1],{1: 1},1,0.0
3,c1ad3eba-0ea6-4e3f-bf45-31cce57b91f8,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,79.3196,79.3196,79.3196,79.3196,...,1,1,0.7292,0,{},[1],[1],{1: 1},1,0.0
4,b2367e9f-e819-4656-9849-1d2005ff065c,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,79.3447,79.3447,79.3447,79.3447,...,1,1,0.7281,0,{},[1],[1],{1: 1},1,0.0
5,1eead728-88db-49d9-8be9-84bda01ce626,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,79.3811,79.3811,79.3811,79.3811,...,1,1,0.7273,0,{},[1],[1],{1: 1},1,0.0
6,9e40b953-9657-47d3-883b-f0c7a3ce7459,synthetic_0_DeepSeek-V3-1-Terminus_128000_100_...,DeepSeek-V3-1-Terminus,128000,100,1,79.2899,79.2899,79.2899,79.2899,...,1,1,0.7278,0,{},[1],[1],{1: 1},1,0.0
7,fdf09d34-850e-4161-8e02-5f18cc02b544,synthetic_0_Qwen3-32B_8000_100_1_stream_fdf09d...,Qwen3-32B,8000,100,1,0.4261,0.4375,0.4285,0.4758,...,5,5,38.5199,0,{},"[1, 1, 1, 1, 1]","[1, 1, 1, 1, 1]",{1: 5},1,0.049667
8,118729eb-3b39-4f29-bfad-ca1d03a9376b,synthetic_0_Qwen3-32B_8000_100_1_stream_118729...,Qwen3-32B,8000,100,1,0.4732,0.4732,0.4732,0.4732,...,1,1,29.2774,0,{},[1],[1],{1: 1},1,0.0
9,e893c619-46da-430e-a691-5b607d5e43d3,synthetic_0_Qwen3-32B_8000_100_5_stream_e893c6...,Qwen3-32B,8000,100,5,0.4257,1.2775,1.4748,1.5332,...,25,25,100.5136,0,{},"[4, 1, 4, 1, 4, 1, 4, 1, 4, 1]","[4, 1, 4, 1, 4, 1, 4, 1, 4, 1]","{4: 20, 1: 5}",4,0.058704


In [78]:
# plots per model family and batch sizes

import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import plotly.io as pio

# --- Create combined model-num_input_tokens column ---
df_consolidated_with_switching_time["model_token"] = (
    df_consolidated_with_switching_time["model"] + "-" + df_consolidated_with_switching_time["num_input_tokens"].astype(str)
)

# --- Sort by num_input_tokens then model ---
df_consolidated_with_switching_time = df_consolidated_with_switching_time.sort_values(by=["num_input_tokens", "model"])

# --- Create figures per model family ---
for family_name, models in model_family.items():
    df_family = df_consolidated_with_switching_time[df_consolidated_with_switching_time["model"].isin(models)]
    df_family = df_family[df_family.uuid.isin(uuids_per_family[family_name])]

    num_concurrents = sorted(df_family["num_concurrent_requests"].unique())
    fig = sp.make_subplots(
        rows=len(num_concurrents),
        cols=3,  # <-- 3 columns now
        shared_xaxes=False,
        vertical_spacing=0.20,
        subplot_titles=[
            f"Switching Time - {family_name} (Batch Size={ncr})" if i % 3 == 0 else
            f"TTFT - {family_name} (Batch Size={ncr})" if i % 3 == 1 else
            f"Output tok/s - {family_name} (Batch Size={ncr})"
            for ncr in num_concurrents for i in range(3)
        ]
    )

    for row_idx, ncr in enumerate(num_concurrents, start=1):
        df_row = df_family[df_family["num_concurrent_requests"] == ncr]

        # Col 1: Switching Time
        fig1 = px.line(
            df_row,
            x="model_token",
            y="switching_time",
            color="num_input_tokens",
            markers=True
        )
        for trace in fig1.data:
            fig.add_trace(trace, row=row_idx, col=1)

        # Col 2: TTFT
        fig2 = px.line(
            df_row,
            x="model_token",
            y="server_ttft_s_p50",
            color="num_input_tokens",
            markers=True
        )
        for trace in fig2.data:
            fig.add_trace(trace, row=row_idx, col=2)

        # Col 3: Output tok/s
        fig3 = px.line(
            df_row,
            x="model_token",
            y="server_output_token_per_s_p50",
            color="num_input_tokens",
            markers=True
        )
        for trace in fig3.data:
            fig.add_trace(trace, row=row_idx, col=3)

    # Rotate all x labels, add standoff, and prevent clipping
    fig.update_xaxes(
        tickangle=60,
        automargin=True,
        title_standoff=20
    )

    # Legend on the left, vertical
    fig.update_layout(
        height=500 * len(num_concurrents),
        width=2000,  # wider to fit 3 columns
        title_text=f"Performance Plots for {family_name}",
        showlegend=True,
        legend=dict(
            orientation="v",
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=-0.15
        ),
        margin=dict(t=120, b=180, l=150, r=40)
    )

    fig.show(renderer="browser")
    file_name = f"{family_name}_performance_plots.html"
    pio.write_html(fig, file_name)
    print(f"✅ Saved: {file_name}")


✅ Saved: Chinese-rag_performance_plots.html


In [40]:
# grouped plots for all families in one figure

import pandas as pd
import plotly.graph_objects as go
import plotly.subplots as sp
import plotly.io as pio

# --- Create combined model-num_input_tokens column ---
df_consolidated_with_switching_time["model_token"] = (
    df_consolidated_with_switching_time["model"] + "-" +
    df_consolidated_with_switching_time["num_input_tokens"].astype(str)
)

# --- Sort by num_input_tokens then model ---
df_consolidated_with_switching_time = df_consolidated_with_switching_time.sort_values(
    by=["num_input_tokens", "model"]
)

# --- Metrics mapping for easy loop ---
metrics = {
    "switching_time": "Switching Time",
    "server_ttft_s_p50": "TTFT",
    "server_output_token_per_s_p50": "Output tok/s"
}

# --- Create subplot grid: 3 rows (families) × 3 cols (metrics) ---
fig = sp.make_subplots(
    rows=len(model_family),
    cols=3,
    subplot_titles=[
        f"{metric_name} - {family_name}"
        for family_name in model_family.keys()
        for metric_name in metrics.values()
    ],
    vertical_spacing=0.15,   # more space between rows
    horizontal_spacing=0.08
)

# --- Loop over families and metrics ---
for row_idx, (family_name, models) in enumerate(model_family.items(), start=1):
    df_family = df_consolidated_with_switching_time[
        df_consolidated_with_switching_time["model"].isin(models)
    ]
    df_family = df_family[df_family.uuid.isin(uuids_per_family[family_name])]

    # Group per model_token + num_input_tokens
    grouped = df_family.groupby(["model_token", "num_input_tokens"])

    # Compute median, min, max per metric across num_concurrent_requests
    agg_df = grouped.agg({
        "switching_time": ["median", "min", "max"],
        "server_ttft_s_p50": ["median", "min", "max"],
        "server_output_token_per_s_p50": ["median", "min", "max"]
    }).reset_index()

    # Flatten column names
    agg_df.columns = [
        "model_token", "num_input_tokens",
        "switching_time_median", "switching_time_min", "switching_time_max",
        "ttft_median", "ttft_min", "ttft_max",
        "outtok_median", "outtok_min", "outtok_max"
    ]

    # Map internal to column names
    metric_map = {
        "switching_time": ("switching_time_median", "switching_time_min", "switching_time_max"),
        "server_ttft_s_p50": ("ttft_median", "ttft_min", "ttft_max"),
        "server_output_token_per_s_p50": ("outtok_median", "outtok_min", "outtok_max")
    }

    for col_idx, (metric, _) in enumerate(metrics.items(), start=1):
        median_col, min_col, max_col = metric_map[metric]

        for token_size, df_token in agg_df.groupby("num_input_tokens"):
            legend_name = f"{family_name}-{token_size}"

            # Add filled area between min and max
            fig.add_trace(
                go.Scatter(
                    x=df_token["model_token"],
                    y=df_token[max_col],
                    mode="lines",
                    line=dict(width=0),
                    hoverinfo="skip",
                    name=legend_name,
                    legendgroup=str(token_size),
                ),
                row=row_idx, col=col_idx
            )
            fig.add_trace(
                go.Scatter(
                    x=df_token["model_token"],
                    y=df_token[min_col],
                    mode="lines",
                    line=dict(width=0),
                    fill="tonexty",
                    name=legend_name,
                    legendgroup=str(token_size),
                    hoverinfo="skip"
                ),
                row=row_idx, col=col_idx
            )

            # Add median line with markers
            fig.add_trace(
                go.Scatter(
                    x=df_token["model_token"],
                    y=df_token[median_col],
                    mode="lines+markers",
                    name=legend_name,
                    legendgroup=str(token_size),
                ),
                row=row_idx, col=col_idx
            )

# --- Layout tweaks ---
fig.update_xaxes(
    tickangle=60,
    automargin=True,
    title_standoff=30  # push axis title further from ticks
)
fig.update_layout(
    height=750 * len(model_family),   # more height per row
    width=2100,
    title_text="Performance Comparison Across Model Families",
    showlegend=True,
    legend=dict(
        orientation="v",
        yanchor="middle",
        y=0.5,
        xanchor="left",
        x=-0.12
    ),
    margin=dict(t=120, b=250, l=80, r=40)  # extra bottom margin
)

fig.show(renderer="browser")
pio.write_html(fig, "all_families_performance_plots.html")
print("✅ Saved: all_families_performance_plots.html")


✅ Saved: all_families_performance_plots.html
