Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Install lint dependencies
run: |
source .venv/bin/activate
uv pip install pyright
uv pip install pyrefly
uv pip install .'[dev]'

- name: Run pre-commit
Expand Down
9 changes: 5 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ repos:
additional_dependencies:
- tomli

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.407
- repo: https://github.com/facebook/pyrefly-pre-commit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the pyrefly with the version in the config? i think @lolpack added something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried switching to the second option here that's supposed to let you specify the version, but I couldn't get it to work. Pyrefly starts producing a lot of import errors (on torch, triton, etc.) that suggest it's no longer able to find dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened an issue against pyrefly-pre-commit: facebook/pyrefly-pre-commit#7.

rev: 0.0.1
hooks:
- id: pyright
language: system
- id: pyrefly-typecheck-system
name: Pyrefly (type checking)
pass_filenames: false
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This document explains how to work effectively in this repository.
- Imports: Sorted by Ruff/isort; single import per line.
- Helion import pattern: `import helion; import helion.language as hl` (do not `import helion as hl`).
- Modules/files: snake_case; tests `test_*.py`; examples `*.py` with `main()`.
- Run `./lint.sh fix` before pushing; CI uses Ruff and Pyright.
- Run `./lint.sh fix` before pushing; CI uses Ruff and Pyrefly.

## Testing Guidelines

Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ outlined on that page and do not file a public issue.

## Coding Style
* Code is formatted and checked using [ruff](https://docs.astral.sh/ruff/formatter/).
* All files must be typed with [pyre](https://pyre-check.org/).
* All files must be typed with [pyrefly](https://pyrefly.org/).
* Run `./lint.sh install && ./lint.sh` to check your code. Many formatting issues can be fixed automatically.

## License
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ code take effect without needing to reinstall.

## Linting

We use `pre-commit` to run ruff, pyright, and other checks automatically.
We use `pre-commit` to run ruff, pyrefly, and other checks automatically.

– One-time setup (installs the git hook):
```bash
Expand Down
32 changes: 19 additions & 13 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# pyright: reportMissingImports=false

"""Performance comparison between Helion, torch.compile, Triton, and PyTorch eager by leveraging TritonBench.

Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`.
Expand Down Expand Up @@ -105,7 +103,8 @@ class RunResult:
# - Single kernel with args: (tritonbench_module, helion_module, helion_func, args_dict)
# - Multiple kernels: (tritonbench_module, [(helion_module, helion_func), ...])
# - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
# pyrefly: ignore [bad-assignment]
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = {
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"addmm": (
Expand Down Expand Up @@ -663,7 +662,8 @@ def check_and_setup_tritonbench() -> None:
installing_marker = (benchmarks_dir / ".tritonbench_installing").resolve()

try:
import tritonbench # pyright: ignore[reportMissingImports]
# pyrefly: ignore [missing-import]
import tritonbench

module_file = getattr(tritonbench, "__file__", None)
tb_repo_path = tritonbench_path.resolve()
Expand Down Expand Up @@ -785,7 +785,8 @@ def is_local(path: Path) -> bool:
importlib.invalidate_caches()

try:
import tritonbench # pyright: ignore[reportMissingImports]
# pyrefly: ignore [missing-import]
import tritonbench

print("Tritonbench installed successfully.", file=sys.stderr)
if installing_marker.exists():
Expand Down Expand Up @@ -841,11 +842,11 @@ def run_kernel(
tritonbench_module = mapping[0]
module_path = mapping[1]
func_name = mapping[2]
operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues]
operator_args = mapping[3]
variants = [(module_path, func_name)]
else:
# Without args
assert len(mapping) == 3 # Type narrowing for pyright
assert len(mapping) == 3
tritonbench_module, module_path, func_name = mapping
variants = [(module_path, func_name)]

Expand Down Expand Up @@ -873,10 +874,13 @@ def run_kernel_variants(
"""Run kernel variants in the same benchmark run."""

# Import tritonbench components
from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]
get_parser,
)
# pyrefly: ignore [missing-import]
from tritonbench.utils.parser import get_parser

# pyrefly: ignore [missing-import]
from tritonbench.utils.triton_op import BenchmarkOperator

# pyrefly: ignore [missing-import]
from tritonbench.utils.triton_op import BenchmarkOperatorMetrics

# Get the tritonbench operator name, stripping -bwd suffix for backward operators
Expand Down Expand Up @@ -944,9 +948,8 @@ def run_kernel_variants(
sys.exit(1)

# Import register_benchmark API
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
register_benchmark,
)
# pyrefly: ignore [missing-import]
from tritonbench.utils.triton_op import register_benchmark

# Register all variants as separate methods
for module_path, func_name in variants:
Expand Down Expand Up @@ -1094,11 +1097,14 @@ def accuracy_fail_hook(
)

try:
# pyrefly: ignore [missing-import]
from tritonbench.run import run as tritonbench_run
except ImportError:
try:
# pyrefly: ignore [missing-import]
from tritonbench.utils.run_utils import tritonbench_run
except ImportError:
# pyrefly: ignore [missing-import]
from pytorch.tritonbench.run import run as tritonbench_run

with tempfile.NamedTemporaryFile(mode="w+t", suffix=".csv") as tmp:
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from typing import Callable
from typing import Protocol

import pytorch_sphinx_theme2 # pyright: ignore[reportMissingImports]
# pyrefly: ignore [missing-import]
import pytorch_sphinx_theme2

# -- Path setup --------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion examples/all_gather_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
dist_group = dist.group.WORLD
if dist_group is None:
raise RuntimeError("No distributed group available")
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue]
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
)
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)
Expand Down
5 changes: 3 additions & 2 deletions examples/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def dev_array_to_tensor_short(
Returns:
PyTorch tensor created from the device pointer
"""
return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
return cpp_mod.from_blob(dev_array_ptr, shape, dtype)


# %%
Expand Down Expand Up @@ -228,7 +229,7 @@ def reference_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
a_shared_clone.copy_(a_shared)

return torch.ops.symm_mem.one_shot_all_reduce( # pyright: ignore[reportCallIssue]
return torch.ops.symm_mem.one_shot_all_reduce(
a_shared_clone, "sum", dist_group.group_name
)

Expand Down
5 changes: 3 additions & 2 deletions examples/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ def attention(
# ---------------------

# %%
attention_dynamic: object = helion.kernel( # pyright: ignore[reportCallIssue]
# pyrefly: ignore [no-matching-overload]
attention_dynamic: object = helion.kernel(
attention.fn,
configs=attention.configs, # pyright: ignore[reportArgumentType]
configs=attention.configs,
static_shapes=False,
)
"""
Expand Down
8 changes: 6 additions & 2 deletions examples/blackwell_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso


# %%
# pyrefly: ignore [no-matching-overload]
@helion.kernel(
configs=[
helion.Config(
Expand Down Expand Up @@ -158,7 +159,8 @@ def blackwell_attention_kernel(
qk = hl.dot(q_i, k_j.T, out_dtype=torch.float32)
m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
if VECT_MUL == 2 or VECT_MUL == 3:
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) # pyright: ignore[reportArgumentType]
# pyrefly: ignore [bad-argument-type]
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
else:
qk = qk * qk_scale - m_ij[:, None]

Expand All @@ -169,6 +171,7 @@ def blackwell_attention_kernel(

if SUBTILING:
acc0, acc1 = hl.split(
# pyrefly: ignore [no-matching-overload]
acc.reshape([tile_m, 2, Dv // 2]).permute(0, 2, 1)
)
if VECT_MUL == 1 or VECT_MUL == 3:
Expand Down Expand Up @@ -267,7 +270,8 @@ def ref_attention(
atol=0.1,
rtol=0.1,
)
dur: float = do_bench(lambda: blackwell_attention(q, k, v)) # pyright: ignore[reportArgumentType, reportAssignmentType]
# pyrefly: ignore [bad-assignment]
dur: float = do_bench(lambda: blackwell_attention(q, k, v))
print(
f"{z=} {h=} {n_ctx=} {head_dim=} tflops={z * h * n_ctx * n_ctx * head_dim * 4 / dur * 1e-9:.2f}"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor) -> torch.Tensor:
# %%
class ExpFunction(torch.autograd.Function):
@staticmethod
def forward(
def forward( # pyrefly: ignore [bad-override]
ctx: object,
x: torch.Tensor,
) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions examples/fp8_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def fp8_attention_kernel(
# Final normalization
acc = acc / l_i[:, None]
# Convert to FP8 before writing to output
# pyrefly: ignore [unsupported-operation]
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)
return out

Expand Down
3 changes: 2 additions & 1 deletion examples/fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def fused_linear_jsd_fwd_tritonbench(
label: torch.Tensor | None = None,
) -> Callable[[], torch.Tensor]:
assert label is None
baseline_op = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
baseline_op = tb_op.baseline_op
beta = baseline_op.jsd.beta
ignore_index = baseline_op.jsd.ignore_index
temperature = baseline_op.temperature
Expand Down
14 changes: 9 additions & 5 deletions examples/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def geglu_bwd(grad_out: Tensor, a: Tensor, b: Tensor) -> tuple[Tensor, Tensor]:

class GEGLUFunction(torch.autograd.Function):
@staticmethod
def forward(
def forward( # pyrefly: ignore [bad-override]
ctx: Any, # noqa: ANN401
a: Tensor,
b: Tensor,
Expand Down Expand Up @@ -343,17 +343,21 @@ def geglu_tritonbench(tb_op: object, x: Tensor) -> Callable:

# Extract configuration from tritonbench operator
config = Config(
hidden_size=tb_op.hidden_size, # pyright: ignore[reportAttributeAccessIssue]
intermediate_size=tb_op.intermediate_size, # pyright: ignore[reportAttributeAccessIssue]
hidden_act=tb_op.hidden_act, # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
hidden_size=tb_op.hidden_size,
# pyrefly: ignore [missing-attribute]
intermediate_size=tb_op.intermediate_size,
# pyrefly: ignore [missing-attribute]
hidden_act=tb_op.hidden_act,
)

# Create Helion model
helion_mlp = HelionGEGLUMLP(config).to(x.device).to(x.dtype)

# Copy weights from tritonbench baseline model (LlamaMLP) to ensure fairness
# LlamaMLP has: gate_proj, up_proj, down_proj (same structure as our HelionGEGLUMLP)
baseline_model = tb_op.baseline_model # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
baseline_model = tb_op.baseline_model

# Copy gate projection weights
helion_mlp.gate_proj.weight.data.copy_(baseline_model.gate_proj.weight.data)
Expand Down
4 changes: 2 additions & 2 deletions examples/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def grouped_gemm_jagged_persistent(
tile_in_group = local_tile * num_workers + worker_id
if tile_in_group < num_group_tiles:
# Convert linear tile index to 2D (M, N) tile coordinates
m_tile_idx = tile_in_group % num_m_tiles # pyright: ignore[reportOperatorIssue]
m_tile_idx = tile_in_group % num_m_tiles
n_tile_idx = tile_in_group // num_m_tiles

# Compute global memory indices for current tile
base_row = group_start + m_tile_idx * BLOCK_M
base_col = n_tile_idx * BLOCK_N # pyright: ignore[reportOperatorIssue]
base_col = n_tile_idx * BLOCK_N

# Generate row and column index ranges for tile access
row_idx = base_row + hl.arange(BLOCK_M)
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class GrpoLossFunction(torch.autograd.Function):
"""Custom autograd function for GRPO loss with forward and backward passes."""

@staticmethod
def forward(
def forward( # pyrefly: ignore [bad-override]
ctx: object,
logits: torch.Tensor,
old_logp: torch.Tensor | None,
Expand Down
7 changes: 3 additions & 4 deletions examples/jagged_hstu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import helion.language as hl

try:
from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports]
triton_hstu_mha,
)
# pyrefly: ignore [missing-import]
from generative_recommenders.ops.triton.triton_hstu_attention import triton_hstu_mha

HAS_HAMMER = True
except ImportError:
Expand Down Expand Up @@ -249,7 +248,7 @@ def _triton_hstu_mha(
num_targets: torch.Tensor | None,
max_seq_len: int,
) -> torch.Tensor:
return triton_hstu_mha( # pyright: ignore[reportPossiblyUnboundVariable,reportCallIssue]
return triton_hstu_mha(
max_seq_len,
alpha=1.0 / v.size(2) ** 2,
q=q,
Expand Down
5 changes: 3 additions & 2 deletions examples/jagged_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def reference_jagged_layer_norm_pytorch(
[
torch.nn.functional.layer_norm(
x_values[x_offsets[i] : x_offsets[i + 1], :],
x_values[x_offsets[i] : x_offsets[i + 1], :].shape, # pyright: ignore[reportArgumentType]
x_values[x_offsets[i] : x_offsets[i + 1], :].shape,
eps=eps,
)
for i in range(x_offsets.shape[0] - 1)
Expand Down Expand Up @@ -225,7 +225,8 @@ def jagged_layer_norm_tritonbench(
Callable that returns normalized tensor values
"""
x_values = x._values
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
x_offsets = x._offsets

return lambda: jagged_layer_norm_kernel(x_values, x_offsets, eps=1e-6)

Expand Down
6 changes: 4 additions & 2 deletions examples/jagged_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,15 @@ def jagged_mean_tritonbench(
Callable that returns tensor of shape (B, M) with mean values per row and feature
"""
x_values = x._values
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
x_offsets = x._offsets

feature_counts = torch.full(
(B,),
M,
dtype=torch.int32,
device=x_values.device, # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
device=x_values.device,
)
return lambda: jagged_mean_kernel(x_values, x_offsets, feature_counts, M)

Expand Down
3 changes: 2 additions & 1 deletion examples/jagged_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def jagged_softmax_tritonbench(
Returns:
Callable that returns tensor of shape (N, M), where N = total number of rows in the jagged tensor
"""
return lambda: jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
return lambda: jagged_softmax_kernel(x._values, x._offsets)


# %%
Expand Down
3 changes: 2 additions & 1 deletion examples/jagged_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def jagged_sum_tritonbench(
Callable that returns tensor of shape (B, M) with mean values per row and feature
"""
x_values = x._values
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
# pyrefly: ignore [missing-attribute]
x_offsets = x._offsets

return lambda: jagged_sum_kernel(x_values, x_offsets)

Expand Down
Loading
Loading