In [3]:
from rich.pretty import pprint
from datasets import load_dataset
from litellm import completion
from triton_eval.utils import run_script_on_gpu, get_tests

In [4]:
ds = load_dataset("tcapelle/annotated_dataset_o3_train_pytorch_triton", split="train")

In [5]:
idx = 200
triton_code, pt_code= ds[idx]["final_triton_code"], ds[idx]["final_pytorch_code"]

In case we want to run the tests from pytorch with the generated triton kernel

In [6]:
prompt = """Grab the name of the function being tested inside this code. 

For example:

```python
import math
import torch

# Global device standard
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def kernel_bw_pytorch(grad_out: torch.Tensor, act_inputs: torch.Tensor, activation_grad, N: int) -> torch.Tensor:
total_cols = act_inputs.size(1)
if N > total_cols:
raise ValueError(f"N (got {N}) cannot be larger than the number of columns (got {total_cols})")

# Initialize output gradient tensor with zeros (ensuring device consistency with DEVICE)
grad_act = torch.zeros_like(grad_out, dtype=act_inputs.dtype, device=DEVICE)

# Process only the valid region (first N columns)
valid_act_inputs = act_inputs[:, :N]
valid_grad_out = grad_out[:, :N]

# Compute the activation gradient for the valid region
computed_grad = activation_grad(valid_act_inputs)

# Element-wise multiplication with grad_out
grad_act[:, :N] = computed_grad * valid_grad_out

return grad_act

########################

def test_kernel_bw():
results = {}

# Define a simple activation gradient function for testing, e.g., for f(x)=x^3 then f'(x)=3*x^2
activation_grad = lambda x: 3 * x.pow(2)

# Test Case 1: Even case, where the valid region covers the entire width
M, L = 2, 8
N = L # valid region covers entire width
act_inputs = torch.arange(M * L, dtype=torch.float32, device=DEVICE).reshape(M, L)
grad_out = torch.ones((M, L), dtype=torch.float32, device=DEVICE)
pytorch_out = kernel_bw_pytorch(grad_out, act_inputs, activation_grad, N)
expected = 3 * act_inputs.pow(2) # expected: activation_grad(act_inputs) * grad_out
results['even_full'] = {'pytorch': pytorch_out, 'expected': expected}

# Test Case 2: Partial valid region, only first N columns are processed
M, L = 3, 10
N = 6
act_inputs = torch.linspace(-5, 4, steps=M * L, dtype=torch.float32, device=DEVICE).reshape(M, L)
grad_out = torch.full((M, L), 2.0, dtype=torch.float32, device=DEVICE)
pytorch_out = kernel_bw_pytorch(grad_out, act_inputs, activation_grad, N)
expected_partial = torch.zeros((M, L), dtype=torch.float32, device=DEVICE)
expected_partial[:, :N] = activation_grad(act_inputs[:, :N]) * 2.0
results['partial_valid'] = {'pytorch': pytorch_out, 'expected': expected_partial}

# Test Case 3: Full valid region with non-trivial random inputs
M, L = 4, 5
N = 5
act_inputs = torch.randn(M, L, dtype=torch.float32, device=DEVICE)
grad_out = torch.randn(M, L, dtype=torch.float32, device=DEVICE)
pytorch_out = kernel_bw_pytorch(grad_out, act_inputs, activation_grad, N)
expected_full = activation_grad(act_inputs) * grad_out
results['full_random'] = {'pytorch': pytorch_out, 'expected': expected_full}

return results


# Running tests and printing the results (only printing the test_results dictionary)
test_results = test_kernel_bw()
print(test_results)
```


You should return `kernel_bw_pytorch` as this is the function being tested.
"""


In [7]:
from pydantic import BaseModel, Field

class PytorchCode(BaseModel):
    tested_function_name: str = Field(description="The name of the function being tested.")

def get_name(row, column: str="final_triton_code"):
    code = row[column]
    if code == "" or code is None:
        return {f"{column}_entrypoint": None}
    result = completion(
        model="gpt-4.1-mini", 
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content":  "Extract the name of the function being tested from this code: \n\n# Code:" + code}], 
        response_format=PytorchCode,
        max_tokens=50)
    out = PytorchCode.model_validate_json(result.choices[0].message.content)
    return {f"{column}_entrypoint": out.tested_function_name}

In [9]:
ds = ds.map(get_name, num_proc=10, fn_kwargs={"column": "final_pytorch_code"})

  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)
Map (num_proc=10): 100%|██████████| 864/864 [00:51<00:00, 16.92 examples/s]


In [10]:
ds = ds.map(get_name, num_proc=10, fn_kwargs={"column": "final_triton_code"})

  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)
Map (num_proc=10): 100%|██████████| 864/864 [01:00<00:00, 14.30 examples/s]


In [15]:
ds.save_to_disk("ds_with_entrypoints")

Saving the dataset (1/1 shards): 100%|██████████| 864/864 [00:00<00:00, 21459.70 examples/s]


In [34]:
from datasets import load_from_disk
ds = load_from_disk("ds_with_entrypoints")

In [35]:
ds[0]

{'final_triton_code': 'import torch\nimport triton\nimport triton.language as tl\n\n# Global device standard\nDEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")\n\n\n@triton.jit\ndef _matmul_kernel(A, B, C, M, N, K, **meta):\n    """Triton kernel for matrix multiplication using tiling."""\n    # Tiling sizes\n    TILE_M = meta[\'BLOCK_M\']\n    TILE_N = meta[\'BLOCK_N\']\n    TILE_K = 128\n    \n    # Indices for output tile computed by the current program instance\n    m = tl.program_id(0) * TILE_M + tl.arange(0, TILE_M)\n    n = tl.program_id(1) * TILE_N + tl.arange(0, TILE_N)\n    \n    # Initialize the accumulator for the resultant tile\n    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n    \n    # Loop over the K dimension with tiles\n    for k in range(0, K, TILE_K):\n        a = tl.load(A + m[:, None] * K + k, mask=[m[:, None] < M, None], other=0.0)\n        b = tl.load(B + k * N + n, mask=[None, n < N], other=0.0)\n        acc += tl.dot(a, b)\n    \n   

In [36]:
import re
def grab_function_definitions(code):
    return re.findall(r"def\s+(\w+)\s*\(", code)

In [37]:
# grab all function definitions (any code that starts with "def")

pt_code = ds[0]["final_pytorch_code"]
triton_code = ds[0]["final_triton_code"]
pt_entrypoint = ds[0]["final_pytorch_code_entrypoint"]
triton_entrypoint = ds[0]["final_triton_code_entrypoint"]


print(grab_function_definitions(pt_code))
print(grab_function_definitions(triton_code))
print([pt_entrypoint, triton_entrypoint])

['matmul_pytorch', 'test_matmul_pytorch']
['_matmul_kernel', 'matmul_triton', 'test_matmul_triton']
['matmul_pytorch', 'matmul_triton']


In [38]:
import re, copy

def remove_suffix(name):
    suffixes = ["_pytorch", "_triton", "pytorch", "triton", "_pt", "torch", "pt", "_python", "python","_py", "py"]
    for suffix in suffixes:
        if name.endswith(suffix):
            name = name[:-len(suffix)]
            break # Stop after removing the first matching suffix

    return name

In [39]:
remove_suffix("matmul_pytorch"), remove_suffix("matmul_py")

('matmul', 'matmul')

In [40]:
for i, row in enumerate(ds):
    if i > 25:
        break
    pt_defs = grab_function_definitions(row["final_pytorch_code"])
    triton_defs = grab_function_definitions(row["final_triton_code"])
    new_entrypoint = remove_suffix(row["final_pytorch_code_entrypoint"])
    print(f"{i:03d}:  {new_entrypoint} -> {pt_defs} | {triton_defs} ")

000:  matmul -> ['matmul_pytorch', 'test_matmul_pytorch'] | ['_matmul_kernel', 'matmul_triton', 'test_matmul_triton'] 
001:  jagged_2_softmax -> ['jagged_2_softmax', 'gelu', 'test_jagged_2_softmax', 'test_gelu', 'run_all_tests'] | ['jagged_2_softmax_kernel', 'test_jagged_2_softmax_triton', 'test_gelu_triton', 'run_all_tests'] 
002:  fancy_function -> ['fancy_function', 'test_fancy_function'] | ['fancy_function_triton', 'test_fancy_function'] 
003:  pytorch_unpack64 -> ['pytorch_unpack64', 'test_unpack64', 'float_to_bits'] | ['unpack64_kernel_inner', 'kernel_unpack64', 'triton_unpack64', 'test_unpack64', 'float_to_bits'] 
004:  fifth_order_bwd -> ['fifth_order_bwd_pytorch', 'test_fifth_order_bwd'] | ['fifth_order_bwd_triton', 'test_fifth_order_bwd_triton'] 
005:  paged_attn -> ['paged_attn', 'test_paged_attn'] | ['paged_attn_triton', 'test_paged_attn_triton'] 
006:  gelu_glu -> ['gelu_glu_pytorch', 'test_gelu_glu_pytorch'] | ['_gelu_glu_fwd_kernel', 'gelu_glu_triton', 'test_gelu_glu_tri

In [44]:
def rename_entrypoints(row):
    pt_code = row["final_pytorch_code"]
    triton_code = row["final_triton_code"]
    pt_entrypoint = row["final_pytorch_code_entrypoint"]
    triton_entrypoint = row["final_triton_code_entrypoint"]
    try:
        new_entrypoint = remove_suffix(pt_entrypoint)
        pt_code = pt_code.replace(pt_entrypoint, new_entrypoint)
        triton_code = triton_code.replace(triton_entrypoint, new_entrypoint)
        return {
            "final_pytorch_code_renamed": pt_code, 
            "final_triton_code_renamed": triton_code,
            "entrypoint": new_entrypoint,
        }
    except Exception as e:
        print(f"Error on row {i}: {e}")
        return {"final_pytorch_code_renamed": pt_code, "final_triton_code_renamed": triton_code, "entrypoint": None}

In [45]:
ds = ds.map(rename_entrypoints, num_proc=4)

Map (num_proc=4):   0%|          | 0/864 [00:00<?, ? examples/s]

Error on row 26: 'NoneType' object has no attribute 'endswith'Error on row 26: 'NoneType' object has no attribute 'endswith'

Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'Error on row 26: 'NoneType' object has no attribute 'endswith'

Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'
Error on row 26: 'NoneType' object has no attribute 'endswith'


Map (num_proc=4): 100%|██████████| 864/864 [00:00<00:00, 5606.13 examples/s]


In [51]:
print(grab_function_definitions(ds[0]["final_pytorch_code"])[0])
print("="*30)
print(grab_function_definitions(ds[0]["final_pytorch_code_renamed"])[0])

matmul_pytorch
matmul


we are done!

In [55]:
ds.column_names

['final_triton_code',
 'final_pytorch_code',
 'final_pytorch_code_entrypoint',
 'final_triton_code_entrypoint',
 'final_pytorch_code_renamed',
 'final_triton_code_renamed',
 'entrypoint']

In [57]:
ds = ds.remove_columns(["final_triton_code", "final_pytorch_code", "final_pytorch_code_entrypoint", "final_triton_code_entrypoint"])
ds = ds.rename_column("final_pytorch_code_renamed", "pytorch_code")
ds = ds.rename_column("final_triton_code_renamed", "triton_code")
ds.push_to_hub("tcapelle/annotated_dataset_renamed_all")

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 47.58ba/s]


Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/tcapelle/annotated_dataset_renamed_all/commit/aa70e436236ba8263030dcb990d15179c43c33d9', commit_message='Upload dataset', commit_description='', oid='aa70e436236ba8263030dcb990d15179c43c33d9', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/tcapelle/annotated_dataset_renamed_all', endpoint='https://huggingface.co', repo_type='dataset', repo_id='tcapelle/annotated_dataset_renamed_all'), pr_revision=None, pr_num=None)

## CPU test

In [5]:
success, results, file_name = run_script_on_gpu(pt_code, test_content="", file_name="test.py", gpu_id=None)

In [6]:
if success:
    print(results.stdout)

{'test_case_1': tensor([[19., 22.],
        [43., 50.]], device='cuda:0'), 'test_case_2': tensor([[ 58.,  64.],
        [139., 154.]], device='cuda:0'), 'test_case_3': {'result': tensor([[  9.3524,  20.1801,   1.3200,  ..., -21.0338,   3.0357,  -8.3879],
        [ -5.5521,   5.0191, -26.5503,  ...,  -5.4739,  -7.3350,  -0.0405],
        [  2.6591,  -5.7370,   2.5628,  ...,  22.7629,   1.0609,  -6.0721],
        ...,
        [  0.7112,  11.1433,   7.8263,  ...,  -8.2718,  -5.5668,  -6.1661],
        [ 17.1974,  -6.1684,   1.1457,  ...,  -6.9263, -12.8880,   5.2832],
        [-10.5624,   2.1081, -10.1488,  ...,   7.4583,  -1.6897,  -1.7082]],
       device='cuda:0'), 'expected': tensor([[  9.3524,  20.1801,   1.3200,  ..., -21.0338,   3.0357,  -8.3879],
        [ -5.5521,   5.0191, -26.5503,  ...,  -5.4739,  -7.3350,  -0.0405],
        [  2.6591,  -5.7370,   2.5628,  ...,  22.7629,   1.0609,  -6.0721],
        ...,
        [  0.7112,  11.1433,   7.8263,  ...,  -8.2718,  -5.5668,  -6.1661

## GPU test

In [7]:
success_gpu, results_gpu, _ = run_script_on_gpu(pt_code, test_content="", file_name="test.py", gpu_id=0)

In [8]:
if success_gpu:
    print(results_gpu.stdout)

{'test_case_1': tensor([[19., 22.],
        [43., 50.]], device='cuda:0'), 'test_case_2': tensor([[ 58.,  64.],
        [139., 154.]], device='cuda:0'), 'test_case_3': {'result': tensor([[  9.3524,  20.1801,   1.3200,  ..., -21.0338,   3.0357,  -8.3879],
        [ -5.5521,   5.0191, -26.5503,  ...,  -5.4739,  -7.3350,  -0.0405],
        [  2.6591,  -5.7370,   2.5628,  ...,  22.7629,   1.0609,  -6.0721],
        ...,
        [  0.7112,  11.1433,   7.8263,  ...,  -8.2718,  -5.5668,  -6.1661],
        [ 17.1974,  -6.1684,   1.1457,  ...,  -6.9263, -12.8880,   5.2832],
        [-10.5624,   2.1081, -10.1488,  ...,   7.4583,  -1.6897,  -1.7082]],
       device='cuda:0'), 'expected': tensor([[  9.3524,  20.1801,   1.3200,  ..., -21.0338,   3.0357,  -8.3879],
        [ -5.5521,   5.0191, -26.5503,  ...,  -5.4739,  -7.3350,  -0.0405],
        [  2.6591,  -5.7370,   2.5628,  ...,  22.7629,   1.0609,  -6.0721],
        ...,
        [  0.7112,  11.1433,   7.8263,  ...,  -8.2718,  -5.5668,  -6.1661

## Map

In [44]:
from concurrent.futures import ProcessPoolExecutor, as_completed

def run_one(row, gpus=[0, 1]):
    triton_code, pt_code = row["final_triton_code"], row["final_pytorch_code"]

    with ProcessPoolExecutor(max_workers=2) as executor:
        future_to_file = {
            executor.submit(run_script_on_gpu, pt_code, test_content="", file_name="test.py", gpu_id=gpus[0]): "pytorch",
            executor.submit(run_script_on_gpu, triton_code, test_content="", file_name="test.py", gpu_id=gpus[1]): "triton"
        }
        for future in as_completed(future_to_file):
            file_name = future_to_file[future]
            success, results, _ = future.result()
            if file_name == "pytorch":
                success_pytorch = success
                results_pytorch = results
            else:
                success_triton = success
                results_triton = results
    
    outputs_match = results_pytorch.stdout == results_triton.stdout

    return {"pytorch_runs": success_pytorch, 
            "pytorch_output": {"stdout": results_pytorch.stdout, "stderr": results_pytorch.stderr}, 
            "triton_runs": success_triton, 
            "triton_output": {"stdout": results_triton.stdout, "stderr": results_triton.stderr}, 
            "outputs_match": outputs_match}


In [45]:
sample_ds = ds.select(range(10))
sample_ds = sample_ds.map(run_one, num_proc=4)

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

Map (num_proc=4): 100%|██████████| 10/10 [00:21<00:00,  2.20s/ examples]


In [55]:
pprint(sample_ds[4])

## Extract Test