In [162]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path

In [163]:
root = Path("../tune/").resolve()

In [164]:
df = pd.read_parquet(root)

In [165]:
del df["device_cap"]
del df["device_name"]

In [166]:
df.sort_values("wall_time")

Unnamed: 0,block_j,block_k,warp,num_stages,n,h,d,cuda_time,wall_time,dtype,method
299860,16,32,2,4,32,8,32,0.077824,0.097275,torch.bfloat16,fwd
319809,16,32,2,4,32,8,32,0.077824,0.097275,torch.bfloat16,fwd
319747,16,16,2,4,32,8,32,0.077824,0.097752,torch.bfloat16,fwd
299798,16,16,2,4,32,8,32,0.077824,0.097752,torch.bfloat16,fwd
294882,16,32,8,3,32,2,32,0.078496,0.097990,torch.bfloat16,fwd
...,...,...,...,...,...,...,...,...,...,...,...
5674,16,16,1,1,1024,8,64,346.864807,347.433090,torch.bfloat16,bwd_dkvb
5676,16,16,1,1,1024,8,64,363.462646,364.422321,torch.bfloat16,bwd_dkvb
5675,16,16,1,1,1024,8,64,374.146057,374.810457,torch.bfloat16,bwd_dkvb
5673,16,16,1,1,1024,8,64,411.280396,411.404133,torch.bfloat16,bwd_dkvb


In [167]:
mean_df = df.groupby(
    [
        "n",
        "h",
        "d",
        "block_j",
        "block_k",
        "warp",
        "num_stages",
        "method",
        "dtype",
    ]
).mean()

In [169]:
for m_name, m_df in mean_df.groupby(["method"]):
    print(m_name)
    for n_, n_df in m_df.groupby(["n"]):
        print(n_)
        for n, f in n_df.groupby(["h", "d"]):
            print(f.sort_values("cuda_time")[["cuda_time"]].head(5))

('bwd_dkvb',)
(32,)
                                                                 cuda_time
n  h d  block_j block_k warp num_stages method   dtype                    
32 1 32 16      16      8    6          bwd_dkvb torch.bfloat16   0.118566
                64      8    4          bwd_dkvb torch.bfloat16   0.126880
                             6          bwd_dkvb torch.bfloat16   0.127712
                        4    6          bwd_dkvb torch.bfloat16   0.130048
        32      16      1    4          bwd_dkvb torch.bfloat16   0.134029
                                                                 cuda_time
n  h d  block_j block_k warp num_stages method   dtype                    
32 1 64 16      32      2    4          bwd_dkvb torch.bfloat16   0.110573
        32      32      4    1          bwd_dkvb torch.bfloat16   0.122507
        16      64      8    4          bwd_dkvb torch.bfloat16   0.124928
                16      4    6          bwd_dkvb torch.bfloat16   0.124949
     

In [None]:
lu = FastParameterLookup(create_parameter_grid(df[df["method"] == "fwd"]))

lu.get_parameters(64, 1, 32)

{'block_j': 64, 'block_k': 128, 'warp': 8, 'num_stages': 4}

In [None]:
import pandas as pd
import yaml
from typing import Dict


def create_config_lookup(df: pd.DataFrame) -> str:
    """Create grid configuration from DataFrame."""
    # Get mean cuda_time for each configuration
    mean_times = (
        df.groupby(["n", "h", "d", "block_j", "block_k", "warp", "num_stages"])[
            "cuda_time"
        ]
        .mean()
        .reset_index()
    )

    # For each n,h,d combo, find the row with minimum cuda_time
    best_params = mean_times.loc[
        mean_times.groupby(["n", "h", "d"])["cuda_time"].idxmin()
    ]

    grid = {
        "grid_points": {
            "n": sorted(df["n"].unique().tolist()),
            "h": sorted(df["h"].unique().tolist()),
            "d": sorted(df["d"].unique().tolist()),
        },
        "settings": {},
    }

    # Store best parameters for each n,h,d point
    for _, row in best_params.iterrows():
        grid["settings"][f"{int(row['n'])},{int(row['h'])},{int(row['d'])}"] = {
            "block_j": int(row["block_j"]),
            "block_k": int(row["block_k"]),
            "warp": int(row["warp"]),
            "num_stages": int(row["num_stages"]),
        }

    return yaml.dump(grid, sort_keys=False, indent=2)


cfg = create_parameter_grid(df[df["method"] == "fwd"])
lu = FastParameterLookup(cfg)
lu.get_parameters(64, 8, 32)

{'block_j': 32, 'block_k': 128, 'warp': 2, 'num_stages': 4}

In [None]:
fwd_yaml = create_config_lookup(df[df["method"] == "fwd"])

with open(cfg_dir / "fwd.yaml", "w") as f:
    f.write(fwd_yaml)

bwd_dq_yaml = create_config_lookup(df[df["method"] == "bwd_dq"])
with open(cfg_dir / "bwd_dq.yaml", "w") as f:
    f.write(bwd_dq_yaml)

bwd_dkvb_yaml = create_config_lookup(df[df["method"] == "bwd_dkvb"])
with open(cfg_dir / "bwd_dkvb.yaml", "w") as f:
    f.write(bwd_dkvb_yaml)

In [182]:
print(cfg)

grid_points:
  n:
  - 32
  - 64
  - 128
  - 256
  - 512
  - 1024
  h:
  - 1
  - 2
  - 4
  - 8
  d:
  - 32
  - 64
settings:
  32,1,32:
    block_j: 64
    block_k: 64
    warp: 8
    num_stages: 2
  32,1,64:
    block_j: 32
    block_k: 64
    warp: 2
    num_stages: 3
  32,2,32:
    block_j: 16
    block_k: 128
    warp: 1
    num_stages: 2
  32,2,64:
    block_j: 32
    block_k: 64
    warp: 1
    num_stages: 3
  32,4,32:
    block_j: 16
    block_k: 64
    warp: 1
    num_stages: 4
  32,4,64:
    block_j: 16
    block_k: 128
    warp: 1
    num_stages: 6
  32,8,32:
    block_j: 16
    block_k: 16
    warp: 1
    num_stages: 3
  32,8,64:
    block_j: 128
    block_k: 64
    warp: 4
    num_stages: 1
  64,1,32:
    block_j: 16
    block_k: 16
    warp: 2
    num_stages: 6
  64,1,64:
    block_j: 16
    block_k: 64
    warp: 2
    num_stages: 4
  64,2,32:
    block_j: 16
    block_k: 16
    warp: 8
    num_stages: 2
  64,2,64:
    block_j: 16
    block_k: 32
    warp: 2
    num_stages: 

In [None]:
from typing import Dict
from bisect import bisect_right


class FastParameterLookup:
    def __init__(self, config_str: str):
        config = yaml.safe_load(config_str)

        # Create sorted lists of values for each dimension
        self.n_vals = sorted(config["grid_points"]["n"])
        self.d_vals = sorted(config["grid_points"]["d"])
        self.h_vals = sorted(config["grid_points"]["h"])

        # Create nested lookup structure {n -> {d -> {h -> params}}}
        # yeah this isn't exactly correct but whatevs
        self.lookup = {}
        for key, params in config["settings"].items():
            n, h, d = map(float, key.split(","))
            if n not in self.lookup:
                self.lookup[n] = {}
            if d not in self.lookup[n]:
                self.lookup[n][d] = {}
            self.lookup[n][d][h] = params

    def get_parameters(self, n: float, h: float, d: float) -> Dict[str, int]:
        i = bisect_right(self.n_vals, n)
        n_closest = self.n_vals[i - 1] if i > 0 else self.n_vals[0]

        i = bisect_right(self.d_vals, d)
        d_closest = self.d_vals[i - 1] if i > 0 else self.d_vals[0]

        i = bisect_right(self.h_vals, h)
        h_closest = self.h_vals[i - 1] if i > 0 else self.h_vals[0]

        return self.lookup[n_closest][d_closest][h_closest].copy()

In [None]:
lu = FastParameterLookup(cfg)

In [188]:
%timeit lu.get_parameters(64, 8, 32)


687 ns ± 52.6 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
