In [94]:
from datasets import load_dataset

ds = load_dataset("tcapelle/train_ds_triton", split="train")

ds[0]

Generating train split: 100%|██████████| 828/828 [00:00<00:00, 67708.10 examples/s]


{'entrypoint': 'matmul',
 'triton_code': 'import torch\nimport triton\nimport triton.language as tl\n\n# Global device standard\nDEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")\n\n\n@triton.jit\ndef _matmul_kernel(A, B, C, M, N, K, **meta):\n    """Triton kernel for matrix multiplication using tiling."""\n    # Tiling sizes\n    TILE_M = meta[\'BLOCK_M\']\n    TILE_N = meta[\'BLOCK_N\']\n    TILE_K = 128\n    \n    # Indices for output tile computed by the current program instance\n    m = tl.program_id(0) * TILE_M + tl.arange(0, TILE_M)\n    n = tl.program_id(1) * TILE_N + tl.arange(0, TILE_N)\n    \n    # Initialize the accumulator for the resultant tile\n    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n    \n    # Loop over the K dimension with tiles\n    for k in range(0, K, TILE_K):\n        a = tl.load(A + m[:, None] * K + k, mask=[m[:, None] < M, None], other=0.0)\n        b = tl.load(B + k * N + n, mask=[None, n < N], other=0.0)\n        acc += tl.d

In [99]:
print(ds[0]["prompt"][0]["content"])
print("-"*100)
print(ds[0]["prompt"][1]["content"])


You are an expert in Triton programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. 

# Instructions
- Ensure that the wrapper function fully corresponds to the provided function information.
- Generate a detailed plan on how to convert and optimize the Pytorch code to a Triton kernel before writing the code.
- The reasoning process MUST BE enclosed within <think> and </think> tags."
- Reply with the thinking process and a single blob of code surrounded with ```python and ```.

----------------------------------------------------------------------------------------------------
Convert the following PyTorch code to a Triton kernel.
Pytorch code:
```python
import torch

# Global device standard
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """
    Perform matrix multiplication using pure PyTorch with torch.matmul

In [97]:
print(ds[0]["tests"])

def test_matmul():
    """
    Test function for pure PyTorch matrix multiplication on DEVICE.

    Returns:
      dict: Dictionary storing test results for each test case.
    """
    results = {}

    # Test Case 1: Small square matrices
    A1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=DEVICE)
    B1 = torch.tensor([[5.0, 6.0], [7.0, 8.0]], device=DEVICE)
    C1 = matmul(A1, B1)
    results['test_case_1'] = C1

    # Test Case 2: Rectangular matrices
    A2 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=DEVICE)
    B2 = torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], device=DEVICE)
    C2 = matmul(A2, B2)
    results['test_case_2'] = C2

    # Test Case 3: Larger matrices
    torch.manual_seed(42)
    A3 = torch.randn(64, 128, device=DEVICE)
    B3 = torch.randn(128, 32, device=DEVICE)
    C3 = matmul(A3, B3)
    expected_C3 = torch.mm(A3, B3)
    results['test_case_3'] = {
        'result': C3,
        'expected': expected_C3,
        'close': torch.allclose(C

In [76]:
columns_to_keep = ["entrypoint", "triton_code", "pt_code_without_tests", "tests"]
ds = ds.select_columns(columns_to_keep)
# ds.push_to_hub("tcapelle/train_ds_triton", commit_message="Filter columns")


In [77]:
ds = ds.rename_column("pt_code_without_tests", "pt_code")

In [78]:
ds

Dataset({
    features: ['entrypoint', 'triton_code', 'pt_code', 'tests'],
    num_rows: 847
})

## Run via server

In [79]:
import time
import httpx

SERVER_URL = "http://127.0.0.1:9347"
RUN_CODE_ENDPOINT = f"{SERVER_URL}/run_code"

client = httpx.Client()

def send_run_request(client: httpx.Client, code: str, tests: str):
    """Sends a request to the /run_code endpoint."""
    payload = {
        "code": code,
        "tests": tests
    }
    try:
        # Use console.print for richer output
        # console.print(f"Sending request for code snippet (first 50 chars): {code[:50]}...") # Make logging less verbose
        start_time = time.monotonic()
        response = client.post(RUN_CODE_ENDPOINT, json=payload, timeout=180.0) # Adjusted timeout if needed
        end_time = time.monotonic()
        duration = end_time - start_time
        response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
        # console.print(f"Request successful (Status: [green]{response.status_code}[/green], Duration: {duration:.2f}s)") # Make logging less verbose
        return response.json(), duration # Return duration along with result
    except Exception as e:
        print(f"Error: {e}")
        return None, None

In [80]:
def test_sample(row):
    code = row["pt_code"]
    tests = row["tests"]
    response, duration = send_run_request(client, code, tests)
    return {"pt_runs": response["status_code"] == 0, "pt_stdout": response["stdout"], "pt_stderr": response["stderr"], "pt_exec_time": duration}

In [81]:
test_sample(ds[0])

{'pt_runs': True,
 'pt_stdout': "{'test_case_1': tensor([[19., 22.],\n        [43., 50.]], device='cuda:0'), 'test_case_2': tensor([[ 58.,  64.],\n        [139., 154.]], device='cuda:0'), 'test_case_3': {'result': tensor([[  9.3524,  20.1801,   1.3200,  ..., -21.0338,   3.0357,  -8.3879],\n        [ -5.5521,   5.0191, -26.5503,  ...,  -5.4739,  -7.3350,  -0.0405],\n        [  2.6591,  -5.7370,   2.5628,  ...,  22.7629,   1.0609,  -6.0721],\n        ...,\n        [  0.7112,  11.1433,   7.8263,  ...,  -8.2718,  -5.5668,  -6.1661],\n        [ 17.1974,  -6.1684,   1.1457,  ...,  -6.9263, -12.8880,   5.2832],\n        [-10.5624,   2.1081, -10.1488,  ...,   7.4583,  -1.6897,  -1.7082]],\n       device='cuda:0'), 'expected': tensor([[  9.3524,  20.1801,   1.3200,  ..., -21.0338,   3.0357,  -8.3879],\n        [ -5.5521,   5.0191, -26.5503,  ...,  -5.4739,  -7.3350,  -0.0405],\n        [  2.6591,  -5.7370,   2.5628,  ...,  22.7629,   1.0609,  -6.0721],\n        ...,\n        [  0.7112,  11.1433

In [82]:
test_ds_cuda = ds.map(test_sample, num_proc=20)

Map (num_proc=20): 100%|██████████| 847/847 [00:32<00:00, 26.02 examples/s] 


In [85]:
test_ds_cuda[0]

{'entrypoint': 'matmul',
 'triton_code': 'import torch\nimport triton\nimport triton.language as tl\n\n# Global device standard\nDEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")\n\n\n@triton.jit\ndef _matmul_kernel(A, B, C, M, N, K, **meta):\n    """Triton kernel for matrix multiplication using tiling."""\n    # Tiling sizes\n    TILE_M = meta[\'BLOCK_M\']\n    TILE_N = meta[\'BLOCK_N\']\n    TILE_K = 128\n    \n    # Indices for output tile computed by the current program instance\n    m = tl.program_id(0) * TILE_M + tl.arange(0, TILE_M)\n    n = tl.program_id(1) * TILE_N + tl.arange(0, TILE_N)\n    \n    # Initialize the accumulator for the resultant tile\n    acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n    \n    # Loop over the K dimension with tiles\n    for k in range(0, K, TILE_K):\n        a = tl.load(A + m[:, None] * K + k, mask=[m[:, None] < M, None], other=0.0)\n        b = tl.load(B + k * N + n, mask=[None, n < N], other=0.0)\n        acc += tl.d

In [92]:
runs_cuda_pt = test_ds_cuda.filter(lambda x: x["pt_runs"])
runs_cuda_pt

Filter: 100%|██████████| 847/847 [00:00<00:00, 73975.00 examples/s]


Dataset({
    features: ['entrypoint', 'triton_code', 'pt_code', 'tests', 'pt_runs', 'pt_stdout', 'pt_stderr', 'pt_exec_time'],
    num_rows: 828
})

In [93]:
runs_cuda_pt.push_to_hub("tcapelle/train_ds_triton", commit_message="Run CUDA PT")

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 20.60ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/tcapelle/train_ds_triton/commit/1e4e797ae4bab5c2125ced4ff69783951401f7ff', commit_message='Run CUDA PT', commit_description='', oid='1e4e797ae4bab5c2125ced4ff69783951401f7ff', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/tcapelle/train_ds_triton', endpoint='https://huggingface.co', repo_type='dataset', repo_id='tcapelle/train_ds_triton'), pr_revision=None, pr_num=None)

## Run Locally

In [None]:
from tools import run_python_code

def run_with_tests(code: str, tests: str) -> dict:
    return run_python_code(code + "\n" + tests)


def test_sample_locally(row):
    code = row["pt_code"]
    tests = row["tests"]
    return run_with_tests(code, tests)

test_sample_locally(ds[0])


{'status_code': 0,
 'stdout': "{'test_case_1': tensor([[19., 22.],\n        [43., 50.]]), 'test_case_2': tensor([[ 58.,  64.],\n        [139., 154.]]), 'test_case_3': {'result': tensor([[ 2.4955e+00,  6.4769e+00,  1.6394e+00,  ...,  1.2249e+01,\n          2.1542e+01,  8.3416e+00],\n        [ 7.4504e+00,  1.9378e-02,  1.2411e+01,  ...,  8.3219e+00,\n          2.8858e+00,  1.4705e+00],\n        [ 2.8023e+00,  7.2151e+00,  3.0986e+00,  ...,  2.8608e+01,\n         -1.5909e+01, -2.4647e+01],\n        ...,\n        [-7.1271e+00, -1.0447e+01,  9.8994e+00,  ...,  8.3518e+00,\n         -7.8036e-01, -2.5926e+01],\n        [ 2.3954e+00,  1.7080e+01, -4.1753e+00,  ..., -5.8380e-01,\n          1.8727e+00,  2.1891e+00],\n        [-2.0062e+00, -4.0143e+00, -9.1468e+00,  ..., -1.9226e+01,\n         -1.0324e+01,  2.3399e+01]]), 'expected': tensor([[ 2.4955e+00,  6.4769e+00,  1.6394e+00,  ...,  1.2249e+01,\n          2.1542e+01,  8.3416e+00],\n        [ 7.4504e+00,  1.9378e-02,  1.2411e+01,  ...,  8.321

In [46]:
test_ds = ds.map(test_sample_locally, num_proc=4)

Map (num_proc=4): 100%|██████████| 847/847 [02:18<00:00,  6.10 examples/s]


In [49]:
issues_ds = test_ds.filter(lambda x: x["status_code"] != 0)

Filter: 100%|██████████| 847/847 [00:00<00:00, 94182.81 examples/s]


In [62]:
print(issues_ds[9]["stderr"])

Traceback (most recent call last):
  File "/Users/tcapelle/work/triton_eval/axolotl_dev/temp_files/0fc4e62c-3c92-44f0-a2f7-2cbf7718ce1a.py", line 183, in <module>
    test_results = test_attn_fwd_inner_torch()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tcapelle/work/triton_eval/axolotl_dev/temp_files/0fc4e62c-3c92-44f0-a2f7-2cbf7718ce1a.py", line 166, in test_attn_fwd_inner_torch
    acc3, l_i3, m_i3 = attn_fwd_inner(acc0.clone(), l_i0.clone(), m_i0.clone(),
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tcapelle/work/triton_eval/axolotl_dev/temp_files/0fc4e62c-3c92-44f0-a2f7-2cbf7718ce1a.py", line 67, in attn_fwd_inner
    k1 = safe_slice(K1_sub, 1, rel_index, BLOCK_N)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tcapelle/work/triton_eval/axolotl_dev/temp_files/0fc4e62c-3c92-44f0-a2f7-2cbf7718ce1a.py", line 18, in safe_slice
    return tensor.narrow(dim, start, block_size)
           ^^^^^^^^^^^^^^

## Old

In [9]:
print(ds[0]["entrypoint"])
print("="*100)
print(ds[0]["pt_code_without_tests"])


matmul
import torch

# Global device standard
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
    """
    Perform matrix multiplication using pure PyTorch with torch.matmul.

    Parameters:
      A (torch.Tensor): Tensor of shape (M, K) on DEVICE.
      B (torch.Tensor): Tensor of shape (K, N) on DEVICE.

    Returns:
      torch.Tensor: The product matrix of shape (M, N).
    """
    # Verify dimensions
    if A.dim() != 2 or B.dim() != 2:
        raise ValueError('Both A and B must be 2D tensors.')
    M, K = A.shape
    K2, N = B.shape
    if K != K2:
        raise ValueError(f'Inner dimensions must match, got A: {A.shape}, B: {B.shape}')

    # Perform matrix multiplication using torch.matmul
    return torch.matmul(A, B)

########################




In [10]:
import re
def find_test_name(code: str) -> str:
    pattern = r"def test_(.*)\("
    match = re.search(pattern, code)
    if match:
        return "test_" + match.group(1)
    return ""

In [12]:
code = ds[0]["pytorch_code_fixed"]
find_test_name(code)

'test_matmul'

In [45]:
def split_at_tests(code: str, test_name: str) -> tuple[str, str]:
    pattern = f"def {test_name}"
    match_index = code.find(pattern)
    if match_index == -1:
        return code, ""
    code_without_tests = code[:match_index]
    tests = code[match_index:]
    return code_without_tests, tests

In [46]:
test_name = find_test_name(code)
pt_code, test_code = split_at_tests(code, test_name)
test_code

'def matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:\n    """\n    Perform matrix multiplication using pure PyTorch with torch.matmul.\n\n    Parameters:\n      A (torch.Tensor): Tensor of shape (M, K) on DEVICE.\n      B (torch.Tensor): Tensor of shape (K, N) on DEVICE.\n\n    Returns:\n      torch.Tensor: The product matrix of shape (M, N).\n    """\n    # Verify dimensions\n    if A.dim() != 2 or B.dim() != 2:\n        raise ValueError(\'Both A and B must be 2D tensors.\')\n    M, K = A.shape\n    K2, N = B.shape\n    if K != K2:\n        raise ValueError(f\'Inner dimensions must match, got A: {A.shape}, B: {B.shape}\')\n\n    # Perform matrix multiplication using torch.matmul\n    return torch.matmul(A, B)\n\n########################\n\n'

In [47]:
def rename_test_entrypoint(test_code: str, entrypoint: str) -> str:
    test_name = find_test_name(test_code)
    print(f"Replacing `{test_name}` with `test_{entrypoint}`")  
    return test_code.replace(test_name, f"test_{entrypoint}")

In [48]:
entrypoint = ds[0]["entrypoint"]

In [49]:
rename_test_entrypoint(test_code, entrypoint)

Replacing `` with `test_matmul`


'test_matmuldtest_matmuletest_matmulftest_matmul test_matmulmtest_matmulatest_matmulttest_matmulmtest_matmulutest_matmulltest_matmul(test_matmulAtest_matmul:test_matmul test_matmulttest_matmulotest_matmulrtest_matmulctest_matmulhtest_matmul.test_matmulTtest_matmuletest_matmulntest_matmulstest_matmulotest_matmulrtest_matmul,test_matmul test_matmulBtest_matmul:test_matmul test_matmulttest_matmulotest_matmulrtest_matmulctest_matmulhtest_matmul.test_matmulTtest_matmuletest_matmulntest_matmulstest_matmulotest_matmulrtest_matmul)test_matmul test_matmul-test_matmul>test_matmul test_matmulttest_matmulotest_matmulrtest_matmulctest_matmulhtest_matmul.test_matmulTtest_matmuletest_matmulntest_matmulstest_matmulotest_matmulrtest_matmul:test_matmul\ntest_matmul test_matmul test_matmul test_matmul test_matmul"test_matmul"test_matmul"test_matmul\ntest_matmul test_matmul test_matmul test_matmul test_matmulPtest_matmuletest_matmulrtest_matmulftest_matmulotest_matmulrtest_matmulmtest_matmul test_matmulmt

In [53]:
def fix_test_entrypoint(row: dict) -> dict:
    if row["tests"] == "" and row["pt_code_without_tests"] != "":
        entrypoint = row["entrypoint"]
        pt_code_without_tests = row["pt_code_without_tests"]
        test_name = find_test_name(pt_code_without_tests)
        pt_code, test_code = split_at_tests(pt_code_without_tests, test_name)
        test_code = rename_test_entrypoint(test_code, entrypoint)
        return {
            "pt_code_without_tests": pt_code,
            "tests": test_code}
    return row

In [54]:
fixed_ds = filtered_ds.map(fix_test_entrypoint)

Map: 100%|██████████| 55/55 [00:00<00:00, 4474.05 examples/s]

Replacing `test_unpack64` with `test_pytorch_unpack64`
Replacing `test_custom_attention` with `test_custom_attention_forward`
Replacing `test_fourth_order_fwd` with `test_torch_fourth_order_fwd`
Replacing `test_layer_norm_fwd_fused` with `test_layer_norm_fwd_fused_`
Replacing `test__single_query_cached_kv_attention_v2_torch` with `test_torch_single_query_cached_kv_attention_v2`
Replacing `test_bwd_intra` with `test_pytorch_bwd_intra`
Replacing `test_elementwise_mul` with `test_elementwise_mul_`
Replacing `test_rms_norm_fwd` with `test_pytorch_rms_norm_fwd`
Replacing `test_rotary_embedding` with `test_apply_rotary_embedding`
Replacing `test_fused_chunk_delta_rule_bwd` with `test_fused_chunk_delta_rule_bwd_`
Replacing `test_bwd_block` with `test_torch_bwd_block`
Replacing `test_bwd_decay_global_cumsum` with `test_bwd_decay_global_cumsum_`
Replacing `test_attn_fwd_inner_torch` with `test_torch_attn_fwd_inner`
Replacing `test_atomic_kernel` with `test_atomic_kernel_`
Replacing `test_cross_




In [55]:
for row in fixed_ds:
    print(row["entrypoint"])
    print("="*100)
    print(row["pt_code_without_tests"])
    print("="*100)
    print(row["tests"])
    print("*"*100)

pytorch_unpack64
import torch

# Global device standard
DEVICE = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

import math

def pytorch_unpack64(merged: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """
    Decomposes a 64-bit unsigned integer tensor into two 32-bit floats by bitmasking and
    bit-level reinterpretation. The upper 32 bits produce the first float and the lower 32 bits produce the second.
    """
    if merged.dtype != torch.uint64:
        raise ValueError('Input tensor must be of dtype torch.uint64')

    # Extract lower and upper 32 bits
    mask_val = 0xffffffff  # Mask for lower 32 bits

    # Workaround: torch.bitwise_and for uint64 is not implemented on CUDA, so cast to CPU for bitwise ops
    if merged.device.type == 'cuda':
        merged_cpu = merged.cpu()
        lower = merged_cpu & mask_val
        upper = merged_cpu >> 32
        lower = lower.to(DEVICE)
        upper = upper.to(DEVICE)
    else:
        lower = merged & mask_val
     

## Fix DS

In [57]:
ds = load_dataset("tcapelle/train_ds_triton", split="train")
ds = ds.map(fix_test_entrypoint)

Map: 100%|██████████| 847/847 [00:00<00:00, 9050.53 examples/s]

Replacing `test_unpack64` with `test_pytorch_unpack64`
Replacing `test_custom_attention` with `test_custom_attention_forward`
Replacing `test_fourth_order_fwd` with `test_torch_fourth_order_fwd`
Replacing `test_layer_norm_fwd_fused` with `test_layer_norm_fwd_fused_`
Replacing `test__single_query_cached_kv_attention_v2_torch` with `test_torch_single_query_cached_kv_attention_v2`
Replacing `test_bwd_intra` with `test_pytorch_bwd_intra`
Replacing `test_elementwise_mul` with `test_elementwise_mul_`
Replacing `test_rms_norm_fwd` with `test_pytorch_rms_norm_fwd`
Replacing `test_rotary_embedding` with `test_apply_rotary_embedding`
Replacing `test_fused_chunk_delta_rule_bwd` with `test_fused_chunk_delta_rule_bwd_`
Replacing `test_bwd_block` with `test_torch_bwd_block`
Replacing `test_bwd_decay_global_cumsum` with `test_bwd_decay_global_cumsum_`
Replacing `test_attn_fwd_inner_torch` with `test_torch_attn_fwd_inner`
Replacing `test_atomic_kernel` with `test_atomic_kernel_`
Replacing `test_cross_




In [58]:
ds.push_to_hub("tcapelle/train_ds_triton", commit_message="Fix tests/entrypoint")

Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 20.14ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.86it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/tcapelle/train_ds_triton/commit/2a56b9158791b972a3e25d9727cd135ae21c2bed', commit_message='Fix tests/entrypoint', commit_description='', oid='2a56b9158791b972a3e25d9727cd135ae21c2bed', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/tcapelle/train_ds_triton', endpoint='https://huggingface.co', repo_type='dataset', repo_id='tcapelle/train_ds_triton'), pr_revision=None, pr_num=None)