In [132]:
from llm_execution_time_predictor.train_utils import (
    build_stage_features,
    train_linear_predictor,
    train_tree_predictor,
)
import pandas as pd
import json

In [133]:
import pandas as pd
import os
# List your files here
# files = [
#     "arxiv_summarization_rps_3.jsonl",
#     "decode_decode_profiling_tp0.jsonl",
#     "prefill_prefill_profiling_tp0.jsonl",
#     "prefill_profiling_chunked_cache_prefix_caching_prefill_cache_profiling_tp0.jsonl",
#     "prefill_with_prefix_caching_prefill_cache_profiling_tp0.jsonl",
#     "splitwise_code_rps_5.jsonl",
#     "splitwise_code_rps_10.jsonl"
# ]
model_name = "Qwen_Qwen3_4B_TP_1"
files = os.listdir("profile_output/Qwen_Qwen3_4B_TP_1")
files_folder = [
    os.path.join("profile_output/Qwen_Qwen3_4B_TP_1", f) for f in files
]
dfs = [pd.read_json(f, lines=True) for f in files_folder]
combined_df = pd.concat(dfs, ignore_index=True)
combined_df.drop(["throughput", "timestamp", "process_id"], axis=1, inplace=True)

In [134]:
combined_df.to_csv("combined_profile_results.csv", index=False)

In [135]:
combined_df

Unnamed: 0,batch_size,total_token_length,skew,combined_seq_lens,cached_prefix_lens,new_extend_lens,total_extend_len,latency,forward_mode,cache_percent,chunked
0,1,1,0.0,[1],[0],[1],1.0,0.012013,prefill,,
1,1,1,0.5,[1],[0],[1],1.0,0.011156,prefill,,
2,1,1,1.0,[1],[0],[1],1.0,0.010913,prefill,,
3,1,1,1.5,[1],[0],[1],1.0,0.010928,prefill,,
4,1,2,0.0,[2],[0],[2],2.0,0.011433,prefill,,
...,...,...,...,...,...,...,...,...,...,...,...
11839,1,1735,,[1735],[1734],[1],,0.008086,decode,,
11840,1,1736,,[1736],[1735],[1],,0.008114,decode,,
11841,1,1737,,[1737],[1736],[1],,0.008091,decode,,
11842,1,1738,,[1738],[1737],[1],,0.008102,decode,,


In [136]:
# combined_df.to_csv("combined_profile_results.csv", index=False)
import numpy as np
combined_df["total_extend_len"] = combined_df["total_extend_len"].fillna(
    combined_df["new_extend_lens"].apply(sum)
)

combined_df["input_len"] = combined_df["combined_seq_lens"].apply(
    lambda x: np.mean(x) if isinstance(x, list) else 0
)
prefill_df = combined_df[combined_df["forward_mode"] == "prefill"]
decode_df = combined_df[combined_df["forward_mode"] == "decode"]

In [137]:
import numpy as np
import pandas as pd

def _safe_list(x):
    """Guard against NaN or scalar entries that sneak in."""
    return x if isinstance(x, (list, tuple, np.ndarray)) else [x]

def build_stage_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Feature-engineer latency predictors for *both* prefill & decode rows.
    The function assumes the raw frame still contains list-columns:
      - combined_seq_lens
      - cached_prefix_lens
      - new_extend_lens
    and extra scalar columns such as batch_size, latency, skew, cache_percent.
    """
    df = df.copy()

    # ────────────────────────────────────────────────────────
    # 1.  Sequence-length distribution stats
    # ────────────────────────────────────────────────────────
    df["len_max"] = df["combined_seq_lens"].apply(lambda x: np.max(_safe_list(x)))
    df["len_min"] = df["combined_seq_lens"].apply(lambda x: np.min(_safe_list(x)))
    df["len_std"] = df["combined_seq_lens"].apply(lambda x: np.std(_safe_list(x)))
    df["len_p90"] = df["combined_seq_lens"].apply(
        lambda x: np.percentile(_safe_list(x), 90)
    )
    df["len_p95"] = df["combined_seq_lens"].apply(
        lambda x: np.percentile(_safe_list(x), 95)
    )

    # ────────────────────────────────────────────────────────
    # 2.  Cached-prefix stats
    # ────────────────────────────────────────────────────────
    df["cached_sum"] = df["cached_prefix_lens"].apply(
        lambda x: np.sum(_safe_list(x))
    )
    df["cached_max"] = df["cached_prefix_lens"].apply(
        lambda x: np.max(_safe_list(x))
    )
    df["cached_ratio"] = df["cached_sum"] / df["total_token_length"].clip(lower=1)

    # ────────────────────────────────────────────────────────
    # 3.  Extension-stats (per call)
    # ────────────────────────────────────────────────────────
    df["extend_sum"] = df["new_extend_lens"].apply(lambda x: np.sum(_safe_list(x)))
    df["extend_max"] = df["new_extend_lens"].apply(lambda x: np.max(_safe_list(x)))
    df["extend_mean"] = df["new_extend_lens"].apply(lambda x: np.mean(_safe_list(x)))
    df["extend_std"] = df["new_extend_lens"].apply(lambda x: np.std(_safe_list(x)))
    df["extend_p90"] = df["new_extend_lens"].apply(
        lambda x: np.percentile(_safe_list(x), 90)
    )

    # ────────────────────────────────────────────────────────
    # 4.  Skew / imbalance and memory-pressure proxies
    # ────────────────────────────────────────────────────────
    df["imbalance"] = df["len_max"] / df["len_min"].replace(0, np.nan)
    df["cache_percent"] = df.get("cache_percent", np.nan)  # may already exist

    # ────────────────────────────────────────────────────────
    # 5.  Stage flag
    # ────────────────────────────────────────────────────────
    df["is_prefill"] = (df["forward_mode"] == "prefill").astype(int)

    # ────────────────────────────────────────────────────────
    # 6.  “Classic” cost proxies (now using len_max instead of avg)
    # ────────────────────────────────────────────────────────
    # ATTENTION-FLOPs proxy: O(batch * len_max²) for prefill, O(batch * len_max) for decode
    df["prod_ext_ctx"] = np.where(
        df["is_prefill"] == 1,
        df["batch_size"] * (df["len_max"] ** 2),
        df["batch_size"] * df["len_max"],
    )

    # Tokens newly processed this step
    df["num_new_tokens"] = np.where(
        df["is_prefill"] == 1,
        df["extend_sum"],       # sum of prompt tokens
        df["batch_size"],       # one per sequence in decode
    )

    # Total context tokens “live” during this step
    df["num_context_tokens"] = df["batch_size"] * df["len_max"]

    # ────────────────────────────────────────────────────────
    # 7.  Target
    # ────────────────────────────────────────────────────────
    df["time"] = df["latency"]

    # ------- EXTRA SEQ / EXTENSION STATS -------
    df["len_mean"]   = df["combined_seq_lens"].apply(lambda x: np.mean(_lst(x)))
    df["len_median"] = df["combined_seq_lens"].apply(lambda x: np.median(_lst(x)))
    df["len_range"]  = df["len_max"] - df["len_min"]
    df["len_p99"]    = df["combined_seq_lens"].apply(lambda x: np.percentile(_lst(x), 99))
    df["len_cv"]     = df["len_std"] / df["len_mean"].clip(lower=1)

    df["extend_min"]   = df["new_extend_lens"].apply(lambda x: np.min(_lst(x)))
    df["extend_median"]= df["new_extend_lens"].apply(lambda x: np.median(_lst(x)))
    df["extend_p99"]   = df["new_extend_lens"].apply(lambda x: np.percentile(_lst(x), 99))
    df["extend_cv"]    = df["extend_std"] / df["extend_mean"].clip(lower=1)

    # ------- RATIOS & INTERACTIONS -------
    df["prompt_ratio"]     = df["extend_sum"] / df["total_token_length"].clip(lower=1)
    df["cached_peak_ratio"]= df["cached_max"] / df["len_max"].clip(lower=1)
    df["B_len_mean"]       = df["batch_size"] * df["len_mean"]
    df["B_len_max_sq"]     = df["batch_size"] * (df["len_max"] ** 2)
    df["cache_len_prod"]   = df["cache_percent"] * df["len_max"]

    # ------- LOG-SPACE -------
    for col in ["len_max", "prod_ext_ctx", "num_context_tokens"]:
        df[f"log_{col}"] = np.log1p(df[col])


    # ────────────────────────────────────────────────────────
    # 8.  Select final columns
    # ────────────────────────────────────────────────────────
    feature_cols = [
        # token & attention proxies
        "num_new_tokens", "prod_ext_ctx", "num_context_tokens",
        # sequence-distribution
        "len_max", "len_min", "len_std", "len_p90", "len_p95",
        # cache stats
        "cached_sum", "cached_max", "cached_ratio",
        # extension stats
        "extend_max", "extend_mean", "extend_std", "extend_p90",
        # imbalance / batch
        "batch_size", "imbalance", "skew",
        # memory pressure
        "cache_percent",
        # stage
        "is_prefill",
        # sequence shape
        "len_mean","len_median","len_range","len_p99","len_cv",
        # extension shape
        "extend_min","extend_median","extend_p99","extend_cv",
        # ratios & interactions
        "prompt_ratio","cached_peak_ratio",
        "B_len_mean","B_len_max_sq","cache_len_prod",
        # log space
        "log_len_max","log_prod_ext_ctx","log_num_context_tokens",
    ]

    # Keep any hardware-specific knobs if present (they’re cheap to one-hot later)
#    hw_cols = [c for c in ("gpu_name", "num_gpu", "dtype", "flash_attn_flag") if c in df]
    return df[feature_cols]


In [138]:
from sklearn.linear_model import LinearRegression
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.ensemble import RandomForestRegressor
import pandas as pd
import json
from typing import List, cast


def preprocess_input_for_prediction(
    batch_size, avg_context_len, gpu, mode="prefill"
) -> List[float]:
    if mode == "prefill":
        num_new_tokens = batch_size * avg_context_len
        prod_ext_ctx = batch_size * (avg_context_len**2)
        num_context_tokens = avg_context_len * batch_size
        num_batch_size = batch_size
    else:
        num_new_tokens = batch_size
        prod_ext_ctx = batch_size * avg_context_len
        num_context_tokens = avg_context_len * batch_size
        num_batch_size = batch_size
    return [num_new_tokens, prod_ext_ctx, num_context_tokens, num_batch_size]


def build_stage_features(df: pd.DataFrame, stage: str) -> pd.DataFrame:
    """
    Build input features for latency modeling based on the inference stage.
    Returns
    -------
    pd.DataFrame
        A dataframe with engineered features:
        - num_new_tokens: total tokens processed/generated (models token compute)
        - prod_ext_ctx: proxy for attention cost (quadratic or linear depending on stage)
        - num_context_tokens: total context tokens active (models memory + cache pressure)
        - batch_size: degree of parallelism
        - time: latency target to be predicted
    """
    df = df.copy()
    # TOOD: Currently I just use the average input length, but I should take in the actual batch composition seq lens, 
    if stage == "prefill":
        df["num_new_tokens"] = df["batch_size"] * df["input_len"]
        df["prod_ext_ctx"] = df["batch_size"] * (df["input_len"] ** 2)
        df["num_context_tokens"] = df["batch_size"] * df["input_len"]
        df["time"] = df["latency"]

    elif stage == "decode":
        # One token is generated per request per step
        # Each new token attends to all previous context (linear in output_len)
        df["num_new_tokens"] = df["batch_size"]
        df["prod_ext_ctx"] = df["batch_size"] * df["input_len"]
        df["num_context_tokens"] = df["batch_size"] * df["input_len"]
        df["time"] = df["latency"]
    else:
        raise ValueError("stage must be either 'prefill' or 'decode'")

    return df[
        ["num_new_tokens", "prod_ext_ctx", "num_context_tokens", "batch_size", "time"]
    ]


def train_linear_predictor(train_df: pd.DataFrame, name):
    """
    Train a linear regression model to predict latency based on engineered features.
    """
    X_train = train_df[
        ["num_new_tokens", "prod_ext_ctx", "num_context_tokens", "batch_size"]
    ].to_numpy(dtype=np.float32)
    y_train = train_df["time"].to_numpy(dtype=np.float32)
    lr_model = LinearRegression()
    lr_model.fit(X_train, y_train)

    y_pred_lr = lr_model.predict(X_train)

    print(f"Linear Regression: {name}")
    print(f"Train RMSE: {np.sqrt(mean_squared_error(y_train, y_pred_lr)) * 1000:.2f}ms")
    print(f"Train MAE: {mean_absolute_error(y_train, y_pred_lr) * 1000:.2f}ms")
    print(f"Train R2: {r2_score(y_train, y_pred_lr):.4f}")
    return lr_model


def train_tree_predictor(train_df: pd.DataFrame, name):
    """
    Train a decision tree model to predict latency based on engineered features.
    """

    # Extract features and target
    X_train = train_df[
        ["num_new_tokens", "prod_ext_ctx", "num_context_tokens", "batch_size"]
    ].to_numpy(dtype=np.float32)
    y_train = train_df["time"].to_numpy(dtype=np.float32)

    # Fit Decision Tree Regressor
    tree_model = RandomForestRegressor(
        n_estimators=10, random_state=42, min_samples_leaf=2, max_depth=12
    )
    tree_model.fit(X_train, y_train)

    # Predict and evaluate
    y_pred_tree = tree_model.predict(X_train)

    print(f"Decision Tree: {name}")
    print(
        f"Train RMSE: {np.sqrt(mean_squared_error(y_train, y_pred_tree)) * 1000:.2f}ms"
    )
    print(f"Train MAE: {mean_absolute_error(y_train, y_pred_tree) * 1000:.2f}ms")
    print(f"Train R2: {r2_score(y_train, y_pred_tree):.4f}")
    return tree_model


In [139]:
# convert the 'batch_lens' column into a plain Python list

In [140]:
train_df_prefill = build_stage_features(prefill_df, stage="prefill")
train_df_decode = build_stage_features(decode_df, stage="decode")

In [141]:
lr_model_prefill = train_linear_predictor(train_df_prefill, "prefill")
lr_model_decode = train_linear_predictor(train_df_decode, "decode")

Linear Regression: prefill
Train RMSE: 127.58ms
Train MAE: 101.62ms
Train R2: 0.3447
Linear Regression: decode
Train RMSE: 2.33ms
Train MAE: 1.27ms
Train R2: 0.9716


In [142]:
tr_model_prefill = train_tree_predictor(train_df_prefill, "prefill")
tr_model_decode = train_tree_predictor(train_df_decode, "decode")

Decision Tree: prefill
Train RMSE: 56.07ms
Train MAE: 25.04ms
Train R2: 0.8734
Decision Tree: decode
Train RMSE: 1.23ms
Train MAE: 0.12ms
Train R2: 0.9921


In [143]:
import numpy as np
import pandas as pd
from typing import Tuple
from llm_execution_time_predictor.train_utils import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
)
import matplotlib.pyplot as plt
import altair as alt
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from plotly.subplots import make_subplots

# ---------------------------------------------------------------------
# 1.  Stage-aware feature builder
# ---------------------------------------------------------------------
def build_stage_features(df: pd.DataFrame, *, stage: str) -> pd.DataFrame:
    """
    Create latency-prediction features for *one* stage (“prefill” or “decode”).
    The function expects df to contain list-typed columns:
      combined_seq_lens, cached_prefix_lens, new_extend_lens
    plus scalar columns batch_size, latency, total_token_length, cache_percent.
    """
    if stage not in ("prefill", "decode"):
        raise ValueError("stage must be 'prefill' or 'decode'")

    df = df.copy()

    # ───── helper to coerce NaN / scalars into a 1-element list ─────
    _lst = lambda x: x if isinstance(x, (list, tuple, np.ndarray)) else [x]

    df["len_max"] = df["combined_seq_lens"].apply(lambda x: np.max(_lst(x)))
    df["len_min"] = df["combined_seq_lens"].apply(lambda x: np.min(_lst(x)))
    df["len_std"] = df["combined_seq_lens"].apply(lambda x: np.std(_lst(x)))
    df["len_p90"] = df["combined_seq_lens"].apply(lambda x: np.percentile(_lst(x), 90))
    df["len_p95"] = df["combined_seq_lens"].apply(lambda x: np.percentile(_lst(x), 95))

    df["cached_sum"] = df["cached_prefix_lens"].apply(lambda x: np.sum(_lst(x)))
    df["cached_max"] = df["cached_prefix_lens"].apply(lambda x: np.max(_lst(x)))
    df["cached_ratio"] = df["cached_sum"] / df["total_token_length"].clip(lower=1)

    df["extend_sum"]  = df["new_extend_lens"].apply(lambda x: np.sum(_lst(x)))
    df["extend_max"]  = df["new_extend_lens"].apply(lambda x: np.max(_lst(x)))
    df["extend_mean"] = df["new_extend_lens"].apply(lambda x: np.mean(_lst(x)))
    df["extend_std"]  = df["new_extend_lens"].apply(lambda x: np.std(_lst(x)))
    df["extend_p90"]  = df["new_extend_lens"].apply(lambda x: np.percentile(_lst(x), 90))

    df["imbalance"]      = df["len_max"] / df["len_min"].replace(0, np.nan)
    df["cache_percent"]  = df.get("cache_percent", np.nan)
    if stage == "prefill":
        df["num_new_tokens"]    = df["extend_sum"]                         # prompt tokens
        df["prod_ext_ctx"]      = df["batch_size"] * (df["len_max"] ** 2)  # O(B·L²)
    else: 
        df["num_new_tokens"]    = df["batch_size"]                         # one per sequence
        df["prod_ext_ctx"]      = df["batch_size"] * df["len_max"]         # O(B·L)

    df["num_context_tokens"] = df["batch_size"] * df["len_max"]
    df["time"] = df["latency"]

    df["len_mean"]   = df["combined_seq_lens"].apply(lambda x: np.mean(_lst(x)))
    df["len_median"] = df["combined_seq_lens"].apply(lambda x: np.median(_lst(x)))
    df["len_range"]  = df["len_max"] - df["len_min"]
    df["len_p99"]    = df["combined_seq_lens"].apply(lambda x: np.percentile(_lst(x), 99))
    df["len_cv"]     = df["len_std"] / df["len_mean"].clip(lower=1)

    df["extend_min"]   = df["new_extend_lens"].apply(lambda x: np.min(_lst(x)))
    df["extend_median"]= df["new_extend_lens"].apply(lambda x: np.median(_lst(x)))
    df["extend_p99"]   = df["new_extend_lens"].apply(lambda x: np.percentile(_lst(x), 99))
    df["extend_cv"]    = df["extend_std"] / df["extend_mean"].clip(lower=1)

    df["prompt_ratio"]     = df["extend_sum"] / df["total_token_length"].clip(lower=1)
    df["cached_peak_ratio"]= df["cached_max"] / df["len_max"].clip(lower=1)
    df["B_len_mean"]       = df["batch_size"] * df["len_mean"]
    df["B_len_max_sq"]     = df["batch_size"] * (df["len_max"] ** 2)
    df["cache_len_prod"]   = df["cache_percent"] * df["len_max"]

    for col in ["len_max", "prod_ext_ctx", "num_context_tokens"]:
        df[f"log_{col}"] = np.log1p(df[col])

    keep = [
        "num_new_tokens", "prod_ext_ctx", "num_context_tokens",
        "len_max", "len_min", "len_std", "len_p90", "len_p95",
        "cached_sum", "cached_max", "cached_ratio",
        "extend_max", "extend_mean", "extend_std", "extend_p90",
        "batch_size", "imbalance", "skew",
        "cache_percent",
        "len_mean","len_median","len_range","len_p99","len_cv",
        "extend_min","extend_median","extend_p99","extend_cv",
        "prompt_ratio","cached_peak_ratio",
        "B_len_mean","B_len_max_sq","cache_len_prod",
        "log_len_max","log_prod_ext_ctx","log_num_context_tokens",
        "time",
    ]
    # keep hardware knobs if present
    hw_cols = [c for c in ("gpu_name", "num_gpu", "dtype", "flash_attn_flag") if c in df]
    return df[keep + hw_cols]
from sklearn.model_selection import train_test_split
import xgboost as xgb
from typing import Dict

def find_closest_bin(value, bins: Dict):
    for key, bin_val in bins.items():
        if value <= key:
            return bin_val
    return bins[max(bins.keys())] 
    
def find_closest_bin_for_all_lens(values, bins: Dict[int, int]):
    return [find_closest_bin(v, bins) for v in values]

def train_tree_predictor(train_df: pd.DataFrame, stage: str, model_type="lgm", train_split_eval=True, model_name="Qwen3", latency_normalization_dict=None) -> object:
    """
    Dummy LightGBM regressor; swap this out with your real trainer.
    """
    import lightgbm as lgb
    X = train_df.drop(columns=["time"])
    y = train_df["time"]

    if model_type == "lgm":
        model = lgb.LGBMRegressor(min_data_in_leaf=1, verbose=-1)
    elif model_type == "xgboost":
        model = xgb.XGBRegressor(
            n_estimators=100, random_state=42, min_child_weight=1, max_depth=12
        )
    else:
        model = RandomForestRegressor(
            n_estimators=10, random_state=42, min_samples_leaf=2, max_depth=12
        ) 
    if train_split_eval:
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.01, random_state=42)
        model.fit(X_train, y_train)
        print(f"[{stage}] training rows =", len(X_train), "validation rows =", len(X_val))

        train_pred = model.predict(X_train)
        train_rmse = np.sqrt(mean_squared_error(y_train, train_pred)) * 1000
        train_mae = mean_absolute_error(y_train, train_pred) * 1000
        train_r2 = r2_score(y_train, train_pred)
        
        val_pred = model.predict(X_val)
        val_rmse = np.sqrt(mean_squared_error(y_val, val_pred)) * 1000
        val_mae = mean_absolute_error(y_val, val_pred) * 1000
        val_r2 = r2_score(y_val, val_pred)
        
        print(f"[{stage}] LGBM")
        print(f"Train RMSE: {train_rmse:.2f} ms")
        print(f"Train MAE : {train_mae:.2f} ms")
        print(f"Train R²  : {train_r2:.2f}")
        print(f"Val   RMSE: {val_rmse:.2f} ms")
        print(f"Val   MAE : {val_mae:.2f} ms")
        print(f"Val   R²  : {val_r2:.2f}")

        df_eval = pd.DataFrame({
            'true': y_val * 1000,
            'pred': val_pred * 1000,
            'len_max': X_val['len_max'],
            'batch_size': X_val['batch_size'],
        })
        df_eval['error_ms'] = (df_eval['true'] - df_eval['pred']).abs()
        df_eval['len_bin'] = pd.qcut(df_eval['len_max'], q=6, duplicates='drop')
        mae_by_len = df_eval.groupby('len_bin', observed=False)['error_ms'].mean().reset_index()
        mae_by_len['len_center'] = mae_by_len['len_bin'].apply(lambda x: (x.left + x.right) / 2)
        mae_by_len['len_bin_str'] = mae_by_len['len_bin'].astype(str)

        df_eval['bs_bin'] = pd.cut(df_eval['batch_size'], bins=16)
        df_eval['len_bin2'] = pd.cut(df_eval['len_max'], bins=16)
        df_eval['bs_bin_str'] = pd.Categorical(df_eval['bs_bin'].astype(str), 
                                            categories=[str(cat) for cat in df_eval['bs_bin'].cat.categories],
                                            ordered=True)

        df_eval['len_bin2_str'] = pd.Categorical(df_eval['len_bin2'].astype(str), 
                                                categories=[str(cat) for cat in df_eval['len_bin2'].cat.categories],
                                                ordered=True)
        mae_by_len['error_ms_percent'] = [(mae_by_len.iloc[i]['error_ms'] / find_closest_bin(mae_by_len.iloc[i]['len_center'], latency_normalization_dict)) * 100 for i in range(len(mae_by_len))]

        pivot = df_eval.pivot_table(
            index='bs_bin_str',
            columns='len_bin2_str',
            values='error_ms',
            aggfunc='mean',
            observed=False
        )

        # ---- Create Subplots ----
        fig = make_subplots(
            rows=2, cols=2,
            specs=[
                [{"type": "xy"},     {"type": "xy"}],
                [{"type": "xy"},     {"type": "domain"}]
            ],
            subplot_titles=(
                "MAE vs. Sequence Length",
                "MAE Heatmap",
                "Normalized MAE vs. SeqLen"
                "Metrics"
            ),
            horizontal_spacing=0.3,
            vertical_spacing=0.5
        )

        # 1) MAE vs. Sequence Length
        fig.add_trace(
            go.Scatter(
                x=mae_by_len['len_center'],
                y=mae_by_len['error_ms'],
                mode='markers+lines',
                name='MAE'
            ),
            row=1, col=1
        )
        fig.update_xaxes(title_text="Seq Length", row=1, col=1)
        fig.update_yaxes(title_text="MAE (ms)", row=1, col=1)

        # 2) Heatmap
        fig.add_trace(
            go.Heatmap(
                z=pivot.values,
                x=pivot.columns,
                y=pivot.index,
                colorbar=dict(
                    title="MAE (ms)",
                    len=0.5,      # half the height of the subplot
                    y=0.75,       # 75% up the entire figure
                    yanchor="middle"
                )
            ),
            row=1, col=2
        )
        fig.update_xaxes(title_text="Seq Length", row=1, col=2, tickangle=-45)
        fig.update_yaxes(title_text="Batch Size", row=1, col=2)

        fig.add_trace(
            go.Scatter(
                x=mae_by_len['len_center'],
                y=mae_by_len['error_ms_percent'],
                mode='markers+lines',
                name='Normalized MAE by Avg Seq Len'
            ),
            row=2, col=1
        )
        fig.update_xaxes(title_text="Seq Length", row=2, col=1)
        fig.update_yaxes(title_text="Error (%)", row=2, col=1)

        # 4) Metrics box (using a basic table for clarity)
        fig.add_trace(
            go.Table(
                header=dict(values=["Metric", "Train", "Validation"]),
                cells=dict(values=[
                    ["RMSE", "MAE", "R²"],
                    [f"{train_rmse:.2f}", f"{train_mae:.2f}", f"{train_r2:.2f}"],
                    [f"{val_rmse:.2f}",  f"{val_mae:.2f}",  f"{val_r2:.2f}"]
                ])
            ),
            row=2, col=2
        )

        # final layout tweaks
        fig.update_layout(
            height=600,
            width=1000,
            title_text=f"{stage} {model_type} {model_name} Simulation Accuracy",
            showlegend=False,
            margin=dict(l=50, r=50, t=80, b=50)
        )

        fig.show() 
    else:
        model.fit(X, y)
        print(f"[{stage}] training rows =", len(train_df))
        
        y_pred = model.predict(X)
        rmse = np.sqrt(mean_squared_error(y, y_pred)) * 1000  # → milliseconds
        mae  = mean_absolute_error(y, y_pred) * 1000          # → milliseconds
        r2   = r2_score(y, y_pred)

        print(f"[{stage}] RandomForest")
        print(f"  Train RMSE: {rmse:.2f} ms")
        print(f"  Train MAE : {mae:.2f} ms")
        print(f"  Train R²  : {r2:.4f}")
    return model

# ---------------------------------------------------------------------
# 3.  End-to-end pipeline
# ---------------------------------------------------------------------
def train_latency_models(raw_df: pd.DataFrame, model_type="lgm", train_split_eval=True, model_name="Qwen3") -> Tuple:
    """
    Preprocess → feature-engineer → train separate tree models for
    prefill and decode.  Returns (prefill_model, decode_model).
    """
    df = raw_df.copy()

    # fill in total_extend_len if missing
    df["total_extend_len"] = df["total_extend_len"].fillna(df["new_extend_lens"].apply(sum))

    # ensure list-typed columns stay lists when reading from CSV
    list_cols = ["combined_seq_lens", "cached_prefix_lens", "new_extend_lens"]
    for c in list_cols:
        df[c] = df[c].apply(lambda v: v if isinstance(v, list) else eval(v))

    # split by stage
    prefill_df = df[df["forward_mode"] == "prefill"]
    decode_df  = df[df["forward_mode"] == "decode"]

    # build features
    train_df_prefill = build_stage_features(prefill_df, stage="prefill")
    train_df_decode  = build_stage_features(decode_df,  stage="decode")
    forward_mode_to_filter = "prefill" 
    df_bs1_prefill = combined_df[(combined_df["batch_size"] == 1) & (combined_df["skew"] == 0) & (combined_df["forward_mode"] == forward_mode_to_filter) & (combined_df["cache_percent"] == 0.0)].filter(
        ["total_token_length", "latency"]
    ).set_index("total_token_length") * 1000

    prefill_latency_for_normalization = df_bs1_prefill.to_dict()['latency']


    forward_mode_to_filter = "decode" 
    decode_bs_base = combined_df[(combined_df["skew"] == 0) & (combined_df["forward_mode"] == forward_mode_to_filter) & (combined_df["total_token_length"] > 500)].groupby("batch_size")["latency"].mean() * 1000
    decode_latency_for_normalization = decode_bs_base.to_dict()
    
    # train tree models
    model_prefill = train_tree_predictor(train_df_prefill, "prefill", model_type=model_type, train_split_eval=train_split_eval, model_name=model_name, latency_normalization_dict=prefill_latency_for_normalization)
    model_decode  = train_tree_predictor(train_df_decode,  "decode", model_type=model_type, train_split_eval=train_split_eval, model_name=model_name, latency_normalization_dict=decode_latency_for_normalization)
    return model_prefill, model_decode


model_prefill, model_decode = train_latency_models(combined_df, model_type="lgm", train_split_eval=True, model_name=model_name)

[prefill] training rows = 5592 validation rows = 57
[prefill] LGBM
Train RMSE: 1.04 ms
Train MAE : 0.60 ms
Train R²  : 1.00
Val   RMSE: 2.13 ms
Val   MAE : 1.15 ms
Val   R²  : 1.00


[decode] training rows = 6133 validation rows = 62
[decode] LGBM
Train RMSE: 0.09 ms
Train MAE : 0.05 ms
Train R²  : 1.00
Val   RMSE: 0.11 ms
Val   MAE : 0.05 ms
Val   R²  : 1.00


In [144]:
forward_mode_to_filter = "prefill" 
df_bs1_prefill = combined_df[(combined_df["batch_size"] == 1) & (combined_df["skew"] == 0) & (combined_df["forward_mode"] == forward_mode_to_filter) & (combined_df["cache_percent"] == 0.0)].filter(
    ["total_token_length", "latency"]
).set_index("total_token_length") * 1000

prefill_latency_for_normalization = df_bs1_prefill.to_dict()['latency']
print(prefill_latency_for_normalization)


{1: 12.461132369935001, 2: 11.688116006553, 4: 11.735336855053001, 8: 11.697398498654001, 16: 11.520516127347001, 32: 11.301406659185002, 64: 11.661914177238001, 128: 11.711215600371002, 256: 16.398821026086, 512: 22.928835824131003, 1024: 40.913268923759006, 2048: 79.179731197655, 4096: 159.32115819305102, 8192: 354.56836968660303, 10240: 471.14018164575106, 13000: 639.3684521317481, 16384: 868.3463800698511}


<!-- [prefill] RandomForest
Train RMSE: 0.08 ms
Train MAE : 0.94 ms
Train R²  : 1.00
Val   RMSE: 0.04 ms
Val   MAE : 0.74 ms
Val   R²  : 1.00
[decode] RandomForest
Train RMSE: 0.03 ms
Train MAE : 0.08 ms
Train R²  : 0.99
Val   RMSE: 0.06 ms
Val   MAE : 0.27 ms
Val   R²  : 0.95 -->



In [145]:
# [prefill] RandomForest
# Train RMSE: 0.08 ms
# Train MAE : 0.94 ms
# Train R²  : 1.00
# Val   RMSE: 0.04 ms
# Val   MAE : 0.74 ms
# Val   R²  : 1.00
# [decode] RandomForest
# Train RMSE: 0.03 ms
# Train MAE : 0.08 ms
# Train R²  : 0.99
# Val   RMSE: 0.06 ms
# Val   MAE : 0.27 ms
# Val   R²  : 0.95

# [prefill] LGBM
# Train RMSE: 0.03 ms
# Train MAE : 0.59 ms
# Train R²  : 1.00
# Val   RMSE: 0.03 ms
# Val   MAE : 0.58 ms
# Val   R²  : 1.00
# [decode] LGBM
# Train RMSE: 0.00 ms
# Train MAE : 0.05 ms
# Train R²  : 1.00
# Val   RMSE: 0.00 ms
# Val   MAE : 0.05 ms
# Val   R²  : 1.00
