From 1528287e03cd3becedd99251031c616c5b64e289 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Wed, 5 Nov 2025 14:43:30 -0800 Subject: [PATCH 01/10] pyright -> pyrefly --- .github/workflows/lint.yml | 2 +- .pre-commit-config.yaml | 9 +++++---- lint.sh | 12 ++++++------ pyproject.toml | 10 +++++----- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6270fb64b..dae418310 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -40,7 +40,7 @@ jobs: - name: Install lint dependencies run: | source .venv/bin/activate - uv pip install pyright + uv pip install pyrefly uv pip install .'[dev]' - name: Run pre-commit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c0dabf13..2203b06e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,8 +50,9 @@ repos: additional_dependencies: - tomli -- repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.407 +- repo: https://github.com/facebook/pyrefly-pre-commit + rev: 0.0.1 hooks: - - id: pyright - language: system + - id: pyrefly-typecheck-system + name: Pyrefly (type checking) + pass_filenames: false diff --git a/lint.sh b/lint.sh index af0436474..102cc6bf5 100755 --- a/lint.sh +++ b/lint.sh @@ -9,13 +9,13 @@ fi if [ "$ACTION" = "install" ]; then set -ex - pip install ruff==0.14.2 pyright==1.1.407 + pip install ruff==0.14.2 pyrefly==0.40.0 exit 0 fi -if ! (which ruff > /dev/null && which pyright > /dev/null); +if ! (which ruff > /dev/null && which pyrefly > /dev/null); then - echo "ruff/pyright not installed. Run ./lint.sh install" + echo "ruff/pyrefly not installed. Run ./lint.sh install" exit 1 fi @@ -37,21 +37,21 @@ if [ "$ACTION" = "fix" ]; then run ruff format run ruff check --fix - run pyright + run pyrefly check fi if [ "$ACTION" = "unsafe" ]; then run ruff format run ruff check --fix --unsafe-fixes - run pyright + run pyrefly check fi if [ "$ACTION" = "check" ]; then run ruff format --check --diff run ruff check --no-fix - run pyright + run pyrefly check fi if [ "$ERRORS" != "" ]; diff --git a/pyproject.toml b/pyproject.toml index 93032300a..795a0dc49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,11 +100,11 @@ allow-direct-references = true [tool.hatch.version] source = "vcs" -[tool.pyright] -include = ["helion", "examples"] -exclude = ["test"] -extraPaths = ["triton/python", "../pytorch", "../pytorch-hg", "../pytorch-nightly"] -pythonVersion = "3.10" +[tool.pyrefly] +project-includes = ["helion", "examples"] +project-excludes = ["test"] +search-path = ["triton/python", "../pytorch", "../pytorch-hg", "../pytorch-nightly"] +python-version = "3.10" [tool.codespell] ignore-words = "scripts/dictionary.txt" From 9f5ccc444eb0b71edba2b23d2ed7b2de525ec7c5 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 13:55:15 -0800 Subject: [PATCH 02/10] Update pyrefly to 0.42.0. --- .github/workflows/lint.yml | 2 +- lint.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dae418310..262ac84f2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -40,7 +40,7 @@ jobs: - name: Install lint dependencies run: | source .venv/bin/activate - uv pip install pyrefly + uv pip install pyrefly==0.42.0 uv pip install .'[dev]' - name: Run pre-commit diff --git a/lint.sh b/lint.sh index 102cc6bf5..04fb0d218 100755 --- a/lint.sh +++ b/lint.sh @@ -9,7 +9,7 @@ fi if [ "$ACTION" = "install" ]; then set -ex - pip install ruff==0.14.2 pyrefly==0.40.0 + pip install ruff==0.14.2 pyrefly==0.42.0 exit 0 fi From b7189065f2b38c3295ad2d7227245f0efb9a039c Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 13:59:13 -0800 Subject: [PATCH 03/10] Add type annotations and remove redundant casts. --- helion/_compiler/aten_lowering.py | 4 ++-- helion/_compiler/device_function.py | 10 ++-------- helion/_compiler/inductor_lowering.py | 5 ++--- helion/_compiler/program_id.py | 2 +- helion/_compiler/type_propagation.py | 2 +- helion/autotuner/config_spec.py | 2 +- helion/language/random_ops.py | 4 ++-- helion/runtime/kernel.py | 2 +- helion/runtime/ref_mode.py | 6 +++--- 9 files changed, 15 insertions(+), 22 deletions(-) diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index f3e5e0991..c6967c462 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -52,7 +52,7 @@ def get_masked_value(self, node: Node) -> float | bool | None: def _env_arg(ctx: LoweringContext, node: Node) -> Argument: - return cast("Argument", ctx.env[node]) + return ctx.env[node] @dataclasses.dataclass @@ -520,7 +520,7 @@ def _codegen_rng_op( block_size = env.block_sizes[block_id].size dim_names.append(device_fn.literal_expr(block_size)) - offset_parts = [] + offset_parts: list[str] = [] for i in range(ndim): # Create the index variable with proper broadcasting diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 6cec94000..49955f10b 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from ..runtime.config import Config - from ..runtime.config import IndexingLiteral from .device_ir import HelperFunctionGraphInfo from .generate_ast import GenerateAST from .indexing_strategy import IndexingStrategy @@ -259,7 +258,6 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.rng_seed_buffer_param_name = None def get_indexing_strategy(self, index: int) -> IndexingStrategy: - from typing import cast from .indexing_strategy import IndexingStrategy from .indexing_strategy import PointerIndexingStrategy @@ -271,9 +269,7 @@ def get_indexing_strategy(self, index: int) -> IndexingStrategy: if isinstance(self._indexing_config, str): # Single string: all loads/stores use the same strategy if not self.indexing_strategies: - strategy = IndexingStrategy.select( - cast("IndexingLiteral", self._indexing_config) - ) + strategy = IndexingStrategy.select(self._indexing_config) else: strategy = self.indexing_strategies[0] elif isinstance(self._indexing_config, list) and self._indexing_config: @@ -282,9 +278,7 @@ def get_indexing_strategy(self, index: int) -> IndexingStrategy: f"Load/Store operation {idx} exceeds indexing config length " f"{len(self._indexing_config)}. Please specify indexing for all loads and stores." ) - strategy = IndexingStrategy.select( - cast("IndexingLiteral", self._indexing_config[idx]) - ) + strategy = IndexingStrategy.select(self._indexing_config[idx]) else: # Empty/default: use pointer strategy = PointerIndexingStrategy() diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index c0df920a4..347daae02 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -36,7 +36,6 @@ from torch.fx.experimental import proxy_tensor from torch.fx.experimental.sym_node import SymNode from torch.fx.interpreter import Interpreter -from torch.fx.node import Argument from torch.fx.node import Node from torch.fx.node import map_arg @@ -725,7 +724,7 @@ def __init__(self, api_func: object) -> None: def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: assert not node.kwargs - ast_args = [*map_arg(node.args, lambda arg: cast("Argument", ctx.env[arg]))] + ast_args = [*map_arg(node.args, lambda arg: ctx.env[arg])] proxy_args = [*map_arg(node.args, lambda arg: arg.meta["val"])] env = CompileEnvironment.current() @@ -927,7 +926,7 @@ class GraphInterpreter(LoweringContext, Interpreter): def __init__(self, graph: torch.fx.Graph, cg: CodegenInterface) -> None: super().__init__(_LazyGraphModule({}, graph), garbage_collect_values=False) self.cg = cg - self.env = cast("dict[Node, Argument]", self.env) + self.env = self.env def to_ast(self, value: object) -> ast.AST: """ diff --git a/helion/_compiler/program_id.py b/helion/_compiler/program_id.py index 2db56d4a6..5ea6db6d5 100644 --- a/helion/_compiler/program_id.py +++ b/helion/_compiler/program_id.py @@ -281,7 +281,7 @@ def codegen(self, state: CodegenState) -> None: assignments = [] # Generate size variables for all dimensions (except the last which doesn't need one) - num_blocks = [] + num_blocks: list[str] = [] for i in range(num_dims - 1): num_block_var = new_var(f"num_blocks_{i}", dce=True) assignments.append( diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 8ae1e423e..e3a62af4e 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -412,7 +412,7 @@ def __init__(self, origin: Origin, fake_value: torch.Tensor) -> None: CompileEnvironment.current().add_kernel_tensor_size(fake_value.size()) def __str__(self) -> str: - shape = [] + shape: list[str] = [] for s in self.fake_value.size(): if isinstance(s, torch.SymInt): shape.append( diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 29b07676f..4e2afbb86 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -373,7 +373,7 @@ def __init__( assert self.min_size <= self.max_size def __repr__(self) -> str: - fields = [] + fields: list[str] = [] for field, default in ( ("block_id", None), ("size_hint", None), diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py index 4b65b6154..bd852886f 100644 --- a/helion/language/random_ops.py +++ b/helion/language/random_ops.py @@ -89,7 +89,7 @@ def _rand_codegen(state: CodegenState) -> ast.AST: seed_ast = state.ast_arg(1) index_vars = [] - size_names = [] + size_names: list[str] = [] for i in range(ndim): size = tensor_shape[i] block_id = env.get_block_id(size) @@ -113,7 +113,7 @@ def _rand_codegen(state: CodegenState) -> ast.AST: if ndim == 1: offset_expr = expr_from_string(index_vars[0]) else: - offset_parts = [] + offset_parts: list[str] = [] for i in range(ndim): broadcast_slice = StackIndexingStrategy.get_element_broadcast_slice(i, ndim) broadcasted_index = f"{index_vars[i]}{broadcast_slice}" diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 32f5f3ee9..73f815e9f 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -757,7 +757,7 @@ def _render_input_arg_assignment(name: str, value: object) -> list[str]: output_lines.extend(["", "def helion_repro_caller():"]) output_lines.append(" torch.manual_seed(0)") - arg_names = [] + arg_names: list[str] = [] for i, value in enumerate(args): var_name = sig_param_names[i] diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py index 06ec890ac..548c8d51d 100644 --- a/helion/runtime/ref_mode.py +++ b/helion/runtime/ref_mode.py @@ -115,7 +115,7 @@ class RefModeTorchFunctionMode(BaseTorchFunctionMode): def __init__(self) -> None: super().__init__() # Map functions to their handlers - self._func_handlers = { + self._func_handlers: dict[Callable[..., object], Callable[..., object]] = { torch.addmm: lambda args, kwargs: self._handle_mm_with_bias( args, kwargs, torch.mm, "addmm" ), @@ -147,7 +147,7 @@ def __init__(self) -> None: } # Map method names to their handlers for tensor methods - self._method_handlers = { + self._method_handlers: dict[str, Callable[..., object]] = { **{ method: lambda args, kwargs, m=method: self._handle_factory_method( args, kwargs, m, has_fill=False @@ -164,7 +164,7 @@ def __init__(self) -> None: def __torch_function__( self, - func: object, + func: Callable[..., object], types: list[type[object]], args: tuple[object, ...] = (), kwargs: dict[str, object] | None = None, From ff9fb8f9c7715b480e8122cfca3c7772ccc626da Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 14:09:44 -0800 Subject: [PATCH 04/10] Remove pyright: ignore directives and other pyright references. --- README.md | 2 +- benchmarks/run.py | 20 +- docs/conf.py | 2 +- examples/all_gather_matmul.py | 2 +- examples/all_reduce.py | 4 +- examples/attention.py | 4 +- examples/blackwell_attention.py | 4 +- examples/fused_linear_jsd.py | 2 +- examples/geglu.py | 8 +- examples/grouped_gemm.py | 4 +- examples/jagged_hstu_attn.py | 6 +- examples/jagged_layer_norm.py | 4 +- examples/jagged_mean.py | 4 +- examples/jagged_softmax.py | 2 +- examples/jagged_sum.py | 2 +- examples/jsd.py | 2 +- examples/layer_norm.py | 2 +- examples/swiglu.py | 8 +- helion/_compat.py | 6 +- helion/_compiler/ast_extension.py | 20 +- helion/_compiler/ast_read_writes.py | 2 +- helion/_compiler/aten_lowering.py | 58 +++--- helion/_compiler/compile_environment.py | 13 +- helion/_compiler/device_function.py | 19 +- helion/_compiler/device_ir.py | 44 +++-- helion/_compiler/generate_ast.py | 14 +- helion/_compiler/helper_function.py | 2 +- helion/_compiler/host_function.py | 4 +- helion/_compiler/indexing_strategy.py | 8 +- helion/_compiler/inductor_lowering.py | 46 ++--- helion/_compiler/inductor_lowering_extra.py | 42 ++--- helion/_compiler/lift_closures.py | 2 +- helion/_compiler/node_masking.py | 2 +- helion/_compiler/program_id.py | 2 +- helion/_compiler/roll_reduction.py | 12 +- helion/_compiler/source_location.py | 22 +-- helion/_compiler/static_loop_unroller.py | 4 +- helion/_compiler/tensor_utils.py | 16 +- helion/_compiler/tile_strategy.py | 4 +- helion/_compiler/traceback_compat.py | 24 +-- helion/_compiler/type_propagation.py | 192 ++++++++++---------- helion/_logging/_internal.py | 4 +- helion/_testing.py | 34 ++-- helion/autotuner/base_search.py | 2 +- helion/autotuner/block_id_sequence.py | 8 +- helion/autotuner/local_cache.py | 14 +- helion/autotuner/progress_bar.py | 2 +- helion/language/_decorators.py | 18 +- helion/language/_tracing_ops.py | 16 +- helion/language/atomic_ops.py | 36 ++-- helion/language/loops.py | 2 +- helion/language/memory_ops.py | 2 +- helion/language/reduce_ops.py | 4 +- helion/language/ref_tile.py | 10 +- helion/language/signal_wait.py | 10 +- helion/language/stack_tensor.py | 8 +- helion/language/tile_proxy.py | 2 +- helion/language/view_ops.py | 4 +- helion/runtime/kernel.py | 29 ++- helion/runtime/precompile_shim.py | 2 +- helion/runtime/settings.py | 10 +- helion/runtime/triton_helpers.py | 6 +- test/test_examples.expected | 8 +- 63 files changed, 405 insertions(+), 466 deletions(-) diff --git a/README.md b/README.md index 4b40ef01b..cb576e793 100644 --- a/README.md +++ b/README.md @@ -366,7 +366,7 @@ code take effect without needing to reinstall. ## Linting -We use `pre-commit` to run ruff, pyright, and other checks automatically. +We use `pre-commit` to run ruff, pyrefly, and other checks automatically. – One-time setup (installs the git hook): ```bash diff --git a/benchmarks/run.py b/benchmarks/run.py index 750833422..969253494 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -1,5 +1,3 @@ -# pyright: reportMissingImports=false - """Performance comparison between Helion, torch.compile, Triton, and PyTorch eager by leveraging TritonBench. Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`. @@ -105,7 +103,7 @@ class RunResult: # - Single kernel with args: (tritonbench_module, helion_module, helion_func, args_dict) # - Multiple kernels: (tritonbench_module, [(helion_module, helion_func), ...]) # - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict) -KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType] +KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # : (, , ) "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), "addmm": ( @@ -663,7 +661,7 @@ def check_and_setup_tritonbench() -> None: installing_marker = (benchmarks_dir / ".tritonbench_installing").resolve() try: - import tritonbench # pyright: ignore[reportMissingImports] + import tritonbench module_file = getattr(tritonbench, "__file__", None) tb_repo_path = tritonbench_path.resolve() @@ -785,7 +783,7 @@ def is_local(path: Path) -> bool: importlib.invalidate_caches() try: - import tritonbench # pyright: ignore[reportMissingImports] + import tritonbench print("Tritonbench installed successfully.", file=sys.stderr) if installing_marker.exists(): @@ -841,11 +839,11 @@ def run_kernel( tritonbench_module = mapping[0] module_path = mapping[1] func_name = mapping[2] - operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues] + operator_args = mapping[3] variants = [(module_path, func_name)] else: # Without args - assert len(mapping) == 3 # Type narrowing for pyright + assert len(mapping) == 3 tritonbench_module, module_path, func_name = mapping variants = [(module_path, func_name)] @@ -873,9 +871,7 @@ def run_kernel_variants( """Run kernel variants in the same benchmark run.""" # Import tritonbench components - from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports] - get_parser, - ) + from tritonbench.utils.parser import get_parser from tritonbench.utils.triton_op import BenchmarkOperator from tritonbench.utils.triton_op import BenchmarkOperatorMetrics @@ -944,9 +940,7 @@ def run_kernel_variants( sys.exit(1) # Import register_benchmark API - from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports] - register_benchmark, - ) + from tritonbench.utils.triton_op import register_benchmark # Register all variants as separate methods for module_path, func_name in variants: diff --git a/docs/conf.py b/docs/conf.py index f09283635..c339fbd23 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,7 +9,7 @@ from typing import Callable from typing import Protocol -import pytorch_sphinx_theme2 # pyright: ignore[reportMissingImports] +import pytorch_sphinx_theme2 # -- Path setup -------------------------------------------------------------- diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index 1f4a0445c..23103a9e3 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -208,7 +208,7 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: dist_group = dist.group.WORLD if dist_group is None: raise RuntimeError("No distributed group available") - ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue] + ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( golden_a, [b], gather_dim=0, group_name=dist_group.group_name ) torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1) diff --git a/examples/all_reduce.py b/examples/all_reduce.py index 7396e9b9c..3e4715b9d 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -72,7 +72,7 @@ def dev_array_to_tensor_short( Returns: PyTorch tensor created from the device pointer """ - return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # pyright: ignore[reportAttributeAccessIssue] + return cpp_mod.from_blob(dev_array_ptr, shape, dtype) # %% @@ -228,7 +228,7 @@ def reference_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: symm_mem.rendezvous(a_shared_clone, dist_group.group_name) a_shared_clone.copy_(a_shared) - return torch.ops.symm_mem.one_shot_all_reduce( # pyright: ignore[reportCallIssue] + return torch.ops.symm_mem.one_shot_all_reduce( a_shared_clone, "sum", dist_group.group_name ) diff --git a/examples/attention.py b/examples/attention.py index 5a8a1c32e..244c54701 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -94,9 +94,9 @@ def attention( # --------------------- # %% -attention_dynamic: object = helion.kernel( # pyright: ignore[reportCallIssue] +attention_dynamic: object = helion.kernel( attention.fn, - configs=attention.configs, # pyright: ignore[reportArgumentType] + configs=attention.configs, static_shapes=False, ) """ diff --git a/examples/blackwell_attention.py b/examples/blackwell_attention.py index d96697071..9d1a87677 100644 --- a/examples/blackwell_attention.py +++ b/examples/blackwell_attention.py @@ -158,7 +158,7 @@ def blackwell_attention_kernel( qk = hl.dot(q_i, k_j.T, out_dtype=torch.float32) m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale) if VECT_MUL == 2 or VECT_MUL == 3: - qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) # pyright: ignore[reportArgumentType] + qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] @@ -267,7 +267,7 @@ def ref_attention( atol=0.1, rtol=0.1, ) - dur: float = do_bench(lambda: blackwell_attention(q, k, v)) # pyright: ignore[reportArgumentType, reportAssignmentType] + dur: float = do_bench(lambda: blackwell_attention(q, k, v)) print( f"{z=} {h=} {n_ctx=} {head_dim=} tflops={z * h * n_ctx * n_ctx * head_dim * 4 / dur * 1e-9:.2f}" ) diff --git a/examples/fused_linear_jsd.py b/examples/fused_linear_jsd.py index 83b90af16..0ce9f0bdc 100644 --- a/examples/fused_linear_jsd.py +++ b/examples/fused_linear_jsd.py @@ -85,7 +85,7 @@ def fused_linear_jsd_fwd_tritonbench( label: torch.Tensor | None = None, ) -> Callable[[], torch.Tensor]: assert label is None - baseline_op = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue] + baseline_op = tb_op.baseline_op beta = baseline_op.jsd.beta ignore_index = baseline_op.jsd.ignore_index temperature = baseline_op.temperature diff --git a/examples/geglu.py b/examples/geglu.py index 88301b8cd..ab8d49e99 100644 --- a/examples/geglu.py +++ b/examples/geglu.py @@ -343,9 +343,9 @@ def geglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Extract configuration from tritonbench operator config = Config( - hidden_size=tb_op.hidden_size, # pyright: ignore[reportAttributeAccessIssue] - intermediate_size=tb_op.intermediate_size, # pyright: ignore[reportAttributeAccessIssue] - hidden_act=tb_op.hidden_act, # pyright: ignore[reportAttributeAccessIssue] + hidden_size=tb_op.hidden_size, + intermediate_size=tb_op.intermediate_size, + hidden_act=tb_op.hidden_act, ) # Create Helion model @@ -353,7 +353,7 @@ def geglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Copy weights from tritonbench baseline model (LlamaMLP) to ensure fairness # LlamaMLP has: gate_proj, up_proj, down_proj (same structure as our HelionGEGLUMLP) - baseline_model = tb_op.baseline_model # pyright: ignore[reportAttributeAccessIssue] + baseline_model = tb_op.baseline_model # Copy gate projection weights helion_mlp.gate_proj.weight.data.copy_(baseline_model.gate_proj.weight.data) diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index 0780c67cc..bf02f80a6 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -177,12 +177,12 @@ def grouped_gemm_jagged_persistent( tile_in_group = local_tile * num_workers + worker_id if tile_in_group < num_group_tiles: # Convert linear tile index to 2D (M, N) tile coordinates - m_tile_idx = tile_in_group % num_m_tiles # pyright: ignore[reportOperatorIssue] + m_tile_idx = tile_in_group % num_m_tiles n_tile_idx = tile_in_group // num_m_tiles # Compute global memory indices for current tile base_row = group_start + m_tile_idx * BLOCK_M - base_col = n_tile_idx * BLOCK_N # pyright: ignore[reportOperatorIssue] + base_col = n_tile_idx * BLOCK_N # Generate row and column index ranges for tile access row_idx = base_row + hl.arange(BLOCK_M) diff --git a/examples/jagged_hstu_attn.py b/examples/jagged_hstu_attn.py index 5384474d0..4e81c40d9 100644 --- a/examples/jagged_hstu_attn.py +++ b/examples/jagged_hstu_attn.py @@ -22,9 +22,7 @@ import helion.language as hl try: - from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports] - triton_hstu_mha, - ) + from generative_recommenders.ops.triton.triton_hstu_attention import triton_hstu_mha HAS_HAMMER = True except ImportError: @@ -249,7 +247,7 @@ def _triton_hstu_mha( num_targets: torch.Tensor | None, max_seq_len: int, ) -> torch.Tensor: - return triton_hstu_mha( # pyright: ignore[reportPossiblyUnboundVariable,reportCallIssue] + return triton_hstu_mha( max_seq_len, alpha=1.0 / v.size(2) ** 2, q=q, diff --git a/examples/jagged_layer_norm.py b/examples/jagged_layer_norm.py index 3a0544abd..e857ef143 100644 --- a/examples/jagged_layer_norm.py +++ b/examples/jagged_layer_norm.py @@ -192,7 +192,7 @@ def reference_jagged_layer_norm_pytorch( [ torch.nn.functional.layer_norm( x_values[x_offsets[i] : x_offsets[i + 1], :], - x_values[x_offsets[i] : x_offsets[i + 1], :].shape, # pyright: ignore[reportArgumentType] + x_values[x_offsets[i] : x_offsets[i + 1], :].shape, eps=eps, ) for i in range(x_offsets.shape[0] - 1) @@ -225,7 +225,7 @@ def jagged_layer_norm_tritonbench( Callable that returns normalized tensor values """ x_values = x._values - x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue] + x_offsets = x._offsets return lambda: jagged_layer_norm_kernel(x_values, x_offsets, eps=1e-6) diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index 1fdd041dd..b279f6b5b 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -166,13 +166,13 @@ def jagged_mean_tritonbench( Callable that returns tensor of shape (B, M) with mean values per row and feature """ x_values = x._values - x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue] + x_offsets = x._offsets feature_counts = torch.full( (B,), M, dtype=torch.int32, - device=x_values.device, # pyright: ignore[reportAttributeAccessIssue] + device=x_values.device, ) return lambda: jagged_mean_kernel(x_values, x_offsets, feature_counts, M) diff --git a/examples/jagged_softmax.py b/examples/jagged_softmax.py index 2487b1d30..ee80fec93 100644 --- a/examples/jagged_softmax.py +++ b/examples/jagged_softmax.py @@ -163,7 +163,7 @@ def jagged_softmax_tritonbench( Returns: Callable that returns tensor of shape (N, M), where N = total number of rows in the jagged tensor """ - return lambda: jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] + return lambda: jagged_softmax_kernel(x._values, x._offsets) # %% diff --git a/examples/jagged_sum.py b/examples/jagged_sum.py index d1e0fcdae..a56d018d7 100644 --- a/examples/jagged_sum.py +++ b/examples/jagged_sum.py @@ -145,7 +145,7 @@ def jagged_sum_tritonbench( Callable that returns tensor of shape (B, M) with mean values per row and feature """ x_values = x._values - x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue] + x_offsets = x._offsets return lambda: jagged_sum_kernel(x_values, x_offsets) diff --git a/examples/jsd.py b/examples/jsd.py index 6ade1b7e3..979bdf70f 100644 --- a/examples/jsd.py +++ b/examples/jsd.py @@ -318,7 +318,7 @@ def jsd_tritonbench(tb_op: object, log_q: Tensor, log_p: Tensor) -> Callable: Callable: A callable that runs the JSD kernel """ - baseline_model = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue] + baseline_model = tb_op.baseline_op helion_jsd = HelionJSD( beta=baseline_model.beta, diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 809df10f6..65f9c2c47 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -133,7 +133,7 @@ def layer_norm_bwd( grad_w_acc += torch.sum(dy_mb * x_hat, dim=0) if compute_bias_grad: - grad_b_acc += torch.sum(dy_mb, dim=0) # pyright: ignore[reportPossiblyUnboundVariable] + grad_b_acc += torch.sum(dy_mb, dim=0) wdy = weight_cta * dy_mb c1 = torch.sum(x_hat * wdy, dim=-1) / n diff --git a/examples/swiglu.py b/examples/swiglu.py index f597d41ea..7ab6e36af 100644 --- a/examples/swiglu.py +++ b/examples/swiglu.py @@ -312,9 +312,9 @@ def swiglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Extract configuration from tritonbench operator config = Config( - hidden_size=tb_op.hidden_size, # pyright: ignore[reportAttributeAccessIssue] - intermediate_size=tb_op.intermediate_size, # pyright: ignore[reportAttributeAccessIssue] - hidden_act=tb_op.hidden_act, # pyright: ignore[reportAttributeAccessIssue] + hidden_size=tb_op.hidden_size, + intermediate_size=tb_op.intermediate_size, + hidden_act=tb_op.hidden_act, ) # Create Helion model @@ -322,7 +322,7 @@ def swiglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Copy weights from tritonbench baseline model (LlamaMLP) to ensure fairness # LlamaMLP has: gate_proj, up_proj, down_proj (same structure as our HelionGEGLUMLP) - baseline_model = tb_op.baseline_op # pyright: ignore[reportAttributeAccessIssue] + baseline_model = tb_op.baseline_op # Copy gate projection weights helion_mlp.gate_proj.weight.data.copy_(baseline_model.gate_proj.weight.data) diff --git a/helion/_compat.py b/helion/_compat.py index b7a2bae4c..ed1a3c666 100644 --- a/helion/_compat.py +++ b/helion/_compat.py @@ -251,9 +251,7 @@ def _min_dot_size( return (16, 16, 16) if torch.xpu.is_available(): - from triton.backends.intel.compiler import ( # pyright: ignore[reportMissingImports] - min_dot_size as min_dot_size_xpu, - ) + from triton.backends.intel.compiler import min_dot_size as min_dot_size_xpu device_properties = torch.xpu.get_device_properties() gpu_target_info = { @@ -265,7 +263,7 @@ def _min_dot_size( dot_size_val = min_dot_size_xpu(gpu_target_info)( torch_dtype_to_tl(lhs), torch_dtype_to_tl(rhs) ) - return tuple(int(v) for v in dot_size_val) # pyright: ignore[reportReturnType] + return tuple(int(v) for v in dot_size_val) from triton.backends.nvidia.compiler import min_dot_size as min_dot_size_cuda diff --git a/helion/_compiler/ast_extension.py b/helion/_compiler/ast_extension.py index 74efef57a..90a4bd447 100644 --- a/helion/_compiler/ast_extension.py +++ b/helion/_compiler/ast_extension.py @@ -137,7 +137,7 @@ class Wrapper(ExtendedAST, cls): def create(cls: type[_T], **fields: object) -> _T: - result = get_wrapper_cls(cls)(**fields, _location=current_location()) # pyright: ignore[reportCallIssue] + result = get_wrapper_cls(cls)(**fields, _location=current_location()) assert isinstance(result, ExtendedAST) result._location.to_ast(result) return typing.cast("_T", result) @@ -215,7 +215,7 @@ def make_unique(m: re.Match[str]) -> str: def _replace(node: _R) -> _R: # Handle lists by recursively transforming each element if isinstance(node, list): - return [_replace(item) for item in node] # pyright: ignore[reportReturnType] + return [_replace(item) for item in node] # Pass through non-AST nodes unchanged (e.g., strings, numbers) if not isinstance(node, ast.AST): @@ -223,14 +223,14 @@ def _replace(node: _R) -> _R: # Replace placeholder names with their corresponding AST nodes if isinstance(node, ast.Name) and node.id in mapping: - return mapping[node.id] # pyright: ignore[reportReturnType] + return mapping[node.id] # Recursively transform all child nodes and wrap in ExtendedAST subclass cls = get_wrapper_cls(type(node)) - return location.to_ast( # pyright: ignore[reportReturnType] + return location.to_ast( cls( **{field: _replace(getattr(node, field)) for field in node._fields}, - _location=location, # pyright: ignore[reportCallIssue] + _location=location, ) ) @@ -256,7 +256,7 @@ def convert(node: ast.AST) -> ast.AST: return cls( **{field: convert(getattr(node, field)) for field in node._fields}, **{attr: getattr(node, attr) for attr in node._attributes}, - _location=location, # pyright: ignore[reportCallIssue] + _location=location, ) elif isinstance(node, list): return [convert(item) for item in node] @@ -290,9 +290,7 @@ def visit(self, node: ast.AST) -> ast.AST: ) -class _TupleParensRemovedUnparser( - ast._Unparser # pyright: ignore[reportAttributeAccessIssue] -): +class _TupleParensRemovedUnparser(ast._Unparser): def visit_Tuple(self, node: ast.Tuple) -> None: if _needs_to_remove_tuple_parens and isinstance( getattr(node, "ctx", None), ast.Store @@ -308,7 +306,7 @@ def visit_Tuple(self, node: ast.Tuple) -> None: class _LocationAnnotatingOutputLines(OutputLines): - def __init__(self, parent: ast._Unparser) -> None: # pyright: ignore[reportAttributeAccessIssue] + def __init__(self, parent: ast._Unparser) -> None: super().__init__(parent) self._cache: dict[tuple[str, int, int], tuple[str, ...]] = {} self._last_location_key: tuple[str, int, int] | None = None @@ -427,7 +425,7 @@ def maybe_newline(self) -> None: # type: ignore[override] return super().maybe_newline() - def traverse(self, node: ast.AST | list[ast.AST]) -> None: # pyright: ignore[reportSignatureIssue] + def traverse(self, node: ast.AST | list[ast.AST]) -> None: if ( self._output_origin_lines and isinstance(node, ExtendedAST) diff --git a/helion/_compiler/ast_read_writes.py b/helion/_compiler/ast_read_writes.py index 7821d7564..a5c5d7d10 100644 --- a/helion/_compiler/ast_read_writes.py +++ b/helion/_compiler/ast_read_writes.py @@ -43,7 +43,7 @@ class ReadWrites(typing.NamedTuple): reads: dict[str, int] writes: dict[str, int] - def __iter__(self) -> typing.Iterator[str]: # pyright: ignore[reportIncompatibleMethodOverride] + def __iter__(self) -> typing.Iterator[str]: return iter({**self.reads, **self.writes}) @staticmethod diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index c6967c462..944d1a52c 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -116,9 +116,7 @@ def register_lowering( return lowering -sym_size_lowering = register_lowering( - torch.ops.aten.sym_size.int # pyright: ignore[reportAttributeAccessIssue] -) +sym_size_lowering = register_lowering(torch.ops.aten.sym_size.int) @sym_size_lowering.register_codegen("triton") @@ -143,7 +141,7 @@ def codegen_getitem(ctx: LoweringContext, node: Node) -> object: full_lowering = register_lowering( - torch.ops.aten.full.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.full.default, masked_value_fn=lambda n: ( n.args[1] if isinstance(n.args[1], (int, float, bool)) else None ), @@ -163,7 +161,7 @@ def codegen_full(ctx: LoweringContext, node: Node) -> object: if isinstance(value_ast, (int, float, bool)): value_ast = expr_from_string(constant_repr(value_ast)) assert isinstance(value_ast, ast.AST), value_ast - shape_str = ctx.cg.device_function.tile_strategy.shape_str([*size]) # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] + shape_str = ctx.cg.device_function.tile_strategy.shape_str([*size]) return expr_from_string( f"tl.full({shape_str}, {{value}}, {triton_type(dtype)})", value=value_ast, @@ -171,7 +169,7 @@ def codegen_full(ctx: LoweringContext, node: Node) -> object: unsqueeze_lowering = register_lowering( - torch.ops.aten.unsqueeze.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.unsqueeze.default, masked_value_fn=passthrough_masked_value, ) @@ -182,7 +180,7 @@ def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object: tensor, dim = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) assert isinstance(tensor, ast.AST) assert isinstance(dim, int) - ndim = node.args[0].meta["val"].ndim # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + ndim = node.args[0].meta["val"].ndim if dim < 0: dim += ndim assert 0 <= dim <= ndim, f"Invalid dim {dim} for tensor with {ndim} dims" @@ -195,15 +193,15 @@ def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object: squeeze_lowering = register_lowering( - torch.ops.aten.squeeze.dim, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.squeeze.dim, masked_value_fn=passthrough_masked_value, ) view_lowering = register_lowering( - torch.ops.aten.view.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.view.default, masked_value_fn=passthrough_masked_value, ) reshape_lowering = register_lowering( - torch.ops.aten.reshape.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.reshape.default, masked_value_fn=passthrough_masked_value, ) @@ -222,7 +220,7 @@ def codegen_view(ctx: LoweringContext, node: Node) -> object: permute_lowering = register_lowering( - torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.permute.default, masked_value_fn=passthrough_masked_value, ) @@ -232,7 +230,7 @@ def codegen_permute(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "getitem kwargs not supported" tensor, dims = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) assert isinstance(tensor, ast.AST) - dims = [*dims] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] + dims = [*dims] assert {*dims} == {*range(len(dims))}, dims return expr_from_string( f"tl.permute({{tensor}}, {dims!r})", @@ -241,7 +239,7 @@ def codegen_permute(ctx: LoweringContext, node: Node) -> object: stack_lowering = register_lowering( - torch.ops.aten.stack.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.stack.default, masked_value_fn=passthrough_masked_value, ) @@ -252,7 +250,7 @@ def codegen_stack(ctx: LoweringContext, node: Node) -> object: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) assert isinstance(tensors, (list, tuple)) - tensor_asts = [ctx.env[t] for t in tensors] # pyright: ignore[reportArgumentType] + tensor_asts = [ctx.env[t] for t in tensors] n = len(tensor_asts) if n == 0: @@ -300,7 +298,7 @@ def codegen_stack(ctx: LoweringContext, node: Node) -> object: expand_lowering = register_lowering( - torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.expand.default, masked_value_fn=passthrough_masked_value, ) @@ -313,9 +311,9 @@ def codegen_expand(ctx: LoweringContext, node: Node) -> object: val = node.meta["val"] assert isinstance(val, torch.Tensor) shape = [*val.size()] - if node.args[0].meta["val"].ndim != len(shape): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + if node.args[0].meta["val"].ndim != len(shape): broadcasting = [":"] * len(shape) - for i in range(len(shape) - node.args[0].meta["val"].ndim): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + for i in range(len(shape) - node.args[0].meta["val"].ndim): broadcasting[i] = "None" tensor = expr_from_string( f"{{tensor}}[{', '.join(broadcasting)}]", tensor=tensor @@ -381,13 +379,13 @@ def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST: "FP8 GEMM via torch API is not supported yet. Please use hl.dot() instead." ) - lhs_shape = list(lhs_node.meta["val"].size()) # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - rhs_shape = list(rhs_node.meta["val"].size()) # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + lhs_shape = list(lhs_node.meta["val"].size()) + rhs_shape = list(rhs_node.meta["val"].size()) acc_shape = ( list(acc_node.meta["val"].size()) if (with_acc and acc_node is not None) else None - ) # pyright: ignore[reportOptionalMemberAccess] + ) # Extract expected output dtype from FX node to match PyTorch eager mode behavior out_dtype: torch.dtype | None = None @@ -409,11 +407,11 @@ def reduce_3d_dot(ctx: LoweringContext, node: Node, with_acc: bool) -> ast.AST: bmm_lowering = register_lowering( - torch.ops.aten.bmm.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.bmm.default, apply_dot_requirements, ) mm_lowering = register_lowering( - torch.ops.aten.mm.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.mm.default, apply_dot_requirements, ) @@ -427,7 +425,7 @@ def codegen_mm(ctx: LoweringContext, node: Node) -> ast.AST: addmm_lowering = register_lowering( - torch.ops.aten.addmm.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.addmm.default, apply_dot_requirements, ) @@ -439,7 +437,7 @@ def codegen_addmm(ctx: LoweringContext, node: Node) -> ast.AST: baddbmm_lowering = register_lowering( - torch.ops.aten.baddbmm.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.baddbmm.default, apply_dot_requirements, ) @@ -450,9 +448,7 @@ def codegen_baddbmm(ctx: LoweringContext, node: Node) -> ast.AST: return reduce_3d_dot(ctx, node, True) -iota_lowering = register_lowering( - torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue] -) +iota_lowering = register_lowering(torch.ops.prims.iota.default) @iota_lowering.register_codegen("triton") @@ -578,9 +574,7 @@ def _codegen_rng_op( return rng_expr -rand_lowering = register_lowering( - torch.ops.aten.rand.default # pyright: ignore[reportAttributeAccessIssue] -) +rand_lowering = register_lowering(torch.ops.aten.rand.default) @rand_lowering.register_codegen("triton") @@ -588,9 +582,7 @@ def codegen_rand(ctx: LoweringContext, node: Node) -> object: return _codegen_rng_op(ctx, node, "rand") -randn_lowering = register_lowering( - torch.ops.aten.randn.default # pyright: ignore[reportAttributeAccessIssue] -) +randn_lowering = register_lowering(torch.ops.aten.randn.default) @randn_lowering.register_codegen("triton") diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index a1a77afbf..626fad5b2 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -262,7 +262,7 @@ def cached_create_unbacked_symint( A consistent unbacked symint for the given key """ - key = tuple([x._sympy_() if hasattr(x, "_sympy_") else x for x in key]) # pyright: ignore[reportAttributeAccessIssue] + key = tuple([x._sympy_() if hasattr(x, "_sympy_") else x for x in key]) result = self._symint_cache.get(key) if result is None: result = self.create_unbacked_symint(hint) @@ -316,10 +316,7 @@ def to_fake(self, obj: object, origin: Origin) -> object: return [self.to_fake(e, origin) for e in obj] if isinstance(obj, tuple) and hasattr(obj, "_fields"): return type(obj)( - **{ - k: self.to_fake(e, origin) - for k, e in obj._asdict().items() # pyright: ignore[reportAttributeAccessIssue] - } + **{k: self.to_fake(e, origin) for k, e in obj._asdict().items()} ) if isinstance(obj, tuple): return tuple(self.to_fake(e, origin) for e in obj) @@ -327,7 +324,7 @@ def to_fake(self, obj: object, origin: Origin) -> object: return {k: self.to_fake(e, origin) for k, e in obj.items()} if dataclasses.is_dataclass(obj): return dataclasses.replace( - obj, # pyright: ignore[reportArgumentType] + obj, **{ k: self.to_fake(getattr(obj, k), origin) for k in obj.__dataclass_fields__ @@ -369,7 +366,7 @@ def size_hint(self, n: int | torch.SymInt) -> int: # hint will be wrong since we assign a default value to unbacked symbols. Return a default hint. return 8192 - return int(self.shape_env.size_hint(n._sympy_())) # pyright: ignore[reportArgumentType] + return int(self.shape_env.size_hint(n._sympy_())) assert isinstance(n, int) return n @@ -640,7 +637,7 @@ def _to_sympy(x: int | torch.SymInt | sympy.Expr) -> sympy.Expr: def _has_unbacked(expr: sympy.Expr) -> bool: - return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue] + return any(n.name.startswith("u") for n in expr.free_symbols) def format_shape(shape: tuple[object, ...]) -> str: diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 49955f10b..c12d56d06 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -78,7 +78,7 @@ def find_block_size_symbols( non_block_size_symbols = set() for symbol in expr.free_symbols: - origin_info = hf.expr_to_origin.get(symbol) # pyright: ignore[reportArgumentType] + origin_info = hf.expr_to_origin.get(symbol) if origin_info is None or not isinstance(origin_info.origin, BlockSizeOrigin): non_block_size_symbols.add(symbol) else: @@ -258,7 +258,6 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.rng_seed_buffer_param_name = None def get_indexing_strategy(self, index: int) -> IndexingStrategy: - from .indexing_strategy import IndexingStrategy from .indexing_strategy import PointerIndexingStrategy @@ -378,7 +377,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str: if expr in expr_to_origin: return self._lift_sympy_arg(expr) replacements = {} - for sym in sorted(expr.free_symbols, key=lambda x: x.name): # pyright: ignore[reportAttributeAccessIssue] + for sym in sorted(expr.free_symbols, key=lambda x: x.name): assert isinstance(sym, sympy.Symbol) if sym in self.expr_to_var_info: replacements[sym] = sympy.Symbol( @@ -411,7 +410,7 @@ def _lift_sympy_arg(self, expr: sympy.Expr) -> str: def user_sympy_expr(self, expr: sympy.Expr) -> str: """A sympy expression that flows into user computations.""" replacements = {} - for sym in sorted(expr.free_symbols, key=lambda s: s.name): # pyright: ignore[reportAttributeAccessIssue] + for sym in sorted(expr.free_symbols, key=lambda s: s.name): assert isinstance(sym, sympy.Symbol) block_idx = CompileEnvironment.current().get_block_id(sym) if block_idx is not None: @@ -552,16 +551,16 @@ def _format_constexpr_value(self, value: object) -> str: # Handle sympy expressions (sanitize by replacing triton_helpers functions) if isinstance(value, sympy.Expr): - sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue] + sanitized = value.replace( lambda node: isinstance(node, sympy.Function) and getattr(node.func, "__name__", "") == "triton_helpers.div_floor_integer", - lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue] - ).replace( # pyright: ignore[reportAttributeAccessIssue] + lambda node: sympy.floor(node.args[0] / node.args[1]), + ).replace( lambda node: isinstance(node, sympy.Function) and getattr(node.func, "__name__", "") == "triton_helpers.remainder_integer", - lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue] + lambda node: sympy.Mod(node.args[0], node.args[1]), ) expr = cast("sympy.Expr", sanitized) return HostFunction.current().sympy_expr(expr) @@ -684,9 +683,7 @@ def dead_code_elimination(self) -> None: # drop any unused args args_to_remove = { - arg.name - for arg in self.arguments - if arg.name not in rw.reads # pyright: ignore[reportPossiblyUnboundVariable] + arg.name for arg in self.arguments if arg.name not in rw.reads } if args_to_remove: self.arguments = [ diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 6e5ca394f..d7612e4b3 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -106,7 +106,7 @@ def _get_proxy_slot( if obj not in tracker: origin = HostFunction.current().tensor_to_origin[obj] assert origin.is_host() - tracker[obj] = proxy = tracer.create_proxy( # pyright: ignore[reportArgumentType] + tracker[obj] = proxy = tracer.create_proxy( "call_function", _tracing_ops._host_tensor, (origin.host_str(),), @@ -120,7 +120,7 @@ def _get_proxy_slot( tracker = tracer.symnode_tracker if obj not in tracker: debug_name = CompileEnvironment.current().sympy_debug(obj._sympy_()) - tracker[obj] = proxy = tracer.create_proxy( # pyright: ignore[reportArgumentType] + tracker[obj] = proxy = tracer.create_proxy( "call_function", _tracing_ops._get_symnode, (debug_name,), @@ -129,7 +129,7 @@ def _get_proxy_slot( ) proxy.node.meta["val"] = obj proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._get_symnode) - proxy.force = lambda: proxy # pyright: ignore[reportAttributeAccessIssue] + proxy.force = lambda: proxy return transform(tracker[obj]) return get_proxy_slot(obj, tracer, default, transform) @@ -138,9 +138,9 @@ def _get_proxy_slot( preserve_node_meta(), patch.object(proxy_tensor, "get_proxy_slot", _get_proxy_slot), patch.object( - torch.fx.proxy, # pyright: ignore[reportAttributeAccessIssue] + torch.fx.proxy, "_COPY_META_FIELDS", - [*torch.fx.proxy._COPY_META_FIELDS, "location"], # pyright: ignore[reportAttributeAccessIssue] + [*torch.fx.proxy._COPY_META_FIELDS, "location"], ), patch.object(torch, "matmul", torch_matmul_replacement), patch.object( @@ -530,7 +530,7 @@ def _assign(self, target: ast.AST, value: object) -> None: if isinstance(n, ast.Starred): raise exc.StarredArgsNotSupportedOnDevice - self._assign(n, value[i]) # pyright: ignore[reportIndexIssue] + self._assign(n, value[i]) elif isinstance(target, ast.Subscript): dst = self.visit(target.value) assert isinstance(value, torch.Tensor) @@ -1059,7 +1059,7 @@ def visit_Slice(self, node: ast.Slice) -> slice | torch.Tensor: # Convert slice to hl.arange when step is None or 1 and we have both bounds # This allows FX tracing to handle slice operations with dynamic bounds if lower is not None and upper is not None and (step is None or step == 1): - return hl.arange(lower, upper) # pyright: ignore[reportArgumentType] + return hl.arange(lower, upper) return slice(lower, upper, step) @@ -1106,7 +1106,7 @@ def visit_Assign(self, node: ast.Assign) -> None: raise exc.NonTensorSubscriptAssign(lhs_type, rhs_type) assert isinstance(target.value, ExtendedAST) assert target.value._type_info is not None - target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess] + target_origin = target.value._type_info.origin if not target_origin.is_host() and not isinstance( target.value._type_info, StackTensorType ): @@ -1139,9 +1139,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None: ) return hl.store( - self.visit(target.value), # pyright: ignore[reportArgumentType] + self.visit(target.value), self._subscript_slice_proxy(target.slice), - val, # pyright: ignore[reportArgumentType] + val, ) def visit_AnnAssign(self, node: ast.AnnAssign) -> None: @@ -1168,39 +1168,37 @@ def visit_Subscript(self, node: ast.Subscript) -> object: if isinstance(type_info, SequenceType): index_value = self.visit(node.slice) if isinstance(index_value, int): - return self.visit(value)[index_value] # pyright: ignore[reportIndexIssue] + return self.visit(value)[index_value] raise exc.InvalidSequenceSubscription(node.slice) if isinstance(type_info, StackTensorType): - return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] + return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) if type_info is not None and type_info.origin.is_host(): - return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] - return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType] + return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) + return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) def visit_Call(self, node: ast.Call) -> object: args = [] kwargs = {} for arg in node.args: if isinstance(arg, ast.Starred): - args.extend(self.visit(arg.value)) # pyright: ignore[reportArgumentType] + args.extend(self.visit(arg.value)) else: args.append(self.visit(arg)) for kwarg in node.keywords: if kwarg.arg is None: - kwargs.update(self.visit(kwarg.value)) # pyright: ignore[reportArgumentType,reportCallIssue] + kwargs.update(self.visit(kwarg.value)) else: kwargs[kwarg.arg] = self.visit(kwarg.value) if isinstance( - ( - func_type_info := node.func._type_info # pyright: ignore[reportAttributeAccessIssue] - ), + (func_type_info := node.func._type_info), CallableType, ) and (replacement := get_device_func_replacement(func_type_info.value)): func = replacement else: func = self.visit(node.func) - return _CheckForIndexCalls.retry_call(func, args, kwargs) # pyright: ignore[reportArgumentType] + return _CheckForIndexCalls.retry_call(func, args, kwargs) def visit_Attribute(self, node: ast.Attribute) -> object: return getattr(self.visit(node.value), node.attr) @@ -1260,13 +1258,13 @@ def visit_For(self, node: ast.For) -> None: self.device_ir.add_root_graph( _make_fx(lambda: WalkDeviceAST(self.device_ir).visit(node)) ) - iter_type = node.iter._type_info # pyright: ignore[reportAttributeAccessIssue] + iter_type = node.iter._type_info assert isinstance(iter_type, IterType) inner = iter_type.inner if isinstance(inner, SequenceType): - block_ids = [x.block_id for x in inner.unpack()] # pyright: ignore[reportAttributeAccessIssue] + block_ids = [x.block_id for x in inner.unpack()] else: - block_ids = [inner.block_id] # pyright: ignore[reportAttributeAccessIssue] + block_ids = [inner.block_id] self.device_ir.grid_block_ids.append(block_ids) else: self.generic_visit(node) diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 51681e9a1..c629e9711 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -218,7 +218,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: fields[field] = self.visit(old_value) else: fields[field] = old_value - return node.new(fields) # pyright: ignore[reportReturnType] + return node.new(fields) def visit_For(self, node: ast.For) -> ast.AST | None: assert isinstance(node, ExtendedAST) @@ -239,7 +239,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: ) ) self.device_function.body.extend( - self.device_function.pid.codegen_pid_init() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + self.device_function.pid.codegen_pid_init() ) if node._root_id < len(self.host_function.device_ir.root_ids) - 1: body = [] @@ -288,7 +288,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: self, fx_node=None, proxy_args=[*bound.arguments.values()], - ast_args=None, # pyright: ignore[reportArgumentType] + ast_args=None, ) codegen_fn(state) @@ -318,7 +318,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: block.append( create( ast.If, - test=self.device_function.pid.codegen_test(state), # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + test=self.device_function.pid.codegen_test(state), body=body, orelse=self.next_else_block, ) @@ -329,7 +329,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: self.device_function ) if persistent_body is not None: - self.device_function.body = persistent_body # pyright: ignore[reportAttributeAccessIssue] + self.device_function.body = persistent_body self.device_function.dead_code_elimination() if not self.device_function.preamble and not self.device_function.body: raise exc.EmptyDeviceLoopAfterDCE @@ -376,7 +376,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: isinstance(x, TileIndexType) for x in type_info.unpack() ): values = type_info.unpack() - block_infos = [env.block_sizes[x.block_id] for x in values] # pyright: ignore[reportAttributeAccessIssue] + block_infos = [env.block_sizes[x.block_id] for x in values] return expr_from_string( self.host_function.literal_expr( [x.from_config(self.device_function.config) for x in block_infos] @@ -411,7 +411,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs) ast_params.apply_defaults() proxy_params.apply_defaults() - return codegen_fn( # pyright: ignore[reportReturnType] + return codegen_fn( CodegenState( self, None, diff --git a/helion/_compiler/helper_function.py b/helion/_compiler/helper_function.py index baccb5483..8193f60c3 100644 --- a/helion/_compiler/helper_function.py +++ b/helion/_compiler/helper_function.py @@ -55,7 +55,7 @@ def extract_helper_function(helper_fn: object) -> types.FunctionType: """ from ..runtime.kernel import Kernel - return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn # pyright: ignore[reportReturnType] + return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn def extract_helper_function_name(helper_fn: object) -> str: diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index 4238bc3ad..606b6a4f7 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -191,7 +191,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str: if expr in self.expr_to_origin: return self.expr_to_origin[expr].origin.host_str() replacements = {} - for sym in sorted(expr.free_symbols, key=lambda x: x.name): # pyright: ignore[reportAttributeAccessIssue] + for sym in sorted(expr.free_symbols, key=lambda x: x.name): assert isinstance(sym, sympy.Symbol) origin = self.expr_to_origin[sym].origin replacements[sym] = sympy.Symbol(origin.host_str(), integer=True) @@ -212,7 +212,7 @@ def debug_str(self) -> str: result = [ print_ast( self.location.to_ast( - ast.FunctionDef(self.name, self.args, self.body, [], None) # pyright: ignore[reportCallIssue] + ast.FunctionDef(self.name, self.args, self.body, [], None) ) ), self.device_ir.debug_str(), diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index f9d05bee6..983106081 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -49,7 +49,7 @@ def _get_padded_iota_original_length( index_node = state.fx_node.args[1][index_position] # type: ignore[union-attr, index] if ( isinstance(index_node, torch.fx.Node) - and index_node.target == torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue] + and index_node.target == torch.ops.prims.iota.default and isinstance(length_arg := index_node.args[0], int) and length_arg != next_power_of_2(length_arg) ): @@ -165,7 +165,7 @@ def codegen_load( f"tl.load({name} + {{offset}}, {{mask}}{extra})", offset=indexing.index_expr, mask=indexing.mask_expr, - ev=eviction_policy, # pyright: ignore[reportArgumentType] + ev=eviction_policy, ) def codegen_store( @@ -211,7 +211,7 @@ def codegen_load( expr_from_string( f"tl.load({{block_ptr}}, boundary_check={indexing.boundary_check(state)}, padding_option='zero'{extra})", block_ptr=indexing.make_block_ptr(state), - ev=eviction_policy, # pyright: ignore[reportArgumentType] + ev=eviction_policy, ), ) @@ -515,7 +515,7 @@ def codegen_load( base=dev_ptrs_ast, offset=indexing.index_expr, mask=mask_expr, - ev=eviction_policy, # pyright: ignore[reportArgumentType] + ev=eviction_policy, ) @staticmethod diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 347daae02..e9575c01d 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -15,9 +15,7 @@ from torch._dynamo.convert_frame import compile_lock from torch._inductor import config as inductor_config from torch._inductor import ir -from torch._inductor.codegen.simd import ( - SIMDKernelFeatures, # pyright: ignore[reportPrivateImportUsage] -) +from torch._inductor.codegen.simd import SIMDKernelFeatures from torch._inductor.codegen.triton import TritonKernel from torch._inductor.codegen.triton import TritonOverrides from torch._inductor.graph import GraphLowering @@ -172,10 +170,10 @@ def convert_arg(arg: Node) -> TensorBox: with node.meta["location"], graph_lowering.set_current_node(node): try: result = graph_lowering.call_function( - node.target, # pyright: ignore[reportArgumentType] - *map_arg((node.args, node.kwargs), convert_arg), # pyright: ignore[reportArgumentType] + node.target, + *map_arg((node.args, node.kwargs), convert_arg), ) - except torch._inductor.exc.LoweringException as e: # pyright: ignore[reportAttributeAccessIssue] + except torch._inductor.exc.LoweringException as e: # Wrap in Helion exception to get location automatically raise InductorLoweringError(str(e)) from e if not isinstance(result, tuple): @@ -194,9 +192,7 @@ def convert_arg(arg: Node) -> TensorBox: buffer_name_to_output_index[buffer.get_name()] = i new_buffers = graph_lowering.buffers[prior_buffers:] - assert ( - buffer in new_buffers # pyright: ignore[reportPossiblyUnboundVariable] - ) + assert buffer in new_buffers nodes = [] extra_input_names = [] new_node: torch.fx.Node @@ -632,7 +628,7 @@ def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: if len(inputs) == 1: repr_input = inputs[0] - elif node.meta["orig_node"].target == torch.ops.aten.var_mean.correction: # pyright: ignore[reportAttributeAccessIssue] + elif node.meta["orig_node"].target == torch.ops.aten.var_mean.correction: assert len(inputs) == 2 # `inputs[0]` is the original input tensor to var_mean repr_input = inputs[0] @@ -669,7 +665,7 @@ def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: # Non-looped reductions compute the value inline; cast now to ensure the # result dtype matches torch.* semantics reflected in meta["val"].dtype. - desired_dtype = node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue] + desired_dtype = node.meta["val"].dtype return cast_ast(result_ast, desired_dtype) def get_masked_value(self, node: torch.fx.Node) -> float | bool | None: @@ -687,7 +683,7 @@ def _get_reduction_dims(node: torch.fx.Node, fake_input: torch.Tensor) -> list[i dims = node.kwargs.get("dim", node.kwargs.get("dims")) if dims is None: - schema = node.meta["original_aten"]._schema # pyright: ignore[reportAttributeAccessIssue] + schema = node.meta["original_aten"]._schema assert isinstance(schema, torch._C.FunctionSchema) for index, arg in enumerate(schema.arguments): if arg.name in {"dim", "dims"}: @@ -743,8 +739,8 @@ def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: CodegenState( ctx.cg, fx_node=node, - proxy_args=proxy_args, # pyright: ignore[reportArgumentType] - ast_args=ast_args, # pyright: ignore[reportArgumentType] + proxy_args=proxy_args, + ast_args=ast_args, ), ) @@ -982,11 +978,11 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str: ): # Skip pure view ops; their dtype matches their input, which we've likely asserted already if node.op == "call_function" and node.target in ( - torch.ops.aten.unsqueeze.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.view.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.reshape.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.expand.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.unsqueeze.default, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + torch.ops.aten.expand.default, + torch.ops.aten.permute.default, ): return name expected_dtype = val.dtype @@ -1013,17 +1009,12 @@ def _collect_multi_outputs( assert "output_nodes" in node.meta output_nodes = node.meta["output_nodes"] outputs: list[object | None] = [None] * len(output_nodes) - all_nodes = { - n.name: n - for n in self.module.graph.nodes # pyright: ignore[reportAttributeAccessIssue,reportGeneralTypeIssues] - } + all_nodes = {n.name: n for n in self.module.graph.nodes} for idx, node_name in output_nodes.items(): if node_name == node.name: # This is the last node - outputs[idx] = ( # pyright: ignore[reportArgumentType,reportCallIssue] - last_node_result - ) + outputs[idx] = last_node_result else: # This is an extra node - get its result from env if node_name in all_nodes: @@ -1108,8 +1099,7 @@ def codegen_call_with_graph( placeholders = graph.find_nodes(op="placeholder") for arg, placeholder in zip(args, placeholders, strict=True): if all( - user.target == torch.ops.aten.sym_size.int # pyright: ignore[reportAttributeAccessIssue] - for user in placeholder.users + user.target == torch.ops.aten.sym_size.int for user in placeholder.users ): # TODO(jansel): we should remove these sym_size-only args from the graph new_args.append(arg) diff --git a/helion/_compiler/inductor_lowering_extra.py b/helion/_compiler/inductor_lowering_extra.py index d9b7c4e71..23d11bf50 100644 --- a/helion/_compiler/inductor_lowering_extra.py +++ b/helion/_compiler/inductor_lowering_extra.py @@ -36,15 +36,15 @@ def fp32_fallback_lowering(x: object) -> object: # Operations that need fp32 fallbacks due to libdevice/tl_math limitations FP32_FALLBACK_OPS_UNARY = [ - torch.ops.aten.rsqrt.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.sqrt.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.sin.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.cos.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.log.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.tanh.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.log1p.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.expm1.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.exp.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.rsqrt.default, + torch.ops.aten.sqrt.default, + torch.ops.aten.sin.default, + torch.ops.aten.cos.default, + torch.ops.aten.log.default, + torch.ops.aten.tanh.default, + torch.ops.aten.log1p.default, + torch.ops.aten.expm1.default, + torch.ops.aten.exp.default, ] # Register fp32 fallback lowerings for ops that don't support fp16/bfloat16 @@ -62,25 +62,25 @@ def patch_inductor_lowerings() -> Generator[None, Any, Any]: affecting the global state, especially in cases where Helion is missing support for a specific lowering. """ - original_lowerings = torch._inductor.lowering.lowerings.copy() # pyright: ignore[reportAttributeAccessIssue] + original_lowerings = torch._inductor.lowering.lowerings.copy() try: - torch._inductor.lowering.lowerings.update(inductor_lowering_dispatch) # pyright: ignore[reportAttributeAccessIssue] + torch._inductor.lowering.lowerings.update(inductor_lowering_dispatch) yield finally: - torch._inductor.lowering.lowerings = original_lowerings # pyright: ignore[reportAttributeAccessIssue] + torch._inductor.lowering.lowerings = original_lowerings -register_inductor_lowering = torch._inductor.lowering.register_lowering # pyright: ignore[reportAttributeAccessIssue] +register_inductor_lowering = torch._inductor.lowering.register_lowering def var_mean_helper_( - x: torch._inductor.ir.TensorBox, # pyright: ignore[reportAttributeAccessIssue] + x: torch._inductor.ir.TensorBox, *, axis: list[int] | None, correction: float | None, keepdim: bool, return_mean: bool, -) -> torch._inductor.ir.TensorBox: # pyright: ignore[reportAttributeAccessIssue] +) -> torch._inductor.ir.TensorBox: from torch._inductor.lowering import var_mean_sum_ from torch._prims_common import get_computation_dtype @@ -102,16 +102,16 @@ def var_mean_helper_( @register_inductor_lowering( - [torch.ops.aten.var.correction], # pyright: ignore[reportAttributeAccessIssue] + [torch.ops.aten.var.correction], lowering_dict=inductor_lowering_dispatch, ) def var_( - x: torch._inductor.ir.TensorBox, # pyright: ignore[reportAttributeAccessIssue] + x: torch._inductor.ir.TensorBox, axis: list[int] | None = None, *, correction: float | None = None, keepdim: bool = False, -) -> torch._inductor.ir.TensorBox: # pyright: ignore[reportAttributeAccessIssue] +) -> torch._inductor.ir.TensorBox: return var_mean_helper_( x, axis=axis, @@ -122,16 +122,16 @@ def var_( @register_inductor_lowering( - torch.ops.aten.var_mean.correction, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.var_mean.correction, lowering_dict=inductor_lowering_dispatch, ) def var_mean( - x: torch._inductor.ir.TensorBox, # pyright: ignore[reportAttributeAccessIssue] + x: torch._inductor.ir.TensorBox, axis: list[int] | None = None, *, correction: float | None = None, keepdim: bool = False, -) -> torch._inductor.ir.TensorBox: # pyright: ignore[reportAttributeAccessIssue] +) -> torch._inductor.ir.TensorBox: return var_mean_helper_( x, axis=axis, diff --git a/helion/_compiler/lift_closures.py b/helion/_compiler/lift_closures.py index 9ed301793..67598f5f7 100644 --- a/helion/_compiler/lift_closures.py +++ b/helion/_compiler/lift_closures.py @@ -67,4 +67,4 @@ def wrapper(*args: object, **kwargs: object) -> object: new_func: FunctionType | None = None closure_contents: list[object] = [] - return wrapper # pyright: ignore[reportReturnType] + return wrapper diff --git a/helion/_compiler/node_masking.py b/helion/_compiler/node_masking.py index 46c62289d..cf4562b5f 100644 --- a/helion/_compiler/node_masking.py +++ b/helion/_compiler/node_masking.py @@ -126,7 +126,7 @@ def feeds_reduction_input(node: torch.fx.Node) -> bool: input_node ): continue - node.replace_all_uses_with(input_node) # pyright: ignore[reportArgumentType] + node.replace_all_uses_with(input_node) graph.erase_node(node) diff --git a/helion/_compiler/program_id.py b/helion/_compiler/program_id.py index 5ea6db6d5..12ec1c944 100644 --- a/helion/_compiler/program_id.py +++ b/helion/_compiler/program_id.py @@ -139,7 +139,7 @@ class ForEachProgramID(ProgramIDs): Represent multiple top level for loops in the Helion kernel. Turns into `if` statements in generated code. """ - shared_pid_var: str # pyright: ignore[reportGeneralTypeIssues,reportIncompatibleVariableOverride] + shared_pid_var: str cases: list[ProgramIDs] = dataclasses.field(default_factory=list) pid_info: list[PIDInfo] = dataclasses.field(default_factory=list, init=False) diff --git a/helion/_compiler/roll_reduction.py b/helion/_compiler/roll_reduction.py index 262abcfa3..1b950a3b2 100644 --- a/helion/_compiler/roll_reduction.py +++ b/helion/_compiler/roll_reduction.py @@ -37,7 +37,7 @@ _duplicate_ops: tuple[object, ...] = ( _host_tensor, _get_symnode, - torch.ops.aten.sym_size.int, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.sym_size.int, ) # Ops that write to memory and should be treated specially when determining @@ -115,7 +115,7 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool: return False if node.target in _duplicate_ops: - if node.target is torch.ops.aten.sym_size.int: # pyright: ignore[reportAttributeAccessIssue] + if node.target is torch.ops.aten.sym_size.int: arg = node.args[0] assert isinstance(arg, torch.fx.Node) return self.should_go_in_inner_graph(arg) @@ -319,10 +319,10 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool: # Check multiple matmul-family operations if node.target not in ( - torch.ops.aten.mm.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.addmm.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.bmm.default, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.baddbmm.default, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.mm.default, + torch.ops.aten.addmm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.baddbmm.default, hl_dot, ): return False diff --git a/helion/_compiler/source_location.py b/helion/_compiler/source_location.py index cd77d5f82..249102597 100644 --- a/helion/_compiler/source_location.py +++ b/helion/_compiler/source_location.py @@ -70,20 +70,20 @@ def from_ast(node: ast.AST) -> SourceLocation: code = host_function.fn.__code__ offset = code.co_firstlineno - 1 return SourceLocation( - node.lineno + offset, # pyright: ignore[reportAttributeAccessIssue] - node.col_offset + host_function.column_offset, # pyright: ignore[reportAttributeAccessIssue] - node.end_lineno + offset, # pyright: ignore[reportAttributeAccessIssue] - node.end_col_offset + host_function.column_offset, # pyright: ignore[reportAttributeAccessIssue] + node.lineno + offset, + node.col_offset + host_function.column_offset, + node.end_lineno + offset, + node.end_col_offset + host_function.column_offset, filename=code.co_filename, name=code.co_name, ) def to_ast(self, node: _T) -> _T: - if "lineno" in node._attributes: # pyright: ignore[reportAttributeAccessIssue] - node.lineno = self.lineno # pyright: ignore[reportAttributeAccessIssue] - node.col_offset = self.colno # pyright: ignore[reportAttributeAccessIssue] - node.end_lineno = self.end_lineno # pyright: ignore[reportAttributeAccessIssue] - node.end_col_offset = self.end_colno # pyright: ignore[reportAttributeAccessIssue] + if "lineno" in node._attributes: + node.lineno = self.lineno + node.col_offset = self.colno + node.end_lineno = self.end_lineno + node.end_col_offset = self.end_colno return node def __str__(self) -> str: @@ -96,9 +96,9 @@ def format(self) -> str: return format_frame_summary(self) def _key(self) -> tuple[str, int | None, int, int, int]: - return (self.filename, self.lineno, self.colno, self.end_lineno, self.end_colno) # pyright: ignore[reportReturnType] + return (self.filename, self.lineno, self.colno, self.end_lineno, self.end_colno) - def __hash__(self) -> int: # pyright: ignore[reportIncompatibleVariableOverride] + def __hash__(self) -> int: return hash(self._key()) def __eq__(self, other: object) -> bool: diff --git a/helion/_compiler/static_loop_unroller.py b/helion/_compiler/static_loop_unroller.py index 15f822934..e4049f1ed 100644 --- a/helion/_compiler/static_loop_unroller.py +++ b/helion/_compiler/static_loop_unroller.py @@ -25,9 +25,7 @@ class StaticLoopUnroller(ast.NodeTransformer): def visit_For(self, node: ast.For) -> ast.AST | list[ast.AST]: # Generic visit to handle nested loops - node = self.generic_visit( # pyright: ignore[reportAssignmentType] - node - ) + node = self.generic_visit(node) # Check if this is a static loop that can be unrolled if static_values := self._extract_static_values(node.iter): diff --git a/helion/_compiler/tensor_utils.py b/helion/_compiler/tensor_utils.py index 0c0cc5035..d736b2366 100644 --- a/helion/_compiler/tensor_utils.py +++ b/helion/_compiler/tensor_utils.py @@ -13,14 +13,14 @@ class _PadTensorFactoryMode(TorchDispatchMode): """Dispatch mode that pads tensor factory size arguments.""" _SIZE_ARG_INDEX: ClassVar[dict[Callable[..., torch.Tensor], int]] = { - torch.ops.aten.zeros.default: 0, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.ones.default: 0, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.empty.memory_format: 0, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.full.default: 0, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.new_empty.default: 1, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.new_full.default: 1, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.new_zeros.default: 1, # pyright: ignore[reportAttributeAccessIssue] - torch.ops.aten.new_ones.default: 1, # pyright: ignore[reportAttributeAccessIssue] + torch.ops.aten.zeros.default: 0, + torch.ops.aten.ones.default: 0, + torch.ops.aten.empty.memory_format: 0, + torch.ops.aten.full.default: 0, + torch.ops.aten.new_empty.default: 1, + torch.ops.aten.new_full.default: 1, + torch.ops.aten.new_zeros.default: 1, + torch.ops.aten.new_ones.default: 1, } def __torch_dispatch__( diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 67a3d6169..84a19bdad 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -734,12 +734,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})" ), ] - mask_statement = self._setup_mask( # pyright: ignore[reportAttributeAccessIssue] + mask_statement = self._setup_mask( state, block_idx, block_size, index_var, end ) if mask_statement is not None: extra_body.append(mask_statement) - body[:] = [*extra_body, *body] # pyright: ignore[reportArgumentType,reportCallIssue] + body[:] = [*extra_body, *body] body = [for_node] assert for_node is not None return DeviceLoopState( diff --git a/helion/_compiler/traceback_compat.py b/helion/_compiler/traceback_compat.py index 11e3edf7b..d4d74bea9 100644 --- a/helion/_compiler/traceback_compat.py +++ b/helion/_compiler/traceback_compat.py @@ -32,9 +32,7 @@ def _ensure_original_line(fs: traceback.FrameSummary) -> None: # Same public behaviour as 3.11's property: # "return the line as-is from the source, without modifying whitespace". - fs._original_line = ( # pyright: ignore[reportAttributeAccessIssue] - raw - ) + fs._original_line = raw def _byte_offset_to_character_offset(s: str, offset: int) -> int: @@ -90,23 +88,17 @@ def normalize(off: int) -> int: statement = tree.body[0] if isinstance(statement, ast.Expr): - expr = ( - statement.expr # pyright: ignore[reportAttributeAccessIssue] - ) + expr = statement.expr # # 1. Binary operator (a + b, a * b, ...) # if isinstance(expr, ast.BinOp): - operator_start = normalize( - expr.left.end_col_offset # pyright: ignore[reportArgumentType] - ) + operator_start = normalize(expr.left.end_col_offset) operator_end = normalize(expr.right.col_offset) operator_str = segment[operator_start:operator_end] operator_offset = len(operator_str) - len(operator_str.lstrip()) - left_anchor = ( - expr.left.end_col_offset + operator_offset # pyright: ignore[reportOptionalOperand] - ) + left_anchor = expr.left.end_col_offset + operator_offset right_anchor = left_anchor + 1 if ( operator_offset + 1 < len(operator_str) @@ -130,12 +122,8 @@ def normalize(off: int) -> int: # 2. Subscript (a[index]) # if isinstance(expr, ast.Subscript): - left_anchor = normalize( - expr.value.end_col_offset # pyright: ignore[reportArgumentType] - ) - right_anchor = normalize( - expr.slice.end_col_offset + 1 # pyright: ignore[reportOptionalOperand] - ) + left_anchor = normalize(expr.value.end_col_offset) + right_anchor = normalize(expr.slice.end_col_offset + 1) while left_anchor < len(segment) and ( (ch := segment[left_anchor]).isspace() or ch != "[" diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index e3a62af4e..bd2ae28b9 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -65,7 +65,7 @@ class _VisitMethod(Protocol): @staticmethod - def __call__(self: object, node: ast.AST) -> TypeInfo: ... # pyright: ignore[reportSelfClsParameterName] + def __call__(self: object, node: ast.AST) -> TypeInfo: ... _T = TypeVar("_T") @@ -135,7 +135,7 @@ def maybe_get(self, name: str) -> TypeInfo | None: except exc.UndefinedVariable: return None - def set(self, name: str, type_info: TypeInfo) -> None: # pyright: ignore[reportIncompatibleMethodOverride] + def set(self, name: str, type_info: TypeInfo) -> None: self.variables[name] = type_info def merge(self, other: LocalScope | dict[str, TypeInfo]) -> LocalScope: @@ -245,9 +245,9 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: origin, dict( zip( - value._fields, # pyright: ignore[reportAttributeAccessIssue] + value._fields, cls._unpack_example( - value._asdict().items(), # pyright: ignore[reportAttributeAccessIssue] + value._asdict().items(), origin, ), strict=False, @@ -543,7 +543,7 @@ def propagate_setitem( def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: if origin.is_host(): try: - return TypeInfo.from_example(self.fake_value[key.proxy()], origin) # pyright: ignore[reportArgumentType] + return TypeInfo.from_example(self.fake_value[key.proxy()], origin) except NotImplementedError: raise exc.TypeInferenceError( f"Subscript not supported on {self!s} with key={key!s}" @@ -604,7 +604,7 @@ def populate_symbol_origins(self, origin: Origin) -> None: class TensorAttributeType(TypeInfo): - origin: AttributeOrigin # pyright: ignore[reportIncompatibleVariableOverride] + origin: AttributeOrigin tensor: TensorType def __init__(self, origin: AttributeOrigin, tensor: TensorType) -> None: @@ -701,7 +701,7 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: try: - return TypeInfo.from_example(self.value[key.as_literal()], origin) # pyright: ignore[reportIndexIssue] + return TypeInfo.from_example(self.value[key.as_literal()], origin) except NotImplementedError: pass return super().propagate_getitem(key, origin) @@ -725,7 +725,7 @@ def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: def unpack(self) -> list[TypeInfo]: try: - it = iter(self.value) # pyright: ignore[reportArgumentType,reportCallIssue] + it = iter(self.value) except TypeError: return super().unpack() return [TypeInfo.from_example(x, self.origin) for x in it] @@ -744,7 +744,7 @@ def __str__(self) -> str: class ConfigFragmentType(LiteralType): """TypeInfo for config fragments are treated as constant literals during compilation.""" - value: ConfigSpecFragment # pyright: ignore[reportIncompatibleVariableOverride] + value: ConfigSpecFragment def __init__(self, origin: Origin, fragment: ConfigSpecFragment) -> None: assert isinstance(fragment, ConfigSpecFragment) @@ -756,7 +756,7 @@ class CallableType(LiteralType): def __init__(self, origin: Origin, value: Callable[..., object]) -> None: super().__init__(origin, value) - self.value = value # pyright: ignore[reportIncompatibleVariableOverride] + self.value = value def __str__(self) -> str: return f"{type(self).__name__}({self.name})" @@ -771,7 +771,7 @@ def name(self) -> str: except AttributeError: return str(self.value) - def propagate_call( # pyright: ignore[reportIncompatibleMethodOverride] + def propagate_call( self, args: tuple[TypeInfo, ...], kwargs: dict[str, TypeInfo], origin: Origin ) -> TypeInfo | None: if self.value is breakpoint: @@ -881,7 +881,7 @@ class PythonModuleType(LiteralType): def __init__(self, origin: Origin, value: types.ModuleType) -> None: super().__init__(origin, value) - self.value = value # pyright: ignore[reportIncompatibleVariableOverride] + self.value = value def __str__(self) -> str: return f"{type(self).__name__}({self.value.__name__})" @@ -975,7 +975,7 @@ def populate_symbol_origins(self, origin: Origin) -> None: class SymIntType(NumericType): - value: torch.SymInt # pyright: ignore[reportIncompatibleVariableOverride] + value: torch.SymInt @classmethod def new_unbacked(cls, origin: Origin) -> Self: @@ -995,7 +995,7 @@ def proxy(self) -> torch.SymInt | int: class SymFloatType(NumericType): - value: torch.SymFloat # pyright: ignore[reportIncompatibleVariableOverride] + value: torch.SymFloat @classmethod def new_unbacked(cls, origin: Origin) -> Self: @@ -1012,7 +1012,7 @@ def python_type(self) -> type[float]: class SymBoolType(NumericType): - value: torch.SymBool # pyright: ignore[reportIncompatibleVariableOverride] + value: torch.SymBool @classmethod def new_unbacked(cls, origin: Origin) -> Self: @@ -1055,15 +1055,15 @@ def __init__(self, origin: Origin, block_id: int) -> None: def proxy(self) -> object: with proxy_tensor.disable_proxy_modes_tracing(): - fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue] - torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue] + fake_mode = torch._C._unset_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE ) try: - with torch._C._DisableTorchDispatch(): # pyright: ignore[reportAttributeAccessIssue] + with torch._C._DisableTorchDispatch(): return Tile(self.block_id) finally: assert fake_mode is not None - torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue] + torch._C._set_dispatch_mode(fake_mode) @staticmethod def allocate( @@ -1215,9 +1215,9 @@ def propagate_setitem( k := key.value, (int, str) ): if k in elements: - elements[k] = elements[k].merge(value) # pyright: ignore[reportArgumentType,reportCallIssue] + elements[k] = elements[k].merge(value) else: - elements[k] = value # pyright: ignore[reportArgumentType,reportCallIssue] + elements[k] = value return self return super().propagate_setitem(key, value, origin) @@ -1228,13 +1228,13 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: pass else: try: - result = self.element_types[literal_key] # pyright: ignore[reportArgumentType,reportCallIssue,reportIndexIssue] + result = self.element_types[literal_key] except (KeyError, IndexError) as e: raise exc.TypeInferenceError(f"{type(e).__name__}: {e}") from None if isinstance(result, TypeInfo): return result if type(result) is self.python_type: # sliced! - return type(self)(origin=origin, element_types=result) # pyright: ignore[reportArgumentType] + return type(self)(origin=origin, element_types=result) return super().propagate_getitem(key, origin) def truth_value(self) -> bool: @@ -1245,7 +1245,7 @@ def tree_map(self, fn: Callable[[TypeInfo], object]) -> object: class SequenceType(CollectionType): - element_types: list[TypeInfo] | tuple[TypeInfo, ...] # pyright: ignore[reportIncompatibleVariableOverride] + element_types: list[TypeInfo] | tuple[TypeInfo, ...] def __str__(self) -> str: start, *_, end = repr(self.element_types) @@ -1309,7 +1309,7 @@ def tree_map( class DictType(CollectionType): - element_types: dict[str | int, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride] + element_types: dict[str | int, TypeInfo] def __str__(self) -> str: items = ", ".join(f"{k!r}: {v!s}" for k, v in self.element_types.items()) @@ -1362,12 +1362,12 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: class StackTensorType(ClassType): - element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride] + element_types: dict[str, TypeInfo] - def proxy(self) -> StackTensor: # pyright: ignore[reportIncompatibleMethodOverride] + def proxy(self) -> StackTensor: with proxy_tensor.disable_proxy_modes_tracing(): - fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue] - torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue] + fake_mode = torch._C._unset_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE ) try: assert isinstance(self.element_types["tensor_like"], TensorType) @@ -1378,7 +1378,7 @@ def proxy(self) -> StackTensor: # pyright: ignore[reportIncompatibleMethodOverr ) finally: assert fake_mode is not None - torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue] + torch._C._set_dispatch_mode(fake_mode) def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: if isinstance(other, StackTensorType): @@ -1444,7 +1444,7 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: class SliceType(CollectionType): - element_types: slice # pyright: ignore[reportIncompatibleVariableOverride] + element_types: slice @property def lower(self) -> TypeInfo: @@ -1496,41 +1496,41 @@ def _eval_unary(op: ast.unaryop, value: object) -> object: if isinstance(op, ast.Not): return not value if isinstance(op, ast.UAdd): - return +value # pyright: ignore[reportOperatorIssue] + return +value if isinstance(op, ast.USub): - return -value # pyright: ignore[reportOperatorIssue] + return -value if isinstance(op, ast.Invert): - return ~value # pyright: ignore[reportOperatorIssue] + return ~value raise AssertionError(f"{type(op).__name__} unknown unary op") def _eval_binary(op: ast.operator, left: object, right: object) -> object: if isinstance(op, ast.Add): - return left + right # pyright: ignore[reportOperatorIssue] + return left + right if isinstance(op, ast.Sub): - return left - right # pyright: ignore[reportOperatorIssue] + return left - right if isinstance(op, ast.Mult): - return left * right # pyright: ignore[reportOperatorIssue] + return left * right if isinstance(op, ast.Div): - return left / right # pyright: ignore[reportOperatorIssue] + return left / right if isinstance(op, ast.FloorDiv): - return left // right # pyright: ignore[reportOperatorIssue] + return left // right if isinstance(op, ast.Mod): - return left % right # pyright: ignore[reportOperatorIssue] + return left % right if isinstance(op, ast.Pow): - return left**right # pyright: ignore[reportOperatorIssue] + return left**right if isinstance(op, ast.LShift): - return left << right # pyright: ignore[reportOperatorIssue] + return left << right if isinstance(op, ast.RShift): - return left >> right # pyright: ignore[reportOperatorIssue] + return left >> right if isinstance(op, ast.BitOr): - return left | right # pyright: ignore[reportOperatorIssue] + return left | right if isinstance(op, ast.BitXor): - return left ^ right # pyright: ignore[reportOperatorIssue] + return left ^ right if isinstance(op, ast.BitAnd): - return left & right # pyright: ignore[reportOperatorIssue] + return left & right if isinstance(op, ast.MatMult): - return left @ right # pyright: ignore[reportOperatorIssue] + return left @ right raise AssertionError(f"{type(op).__name__} unknown binary op") @@ -1540,21 +1540,21 @@ def _eval_compare(op: ast.cmpop, left: object, right: object) -> object: if isinstance(op, ast.NotEq): return left != right if isinstance(op, ast.Lt): - return left < right # pyright: ignore[reportOperatorIssue] + return left < right if isinstance(op, ast.LtE): - return left <= right # pyright: ignore[reportOperatorIssue] + return left <= right if isinstance(op, ast.Gt): - return left > right # pyright: ignore[reportOperatorIssue] + return left > right if isinstance(op, ast.GtE): - return left >= right # pyright: ignore[reportOperatorIssue] + return left >= right if isinstance(op, ast.Is): return left is right if isinstance(op, ast.IsNot): return left is not right if isinstance(op, ast.In): - return left in right # pyright: ignore[reportOperatorIssue] + return left in right if isinstance(op, ast.NotIn): - return left not in right # pyright: ignore[reportOperatorIssue] + return left not in right raise AssertionError(f"{type(op).__name__} unknown compare op") @@ -1801,22 +1801,22 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None: ) from None return self._assign(lhs.value, unpacked) if isinstance(lhs, (ast.Tuple, ast.List)): - lhs = lhs.elts # pyright: ignore[reportAssignmentType] + lhs = lhs.elts elements: list[TypeInfo] try: elements = rhs.unpack() except NotImplementedError: if isinstance(rhs, TileIndexType): raise exc.FailedToUnpackTile from None - raise exc.FailedToUnpackTupleAssign(len(lhs), rhs) from None # pyright: ignore[reportArgumentType] + raise exc.FailedToUnpackTupleAssign(len(lhs), rhs) from None used_star = False idx = 0 - for elt in lhs: # pyright: ignore[reportGeneralTypeIssues] + for elt in lhs: if isinstance(elt, ast.Starred): # TODO(jansel): need to test this assert not used_star, "multiple `*` in assignment" used_star = True - star_len = len(elements) - len(lhs) + 1 # pyright: ignore[reportArgumentType] + star_len = len(elements) - len(lhs) + 1 assert star_len >= 0, "wrong number of elements to unpack" self._assign( elt.value, @@ -1898,9 +1898,9 @@ def _list_or_tuple(self, node: ast.List | ast.Tuple) -> TypeInfo: cls(elements), ) - visit_List: _VisitMethod = _list_or_tuple # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Tuple: _VisitMethod = _list_or_tuple # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Set: _VisitMethod = _unsupported(set) # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_List: _VisitMethod = _list_or_tuple + visit_Tuple: _VisitMethod = _list_or_tuple + visit_Set: _VisitMethod = _unsupported(set) def visit_Dict(self, node: ast.Dict) -> TypeInfo: assert len(node.keys) == len(node.values) @@ -1930,7 +1930,7 @@ def visit_Name(self, node: ast.Name) -> TypeInfo: raise exc.CannotReadDeviceVariableOnHost(node.id) return result - visit_Starred: _VisitMethod = generic_visit # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_Starred: _VisitMethod = generic_visit def visit_Expr(self, node: ast.Expr) -> TypeInfo: return self.visit(node.value) @@ -2059,7 +2059,7 @@ def visit_Call(self, node: ast.Call) -> TypeInfo: "Failed to unpack */** args to function, got: " + ", ".join(map(str, unhandled)) ) - return func.propagate_call(tuple(args), kwargs, self.origin()) # pyright: ignore[reportReturnType] + return func.propagate_call(tuple(args), kwargs, self.origin()) def visit_IfExp(self, node: ast.IfExp) -> TypeInfo: test = self.visit(node.test) @@ -2167,22 +2167,22 @@ def visit_Assert(self, node: ast.Assert) -> TypeInfo: self.visit(node.msg) return NoType(origin=self.origin()) - visit_Raise: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Delete: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_Raise: _VisitMethod = generic_statement + visit_Delete: _VisitMethod = generic_statement def visit_Pass(self, node: ast.Pass) -> TypeInfo: return NoType(origin=self.origin()) - visit_TypeAlias: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType, reportIncompatibleMethodOverride] - visit_Import: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_ImportFrom: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_TypeAlias: _VisitMethod = generic_statement + visit_Import: _VisitMethod = generic_statement + visit_ImportFrom: _VisitMethod = generic_statement def visit_Global(self, node: ast.Global) -> TypeInfo: # Global statements don't need child visiting since they only declare names return NoType(origin=self.origin()) # TODO(jansel): support lambda - visit_Lambda: _VisitMethod = generic_visit # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_Lambda: _VisitMethod = generic_visit ################################################################ # Control flow @@ -2238,7 +2238,7 @@ def visit_For(self, node: ast.For) -> TypeInfo: ) if device_loop: if node.orelse: - raise exc.DeviceLoopElseBlock(fn.__qualname__) # pyright: ignore[reportPossiblyUnboundVariable] + raise exc.DeviceLoopElseBlock(fn.__qualname__) if self.device_loop_depth == 0: self.func.set_local_types(parent_scope.extract_locals()) @@ -2275,8 +2275,8 @@ def visit_While(self, node: ast.While) -> TypeInfo: self.scope.merge_if_else(body, orelse) return NoType(origin=self.origin()) - visit_Break: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Continue: _VisitMethod = generic_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_Break: _VisitMethod = generic_statement + visit_Continue: _VisitMethod = generic_statement def visit_Try(self, node: ast.Try) -> TypeInfo: self.scope.merge(self._body(node.body)) @@ -2288,7 +2288,7 @@ def visit_Try(self, node: ast.Try) -> TypeInfo: self.scope.overwrite(self._body(node.finalbody)) return NoType(origin=self.origin()) - visit_TryStar: _VisitMethod = visit_Try # pyright: ignore[reportAssignmentType, reportIncompatibleMethodOverride] + visit_TryStar: _VisitMethod = visit_Try def _not_on_device_statement(self, node: ast.AST) -> TypeInfo: if self.device_loop_depth: @@ -2297,9 +2297,9 @@ def _not_on_device_statement(self, node: ast.AST) -> TypeInfo: self.visit(child_node) return NoType(origin=self.origin()) - visit_ExceptHandler: _VisitMethod = _not_on_device_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_With: _VisitMethod = _not_on_device_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Return: _VisitMethod = _not_on_device_statement # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_ExceptHandler: _VisitMethod = _not_on_device_statement + visit_With: _VisitMethod = _not_on_device_statement + visit_Return: _VisitMethod = _not_on_device_statement def _not_supported(self, node: ast.AST) -> TypeInfo: raise exc.StatementNotSupported(type(node).__name__) @@ -2361,29 +2361,29 @@ def visit_ListComp(self, node: ast.ListComp) -> TypeInfo: return self._evaluate_comprehension(node.generators[0], node.elt) # TODO(jansel): need to implement these - visit_SetComp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_GeneratorExp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_DictComp: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_SetComp: _VisitMethod = _not_supported + visit_GeneratorExp: _VisitMethod = _not_supported + visit_DictComp: _VisitMethod = _not_supported # TODO(jansel): support closure functions defined on host - visit_FunctionDef: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - - visit_ClassDef: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Yield: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_YieldFrom: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_AsyncFunctionDef: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_AsyncFor: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_AsyncWith: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Await: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_Match: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchValue: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchSingleton: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchSequence: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchStar: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchMapping: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchClass: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchAs: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] - visit_MatchOr: _VisitMethod = _not_supported # pyright: ignore[reportAssignmentType,reportIncompatibleMethodOverride] + visit_FunctionDef: _VisitMethod = _not_supported + + visit_ClassDef: _VisitMethod = _not_supported + visit_Yield: _VisitMethod = _not_supported + visit_YieldFrom: _VisitMethod = _not_supported + visit_AsyncFunctionDef: _VisitMethod = _not_supported + visit_AsyncFor: _VisitMethod = _not_supported + visit_AsyncWith: _VisitMethod = _not_supported + visit_Await: _VisitMethod = _not_supported + visit_Match: _VisitMethod = _not_supported + visit_MatchValue: _VisitMethod = _not_supported + visit_MatchSingleton: _VisitMethod = _not_supported + visit_MatchSequence: _VisitMethod = _not_supported + visit_MatchStar: _VisitMethod = _not_supported + visit_MatchMapping: _VisitMethod = _not_supported + visit_MatchClass: _VisitMethod = _not_supported + visit_MatchAs: _VisitMethod = _not_supported + visit_MatchOr: _VisitMethod = _not_supported def _to_proxy(arg: TypeInfo) -> object: diff --git a/helion/_logging/_internal.py b/helion/_logging/_internal.py index f3363f13c..d5f17937c 100644 --- a/helion/_logging/_internal.py +++ b/helion/_logging/_internal.py @@ -86,9 +86,9 @@ class LazyString: def __init__( self, func: Callable[P, str], *args: P.args, **kwargs: P.kwargs ) -> None: - self.func: Callable[P, str] = func # pyright: ignore[reportGeneralTypeIssues,reportInvalidTypeForm] + self.func: Callable[P, str] = func self.args: tuple[object, ...] = args self.kwargs: object = kwargs def __str__(self) -> str: - return self.func(*self.args, **self.kwargs) # pyright: ignore[reportCallIssue] + return self.func(*self.args, **self.kwargs) diff --git a/helion/_testing.py b/helion/_testing.py index c5411d080..41d2f023e 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -39,7 +39,7 @@ def _get_triton_backend() -> str | None: try: - return triton.runtime.driver.active.get_current_target().backend # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + return triton.runtime.driver.active.get_current_target().backend except Exception: return None @@ -126,12 +126,12 @@ def skipIfNormalMode(reason: str) -> Callable[[Callable], Callable]: def skipIfRocm(reason: str) -> Callable[[Callable], Callable]: """Skip test if running with rocm""" - return unittest.skipIf(torch.version.hip is not None, reason) # pyright: ignore[reportAttributeAccessIssue] + return unittest.skipIf(torch.version.hip is not None, reason) def skipIfXPU(reason: str) -> Callable[[Callable], Callable]: """Skip test if running with Intel XPU""" - return unittest.skipIf(torch.xpu.is_available(), reason) # pyright: ignore[reportAttributeAccessIssue] + return unittest.skipIf(torch.xpu.is_available(), reason) def skipIfCpu(reason: str) -> Callable[[Callable], Callable]: @@ -335,14 +335,14 @@ def counting_skip_test(*args: object, **kwargs: object) -> object: self._run_ref_count = self._run_ref_tracker.__enter__() # Patch pytest.raises to count calls - if RefEagerTestBase._original_pytest_raises is None: # pyright: ignore[reportAttributeAccessIssue] + if RefEagerTestBase._original_pytest_raises is None: RefEagerTestBase._original_pytest_raises = pytest.raises def counting_pytest_raises(*args: object, **kwargs: object) -> object: """Wrapper for pytest.raises that counts calls but still runs the original logic.""" RefEagerTestBase._assert_raises_count += 1 - assert RefEagerTestBase._original_pytest_raises is not None # pyright: ignore[reportAttributeAccessIssue] - return RefEagerTestBase._original_pytest_raises(*args, **kwargs) # pyright: ignore[reportAttributeAccessIssue] + assert RefEagerTestBase._original_pytest_raises is not None + return RefEagerTestBase._original_pytest_raises(*args, **kwargs) pytest.raises = counting_pytest_raises # type: ignore[assignment] @@ -398,7 +398,7 @@ def tearDown(self) -> None: # Assert that either run_ref was called or the test was skipped if not is_skipped and self._run_ref_count[0] == 0: self.fail( # type: ignore[attr-defined] - f"Test {self._testMethodName} did not call run_ref and was not skipped" # pyright: ignore[reportAttributeAccessIssue] + f"Test {self._testMethodName} did not call run_ref and was not skipped" ) if not is_skipped: @@ -416,12 +416,12 @@ def tearDown(self) -> None: RefEagerTestBase._original_assert_greater_func( # type: ignore[misc] total_assertions, 0, - f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, skipTest, assertTrue, assertFalse, or assertGreater", # type: ignore[attr-defined] ) else: # Fallback if original not available assert total_assertions > 0, ( - f"Test {self._testMethodName} did not call any assertion methods" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + f"Test {self._testMethodName} did not call any assertion methods" # type: ignore[attr-defined] ) finally: # Restore the original assert_close function @@ -439,8 +439,8 @@ def tearDown(self) -> None: self.skipTest = RefEagerTestBase._original_skip_test_func # Restore the original pytest.raises function - if RefEagerTestBase._original_pytest_raises is not None: # pyright: ignore[reportAttributeAccessIssue] - pytest.raises = RefEagerTestBase._original_pytest_raises # pyright: ignore[reportAttributeAccessIssue] + if RefEagerTestBase._original_pytest_raises is not None: + pytest.raises = RefEagerTestBase._original_pytest_raises # Restore the original assertTrue function if RefEagerTestBase._original_assert_true_func is not None: @@ -506,9 +506,9 @@ def assertIsInstance( def import_path(filename: Path) -> types.ModuleType: module_name = f"{__name__}.{filename.stem}" if module_name not in sys.modules: - spec = importlib.util.spec_from_file_location(module_name, filename) # pyright: ignore[reportAttributeAccessIssue] + spec = importlib.util.spec_from_file_location(module_name, filename) assert spec is not None - module = importlib.util.module_from_spec(spec) # pyright: ignore[reportAttributeAccessIssue] + module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) sys.modules[module_name] = module @@ -523,7 +523,7 @@ def code_and_output( bound = fn.bind(args) if is_ref_mode_enabled(bound.kernel.settings): if kwargs: - config = Config(**kwargs) # pyright: ignore[reportArgumentType] + config = Config(**kwargs) bound._config = config result = fn(*args) # Return the original kernel source code @@ -531,9 +531,7 @@ def code_and_output( return code, result if kwargs: - config = Config( - **kwargs # pyright: ignore[reportArgumentType] - ) + config = Config(**kwargs) elif fn.configs: (config,) = fn.configs else: @@ -677,7 +675,7 @@ def run_example( repeat = compute_repeat(bench_fns[0]) timings = interleaved_bench(bench_fns, repeat=repeat, desc="Benchmarking") all_times = dict(zip(all_benchmarks.keys(), timings, strict=True)) - best_baseline_time = min(all_times[name] for name in baselines) # pyright: ignore[reportArgumentType] + best_baseline_time = min(all_times[name] for name in baselines) # Print results print(f"\n{'=' * 65}\nBenchmark Results\n{'=' * 65}", file=sys.stderr) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index ae2c5648e..f83977129 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -1091,7 +1091,7 @@ def _wait_for_all_step( # Wait for at least one to finish or time out timeout = min([f.seconds_left() for f in running], default=0.0) - handles = [f.process.sentinel for f in running] # pyright: ignore[reportOptionalMemberAccess] + handles = [f.process.sentinel for f in running] if handles and timeout > 0: connection.wait(handles, timeout) remaining: list[PrecompileFuture] = [] diff --git a/helion/autotuner/block_id_sequence.py b/helion/autotuner/block_id_sequence.py index babc82e47..0663eeb86 100644 --- a/helion/autotuner/block_id_sequence.py +++ b/helion/autotuner/block_id_sequence.py @@ -72,14 +72,14 @@ def _reindex(self) -> None: new_index[block_id] = i self._block_id_to_index = new_index - def __getitem__(self, index: int) -> _BlockIdItemT: # pyright: ignore[reportIncompatibleMethodOverride] + def __getitem__(self, index: int) -> _BlockIdItemT: return self._data[index] - def __setitem__(self, index: int, value: _BlockIdItemT) -> None: # pyright: ignore[reportIncompatibleMethodOverride] + def __setitem__(self, index: int, value: _BlockIdItemT) -> None: self._data[index] = value self._reindex() # could be faster, but uncommon case - def __delitem__(self, index: int) -> None: # pyright: ignore[reportIncompatibleMethodOverride] + def __delitem__(self, index: int) -> None: del self._data[index] self._reindex() # could be faster, but uncommon case @@ -163,7 +163,7 @@ def _normalize( values = () new_values = [] - map_aggregate(values, new_values.append) # pyright: ignore[reportArgumentType] + map_aggregate(values, new_values.append) values = new_values elif not isinstance(values, (list, tuple, type(None))): raise InvalidConfig( diff --git a/helion/autotuner/local_cache.py b/helion/autotuner/local_cache.py index bb4110ab9..c917aae9a 100644 --- a/helion/autotuner/local_cache.py +++ b/helion/autotuner/local_cache.py @@ -12,9 +12,7 @@ import uuid import torch -from torch._inductor.runtime.cache_dir_utils import ( - cache_dir, # pyright: ignore[reportPrivateImportUsage] -) +from torch._inductor.runtime.cache_dir_utils import cache_dir from ..runtime.config import Config from .base_cache import AutotuneCacheBase @@ -71,21 +69,21 @@ def _generate_key(self) -> LooseAutotuneCacheKey: dev.type == "xpu" and getattr(torch, "xpu", None) is not None and torch.xpu.is_available() - ): # pyright: ignore[reportAttributeAccessIssue] + ): device_properties = torch.xpu.get_device_properties(dev) hardware = device_properties.name - runtime_name = device_properties.driver_version # pyright: ignore[reportAttributeAccessIssue] + runtime_name = device_properties.driver_version break # CUDA/ROCm path if dev.type == "cuda" and torch.cuda.is_available(): device_properties = torch.cuda.get_device_properties(dev) - if torch.version.cuda is not None: # pyright: ignore[reportAttributeAccessIssue] + if torch.version.cuda is not None: hardware = device_properties.name runtime_name = str(torch.version.cuda) - elif torch.version.hip is not None: # pyright: ignore[reportAttributeAccessIssue] + elif torch.version.hip is not None: hardware = device_properties.gcnArchName - runtime_name = torch.version.hip # pyright: ignore[reportAttributeAccessIssue] + runtime_name = torch.version.hip break assert hardware is not None and runtime_name is not None diff --git a/helion/autotuner/progress_bar.py b/helion/autotuner/progress_bar.py index 44e27b941..3868e744a 100644 --- a/helion/autotuner/progress_bar.py +++ b/helion/autotuner/progress_bar.py @@ -54,7 +54,7 @@ def iter_with_progress( When ``False`` the iterable is returned unchanged so there is zero overhead; when ``True`` a Rich progress bar is rendered. """ - if (not enabled) or torch._utils_internal.is_fb_unit_test(): # pyright: ignore[reportAttributeAccessIssue] + if (not enabled) or torch._utils_internal.is_fb_unit_test(): yield from iterable return diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index 89a073f63..f3fc2ccf6 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -32,7 +32,7 @@ class _Decorator(Protocol): def __call__(self, fn: _C) -> _C: ... - class _NoReturnDecorator(Protocol, Generic[_T]): # pyright: ignore[reportInvalidTypeVarUse] + class _NoReturnDecorator(Protocol, Generic[_T]): def __call__(self, fn: Callable[..., _T]) -> object: ... @@ -198,7 +198,7 @@ def wrapper(*args: object, **kwargs: object) -> object: cast("Callable[..., object]", fn) ) api._ref_fn = None - return wrapper # pyright: ignore[reportReturnType] + return wrapper return _impl @@ -218,7 +218,7 @@ def _impl(fake_fn: Callable[..., object]) -> Callable[..., Never]: ) return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def type_propagation( @@ -231,7 +231,7 @@ def _impl(type_fn: Callable[..., TypeInfo]) -> Callable[..., Never]: original_fn._type_function = type_fn return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def prepare_args( @@ -249,7 +249,7 @@ def _impl( original_fn._prepare_args = prep_fn return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def codegen( @@ -266,7 +266,7 @@ def _impl(codegen_fn: Callable[[CodegenState], object]) -> Callable[..., Never]: original_fn._codegen[backend] = codegen_fn return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def get_masked_value( @@ -284,7 +284,7 @@ def _impl( original_fn._get_masked_value = mask_value_fn return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def register_to_device_ir( @@ -298,7 +298,7 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]: original_fn._to_device_ir = to_device_ir_fn return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def ref( @@ -314,7 +314,7 @@ def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]: original_fn._ref_fn = ref_fn return _no_call - return _impl # pyright: ignore[reportReturnType] + return _impl def _default_type_function( diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index 0d5939b35..af9e978a7 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -42,14 +42,14 @@ def _get_symnode(debug_name: str) -> int: @_decorators.codegen(_get_symnode, "triton") def _(state: CodegenState) -> ast.AST: - val = state.fx_node.meta["val"] # pyright: ignore[reportOptionalMemberAccess] + val = state.fx_node.meta["val"] # Handle the case where val is a regular integer (e.g., from reduction_loops config) if isinstance(val, int): return expr_from_string(str(val)) assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val - if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None: # pyright: ignore[reportArgumentType] + if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None: block_size_var = state.device_function.block_size_var(block_idx) if block_size_var is None: return expr_from_string("1") @@ -85,7 +85,7 @@ def _for_loop( @_decorators.codegen(_for_loop, "triton") def _(state: CodegenState) -> None: - return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue] + return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) @has_side_effect @@ -102,7 +102,7 @@ def _while_loop( @_decorators.codegen(_while_loop, "triton") def _(state: CodegenState) -> None: - return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue] + return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) @has_side_effect @@ -114,7 +114,7 @@ def _if(test: object, graph_id: int, args: list[object]) -> list[object]: @_decorators.codegen(_if, "triton") def _(state: CodegenState) -> None: - return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # pyright: ignore[reportArgumentType,reportCallIssue] + return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) # Note we can't DCE phi nodes because there may be a loop carry dependency not captured in the outer graph @@ -184,7 +184,7 @@ def _and(left: object, right: object) -> object: def _(state: CodegenState) -> None: return expr_from_string( "{lhs} and {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1) - ) # pyright: ignore[reportReturnType] + ) @_decorators.register_fake(_and) @@ -237,7 +237,7 @@ def _(left: object, right: object) -> object: def _(state: CodegenState) -> None: return expr_from_string( "{lhs} or {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1) - ) # pyright: ignore[reportReturnType] + ) @_decorators.api() @@ -345,7 +345,7 @@ def _(value: _T) -> _T: if isinstance(value, torch.Tensor): return torch.empty_like(value) if isinstance(value, torch.SymInt): - return CompileEnvironment.current().create_unbacked_symint() # pyright: ignore[reportReturnType] + return CompileEnvironment.current().create_unbacked_symint() if isinstance(value, (int, float, bool)) or value is None: return value raise NotImplementedError(f"Unsupported type for _new_var: {type(value)}") diff --git a/helion/language/atomic_ops.py b/helion/language/atomic_ops.py index 2aa16331b..59eb3798d 100644 --- a/helion/language/atomic_ops.py +++ b/helion/language/atomic_ops.py @@ -244,9 +244,9 @@ def _convert_value_to_target_dtype(val: object) -> torch.Tensor: prev_chunks: list[torch.Tensor] = [] def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None: - prev_val = t[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev_val = t[idx_tuple].clone() val_tensor = _convert_value_to_target_dtype(v) - t[idx_tuple] = t[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType] + t[idx_tuple] = t[idx_tuple] + val_tensor prev_chunks.append(prev_val.reshape(-1)) _ref_apply(target, index, apply, value) @@ -257,9 +257,9 @@ def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None: return flat_prev.reshape(ret_shape) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() val_tensor = _convert_value_to_target_dtype(value) - target[idx_tuple] = target[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType] + target[idx_tuple] = target[idx_tuple] + val_tensor return prev @@ -333,13 +333,13 @@ def _( else: processed_index.append(idx) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) - target[idx_tuple] = val # pyright: ignore[reportArgumentType] + target[idx_tuple] = val return prev @@ -410,13 +410,13 @@ def _( else: processed_index.append(idx) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) - target[idx_tuple] = target[idx_tuple] & val # pyright: ignore[reportArgumentType] + target[idx_tuple] = target[idx_tuple] & val return prev @@ -484,13 +484,13 @@ def _( else: processed_index.append(idx) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) - target[idx_tuple] = target[idx_tuple] | val # pyright: ignore[reportArgumentType] + target[idx_tuple] = target[idx_tuple] | val return prev @@ -558,13 +558,13 @@ def _( else: processed_index.append(idx) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) - target[idx_tuple] = target[idx_tuple] ^ val # pyright: ignore[reportArgumentType] + target[idx_tuple] = target[idx_tuple] ^ val return prev @@ -629,7 +629,7 @@ def _( def apply(t: torch.Tensor, idx: tuple, v: object) -> None: t[idx] = torch.maximum( t[idx], torch.as_tensor(v, dtype=t[idx].dtype, device=t.device) - ) # pyright: ignore[reportArgumentType] + ) _ref_apply(target, index, apply, value) @@ -699,13 +699,13 @@ def _( else: processed_index.append(idx) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() val = ( value if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) - target[idx_tuple] = torch.minimum(target[idx_tuple], val) # pyright: ignore[reportArgumentType] + target[idx_tuple] = torch.minimum(target[idx_tuple], val) return prev @@ -794,7 +794,7 @@ def _( else: processed_index.append(idx) idx_tuple = tuple(processed_index) - prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] + prev = target[idx_tuple].clone() exp_t = ( expected if isinstance(expected, torch.Tensor) @@ -805,8 +805,8 @@ def _( if isinstance(value, torch.Tensor) else torch.as_tensor(value, dtype=target.dtype, device=target.device) ) - mask = target[idx_tuple] == exp_t # pyright: ignore[reportArgumentType] - target[idx_tuple] = torch.where(mask, val_t, target[idx_tuple]) # pyright: ignore[reportArgumentType] + mask = target[idx_tuple] == exp_t + target[idx_tuple] = torch.where(mask, val_t, target[idx_tuple]) return prev diff --git a/helion/language/loops.py b/helion/language/loops.py index 506391850..97cff66c3 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -734,7 +734,7 @@ def _( size = None # data dependent size if step_part is None: step_part = 1 - results.append(GridIndexType.allocate(size, origin, step_part)) # pyright: ignore[reportArgumentType] + results.append(GridIndexType.allocate(size, origin, step_part)) _add_config_choices( [x.block_id for x in results], diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 88bad4532..f45ec5947 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -306,7 +306,7 @@ def _( from .ref_tile import RefTile if extra_mask is None: - return tensor[tuple(index)] # pyright: ignore[reportArgumentType] + return tensor[tuple(index)] # Create zero result matching mask shape result = torch.zeros(extra_mask.shape, dtype=tensor.dtype, device=tensor.device) diff --git a/helion/language/reduce_ops.py b/helion/language/reduce_ops.py index 0412dbe7a..70343eafc 100644 --- a/helion/language/reduce_ops.py +++ b/helion/language/reduce_ops.py @@ -396,7 +396,7 @@ def _( # For single tensor or single other value, use mask_node_inputs from .._compiler.node_masking import mask_node_inputs - mask_node_inputs(actual_node, other=other) # pyright: ignore[reportArgumentType] + mask_node_inputs(actual_node, other=other) # Create output tensors with reduced shape if is_tuple_input: @@ -522,7 +522,7 @@ def _create_reduce_expression( return expr_from_string( template, input_tensor=input_tensor, - dim_value=ast.Constant(value=dim), # pyright: ignore[reportArgumentType] + dim_value=ast.Constant(value=dim), ) diff --git a/helion/language/ref_tile.py b/helion/language/ref_tile.py index 63a51441c..5c6ed152f 100644 --- a/helion/language/ref_tile.py +++ b/helion/language/ref_tile.py @@ -155,7 +155,7 @@ def _handle_getitem( assert isinstance(tensor, torch.Tensor) slice_index = convert_tile_indices_to_slices(index) - return tensor[slice_index] # pyright: ignore[reportArgumentType] + return tensor[slice_index] @classmethod def _handle_setitem( @@ -170,7 +170,7 @@ def _handle_setitem( assert isinstance(value, (int, float, bool, torch.Tensor)) slice_index = convert_tile_indices_to_slices(index) - target_shape = tensor[slice_index].shape # pyright: ignore[reportArgumentType] + target_shape = tensor[slice_index].shape # Slice value tensor to match target shape if needed if ( @@ -181,17 +181,17 @@ def _handle_setitem( slices = create_shape_matching_slices(value.shape, target_shape) value = value[slices] - tensor[slice_index] = value # pyright: ignore[reportArgumentType] + tensor[slice_index] = value return None - def __repr__(self, tensor_contents: None = None) -> str: # pyright: ignore[reportIncompatibleMethodOverride] + def __repr__(self, tensor_contents: None = None) -> str: return f"RefTile({self._slice!r})" def __index__(self) -> int: return self.block_size @property - def index(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleMethodOverride] + def index(self) -> torch.Tensor: """Return tensor of indices for .index attribute access in ref mode.""" from .._compiler.compile_environment import CompileEnvironment diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 43386daca..8d0d3f42c 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -158,8 +158,8 @@ def _(state: CodegenState) -> ast.AST: else: raise NotImplementedError(f"Unsupported signal pad type: {type(signal_pad)}") - signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] - update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType] + signal_expr = ast.Constant(value=signal) + update_expr = ast.Constant(value=update) is_scalar = len(shape) == 0 @@ -322,12 +322,12 @@ def _(state: CodegenState) -> ast.AST: is_scalar = len(shape) == 0 - signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] + signal_expr = ast.Constant(value=signal) if wait_for is not None: - wait_for_expr = ast.Constant(value=wait_for) # pyright: ignore[reportArgumentType] + wait_for_expr = ast.Constant(value=wait_for) else: wait_for_expr = ast.Constant(value=0) - skip_sync_expr = ast.Constant(value=skip_sync) # pyright: ignore[reportArgumentType] + skip_sync_expr = ast.Constant(value=skip_sync) if wait_for is not None: call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={{bar_addrs}}, expect={{wait_for}}, update={{signal}}, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not {{skip_sync}}))" diff --git a/helion/language/stack_tensor.py b/helion/language/stack_tensor.py index 8b025a9b5..ee23d6835 100644 --- a/helion/language/stack_tensor.py +++ b/helion/language/stack_tensor.py @@ -76,13 +76,13 @@ def device(self) -> torch.device: def shape(self) -> torch.Size: return self.dev_ptrs.shape + self.tensor_like.shape - def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride] + def __getitem__( self, index: list[object] | torch.Tensor, ) -> torch.Tensor: raise exc.NotInsideKernel - def __setitem__( # pyright ignore[reportIncompatibleMethodOverride] + def __setitem__( self, index: list[object] | torch.Tensor, value: torch.Tensor | bool | float, @@ -92,7 +92,7 @@ def __setitem__( # pyright ignore[reportIncompatibleMethodOverride] def new_empty( self, *args: Sequence[int | torch.SymInt], **kwargs: dict ) -> torch.Tensor: - return self.tensor_like.new_empty(*args, **kwargs) # pyright: ignore[reportCallIssue] + return self.tensor_like.new_empty(*args, **kwargs) # TODO(joydddd): Implement this to support StackTensor in ref mode. # def as_tuple_of_tensor(self) -> tuple[torch.Tensor, ...]: @@ -201,7 +201,7 @@ def _(tensor_like: TypeInfo, dev_ptrs: TypeInfo, *, origin: Origin) -> TypeInfo: "tensor_like": tensor_like, } - return StackTensorType(origin, element_types) # pyright: ignore[reportArgumentType] + return StackTensorType(origin, element_types) @_decorators.register_to_device_ir(_stack_tensor) diff --git a/helion/language/tile_proxy.py b/helion/language/tile_proxy.py index 0e5b2f5f7..055a80d4e 100644 --- a/helion/language/tile_proxy.py +++ b/helion/language/tile_proxy.py @@ -102,7 +102,7 @@ def _prepare_index(index: object) -> list[object]: assert isinstance(index, Tile) return [index] - def __repr__(self, tensor_contents: None = None) -> str: # pyright: ignore[reportIncompatibleMethodOverride] + def __repr__(self, tensor_contents: None = None) -> str: return f"Tile({self.block_id!r})" @classmethod diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index 863877be5..dccf56bad 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -90,7 +90,7 @@ def _(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: @_decorators.codegen(subscript, "triton") def _(state: CodegenState) -> ast.AST: output_keys = [] - for val in state.proxy_arg(1): # pyright: ignore[reportGeneralTypeIssues] + for val in state.proxy_arg(1): if val is None: output_keys.append("None") elif isinstance(val, slice) and repr(val) == "slice(None, None, None)": @@ -105,7 +105,7 @@ def _(state: CodegenState) -> ast.AST: @_decorators.ref(subscript) def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: - return tensor[indices] # pyright: ignore[reportArgumentType] + return tensor[indices] @_decorators.get_masked_value(subscript) diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 73f815e9f..ae5c008f5 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -124,8 +124,7 @@ def __init__( self.settings: Settings = settings or Settings() self._key_fn: Callable[..., Hashable] | None = key self.configs: list[Config] = [ - Config(**c) if isinstance(c, dict) else c # pyright: ignore[reportArgumentType] - for c in configs or [] + Config(**c) if isinstance(c, dict) else c for c in configs or [] ] self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {} self._specialize_extra: dict[ @@ -456,7 +455,7 @@ def to_triton_code( config = self._require_implicit_config() with self.env: if not isinstance(config, Config): - config = Config(**config) # pyright: ignore[reportArgumentType] + config = Config(**config) self.env.config_spec.normalize(config) root = generate_ast(self.host_function, config, emit_repro_caller) if output_origin_lines is None: @@ -481,9 +480,7 @@ def compile_config( if config is None: config = self._require_implicit_config() if not isinstance(config, Config): - config = Config( - **config # pyright: ignore[reportArgumentType] - ) + config = Config(**config) if (rv := self._compile_cache.get(config)) is not None: return rv try: @@ -576,9 +573,7 @@ def set_config(self, config: ConfigLike) -> None: config: The configuration to set. """ if not isinstance(config, Config): - config = Config( - **config # pyright: ignore[reportArgumentType] - ) + config = Config(**config) self._run = self.compile_config(config) self._config = config @@ -646,7 +641,7 @@ def _require_implicit_config(self) -> Config: raise RuntimeError("no config provided and no implicit config available") return config - def run_ref(self, *args: object) -> _R: # pyright: ignore[reportReturnType] + def run_ref(self, *args: object) -> _R: # Unwrap ConstExpr arguments clean_args = [] for arg in args: @@ -933,7 +928,7 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: _specialization_extractors: dict[ type[object] | str, Callable[[Kernel, object], Hashable] -] = { # pyright: ignore[reportAssignmentType] +] = { torch.Tensor: _tensor_key, torch.nn.Parameter: _tensor_key, FakeTensor: _tensor_key, @@ -945,13 +940,13 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: str: lambda fn, x: x, list: _sequence_key, tuple: _sequence_key, - dict: lambda fn, x: _mapping_key(fn, x, type(x)), # pyright: ignore[reportArgumentType] - "namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)), # pyright: ignore[reportAttributeAccessIssue] - "dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), # pyright: ignore[reportArgumentType] + dict: lambda fn, x: _mapping_key(fn, x, type(x)), + "namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)), + "dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), types.FunctionType: _function_key, types.BuiltinFunctionType: lambda fn, x: x, torch.fx.GraphModule: _graph_module_key, - ConstExpr: lambda fn, x: x.value, # pyright: ignore[reportAttributeAccessIssue] + ConstExpr: lambda fn, x: x.value, type(None): lambda fn, x: None, } @@ -989,8 +984,8 @@ def _find_device(args: tuple[object, ...]) -> torch.device: def _maybe_skip_dtype_check_in_meta_registrations() -> ( contextlib.AbstractContextManager[None, None] ): - if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"): # pyright: ignore[reportAttributeAccessIssue] - return torch.fx.experimental._config.patch( # pyright: ignore[reportAttributeAccessIssue] + if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"): + return torch.fx.experimental._config.patch( skip_dtype_check_in_meta_registrations=True ) return contextlib.nullcontext() diff --git a/helion/runtime/precompile_shim.py b/helion/runtime/precompile_shim.py index 6d72dba78..f49dd4f33 100644 --- a/helion/runtime/precompile_shim.py +++ b/helion/runtime/precompile_shim.py @@ -32,7 +32,7 @@ def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], None]: parts so we can wrap it in a subprocess to handle configs that hang in Triton compile and never return. """ - device = _find_device([*args, *kwargs.values()]) # pyright: ignore[reportArgumentType] + device = _find_device([*args, *kwargs.values()]) kwargs["debug"] = ( kwargs.get("debug", fn.debug) or os.environ.get("TRITON_DEBUG", "0") == "1" ) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index b48f3b99f..affec721b 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -248,7 +248,7 @@ def default_autotuner_fn( f"{', '.join(cache_classes.keys())}" ) - return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) # pyright: ignore[reportArgumentType] + return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) def _get_autotune_random_seed() -> int: @@ -278,7 +278,7 @@ class _Settings: cast("DotPrecision", "tf32"), mapping={k: k for k in ("tf32", "tf32x3", "ieee")}, ) - ) # pyright: ignore[reportAssignmentType] + ) static_shapes: bool = dataclasses.field( default_factory=functools.partial(_env_get_bool, "HELION_STATIC_SHAPES", True) ) @@ -315,7 +315,7 @@ class _Settings: "0": None, }, ) - ) # pyright: ignore[reportAssignmentType] + ) autotune_precompile_jobs: int | None = dataclasses.field( default_factory=functools.partial( _env_get_optional_int, @@ -378,7 +378,7 @@ class _Settings: cast("AutotuneEffort", "full"), mapping={key: key for key in ("none", "quick", "full")}, ) - ) # pyright: ignore[reportAssignmentType] + ) allow_warp_specialize: bool = dataclasses.field( default_factory=functools.partial( _env_get_bool, "HELION_ALLOW_WARP_SPECIALIZE", True @@ -484,7 +484,7 @@ def __init__(self, **settings: object) -> None: Initialize the Settings object with the provided dictionary of settings. """ - super().__init__(**settings) # pyright: ignore[reportArgumentType] + super().__init__(**settings) self._check_ref_eager_mode_before_print_output_code() diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index 88e4cc25b..834bd0140 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -64,7 +64,7 @@ def triton_wait_signal( scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, - sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType] + sync_before: tl.constexpr = False, ) -> None: """ Wait for a global memory barrier to reach the expected value. @@ -83,7 +83,7 @@ def triton_wait_signal( sync_before: Add a CTA sync before the wait (default: False) """ tl.static_assert( - addr.type.is_ptr(), # pyright: ignore[reportAttributeAccessIssue] + addr.type.is_ptr(), "Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ", ) @@ -135,7 +135,7 @@ def triton_wait_multiple_signal( scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, - sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType] + sync_before: tl.constexpr = False, ) -> None: """ Simultaneously wait for multiple global memory barriers to reach the diff --git a/test/test_examples.expected b/test/test_examples.expected index 24865b2ad..dbb880280 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2051,7 +2051,7 @@ def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, A_pa v_12 = v_11 > add_2 # src[grouped_gemm.py:N]: if tile_in_group < num_group_tiles: # src[grouped_gemm.py:N]: # Convert linear tile index to 2D (M, N) tile coordinates - # src[grouped_gemm.py:N]: m_tile_idx = tile_in_group % num_m_tiles # pyright: ignore[reportOperatorIssue] + # src[grouped_gemm.py:N]: m_tile_idx = tile_in_group % num_m_tiles # src[grouped_gemm.py:N-N]: ... if v_12: v_8_copy_0_copy = v_8_copy_0 @@ -2060,7 +2060,7 @@ def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, A_pa v_8_copy_0_copy_0 = v_8_copy_0_copy group_start_copy_0_copy_0_copy_0 = group_start_copy_0_copy_0_copy group_end_copy_0_copy_0_copy_0 = group_end_copy_0_copy_0_copy - # src[grouped_gemm.py:N]: m_tile_idx = tile_in_group % num_m_tiles # pyright: ignore[reportOperatorIssue] + # src[grouped_gemm.py:N]: m_tile_idx = tile_in_group % num_m_tiles v_13 = tl.cast(v_8_copy_0_copy_0, tl.int64) v_14 = add_2 % v_13 v_15 = tl.full([], 0, tl.int32) @@ -2079,7 +2079,7 @@ def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, A_pa v_25 = tl.cast(v_22, tl.int64) v_26 = v_25 * _BLOCK_SIZE_0__2 v_27 = group_start_copy_0_copy_0_copy_0 + v_26 - # src[grouped_gemm.py:N]: base_col = n_tile_idx * BLOCK_N # pyright: ignore[reportOperatorIssue] + # src[grouped_gemm.py:N]: base_col = n_tile_idx * BLOCK_N _BLOCK_SIZE_1_ = _BLOCK_SIZE_1 v_28 = tl.cast(v_24, tl.int64) v_29 = v_28 * _BLOCK_SIZE_1_ @@ -4141,7 +4141,7 @@ def _helion_layer_norm_bwd(weight, x, grad_out, mean, rstd, grad_x, grad_weight_ v_5 = v_2 * v_4 sum_1 = tl.cast(tl.sum(v_5, 0), tl.float32) grad_w_acc = grad_w_acc_copy_0 + sum_1 - # src[layer_norm.py:N]: grad_b_acc += torch.sum(dy_mb, dim=0) # pyright: ignore[reportPossiblyUnboundVariable] + # src[layer_norm.py:N]: grad_b_acc += torch.sum(dy_mb, dim=0) sum_2 = tl.cast(tl.sum(v_2, 0), tl.float32) grad_b_acc = grad_b_acc_copy_0 + sum_2 # src[layer_norm.py:N]: wdy = weight_cta * dy_mb From a085515e35676f050e59bf480fb34a3407d1d95d Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 14:30:37 -0800 Subject: [PATCH 05/10] Remove search paths that pyrefly can't find. --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 225e7264a..70cf212c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,6 @@ source = "vcs" [tool.pyrefly] project-includes = ["helion", "examples"] project-excludes = ["test"] -search-path = ["triton/python", "../pytorch", "../pytorch-hg", "../pytorch-nightly"] python-version = "3.10" [tool.codespell] From 4e681626fc459139d77e3756c00e5f19c5e8a08e Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 14:30:55 -0800 Subject: [PATCH 06/10] Add pyrefly ignore directives. --- examples/all_reduce.py | 1 + examples/attention.py | 1 + examples/blackwell_attention.py | 4 + examples/exp.py | 2 +- examples/fp8_attention.py | 1 + examples/fused_linear_jsd.py | 1 + examples/geglu.py | 6 +- examples/grpo_loss.py | 2 +- examples/jagged_hstu_attn.py | 1 + examples/jagged_layer_norm.py | 1 + examples/jagged_mean.py | 2 + examples/jagged_softmax.py | 1 + examples/jagged_sum.py | 1 + examples/jsd.py | 1 + examples/layer_norm.py | 3 +- examples/matmul.py | 4 +- examples/rms_norm.py | 2 +- examples/softmax.py | 2 +- examples/swiglu.py | 6 +- helion/_compat.py | 2 + helion/_compiler/ast_extension.py | 16 +++- helion/_compiler/aten_lowering.py | 6 ++ helion/_compiler/compile_environment.py | 9 ++- helion/_compiler/device_function.py | 11 ++- helion/_compiler/device_ir.py | 30 ++++++- helion/_compiler/generate_ast.py | 7 ++ helion/_compiler/helper_function.py | 1 + helion/_compiler/host_function.py | 4 + helion/_compiler/indexing_strategy.py | 3 + helion/_compiler/inductor_lowering.py | 12 ++- helion/_compiler/inductor_lowering_extra.py | 12 +++ helion/_compiler/matmul_utils.py | 3 + helion/_compiler/program_id.py | 1 + helion/_compiler/source_location.py | 9 +++ helion/_compiler/static_loop_unroller.py | 1 + helion/_compiler/tensor_utils.py | 1 + helion/_compiler/tile_strategy.py | 7 ++ helion/_compiler/traceback_compat.py | 26 ++++-- helion/_compiler/type_propagation.py | 89 +++++++++++++++++++++ helion/_logging/_internal.py | 1 + helion/_testing.py | 19 ++++- helion/autotuner/base_search.py | 4 + helion/autotuner/block_id_sequence.py | 12 ++- helion/autotuner/config_spec.py | 5 +- helion/language/_decorators.py | 8 ++ helion/language/_tracing_ops.py | 10 +++ helion/language/constexpr.py | 2 + helion/language/creation_ops.py | 1 + helion/language/loops.py | 4 + helion/language/memory_ops.py | 8 +- helion/language/random_ops.py | 1 + helion/language/reduce_ops.py | 4 + helion/language/ref_tile.py | 5 +- helion/language/signal_wait.py | 4 + helion/language/stack_tensor.py | 3 + helion/language/view_ops.py | 2 + helion/runtime/kernel.py | 34 +++++++- helion/runtime/precompile_shim.py | 1 + helion/runtime/settings.py | 3 + helion/runtime/triton_helpers.py | 3 + 60 files changed, 397 insertions(+), 29 deletions(-) diff --git a/examples/all_reduce.py b/examples/all_reduce.py index 3e4715b9d..afb27f326 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -72,6 +72,7 @@ def dev_array_to_tensor_short( Returns: PyTorch tensor created from the device pointer """ + # pyrefly: ignore [missing-attribute] return cpp_mod.from_blob(dev_array_ptr, shape, dtype) diff --git a/examples/attention.py b/examples/attention.py index 244c54701..9ff1acc08 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -94,6 +94,7 @@ def attention( # --------------------- # %% +# pyrefly: ignore [no-matching-overload] attention_dynamic: object = helion.kernel( attention.fn, configs=attention.configs, diff --git a/examples/blackwell_attention.py b/examples/blackwell_attention.py index 9d1a87677..ebb6cf2e2 100644 --- a/examples/blackwell_attention.py +++ b/examples/blackwell_attention.py @@ -76,6 +76,7 @@ def _fma_f32x2(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tenso # %% +# pyrefly: ignore [no-matching-overload] @helion.kernel( configs=[ helion.Config( @@ -158,6 +159,7 @@ def blackwell_attention_kernel( qk = hl.dot(q_i, k_j.T, out_dtype=torch.float32) m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale) if VECT_MUL == 2 or VECT_MUL == 3: + # pyrefly: ignore [bad-argument-type] qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] @@ -169,6 +171,7 @@ def blackwell_attention_kernel( if SUBTILING: acc0, acc1 = hl.split( + # pyrefly: ignore [no-matching-overload] acc.reshape([tile_m, 2, Dv // 2]).permute(0, 2, 1) ) if VECT_MUL == 1 or VECT_MUL == 3: @@ -267,6 +270,7 @@ def ref_attention( atol=0.1, rtol=0.1, ) + # pyrefly: ignore [bad-assignment] dur: float = do_bench(lambda: blackwell_attention(q, k, v)) print( f"{z=} {h=} {n_ctx=} {head_dim=} tflops={z * h * n_ctx * n_ctx * head_dim * 4 / dur * 1e-9:.2f}" diff --git a/examples/exp.py b/examples/exp.py index ba35c341b..a3c076031 100644 --- a/examples/exp.py +++ b/examples/exp.py @@ -67,7 +67,7 @@ def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor) -> torch.Tensor: # %% class ExpFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: object, x: torch.Tensor, ) -> torch.Tensor: diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py index b62a054f9..96f47f966 100644 --- a/examples/fp8_attention.py +++ b/examples/fp8_attention.py @@ -106,6 +106,7 @@ def fp8_attention_kernel( # Final normalization acc = acc / l_i[:, None] # Convert to FP8 before writing to output + # pyrefly: ignore [unsupported-operation] out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn) return out diff --git a/examples/fused_linear_jsd.py b/examples/fused_linear_jsd.py index 0ce9f0bdc..8efb24e3a 100644 --- a/examples/fused_linear_jsd.py +++ b/examples/fused_linear_jsd.py @@ -85,6 +85,7 @@ def fused_linear_jsd_fwd_tritonbench( label: torch.Tensor | None = None, ) -> Callable[[], torch.Tensor]: assert label is None + # pyrefly: ignore [missing-attribute] baseline_op = tb_op.baseline_op beta = baseline_op.jsd.beta ignore_index = baseline_op.jsd.ignore_index diff --git a/examples/geglu.py b/examples/geglu.py index ab8d49e99..86f875f06 100644 --- a/examples/geglu.py +++ b/examples/geglu.py @@ -144,7 +144,7 @@ def geglu_bwd(grad_out: Tensor, a: Tensor, b: Tensor) -> tuple[Tensor, Tensor]: class GEGLUFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 a: Tensor, b: Tensor, @@ -343,8 +343,11 @@ def geglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Extract configuration from tritonbench operator config = Config( + # pyrefly: ignore [missing-attribute] hidden_size=tb_op.hidden_size, + # pyrefly: ignore [missing-attribute] intermediate_size=tb_op.intermediate_size, + # pyrefly: ignore [missing-attribute] hidden_act=tb_op.hidden_act, ) @@ -353,6 +356,7 @@ def geglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Copy weights from tritonbench baseline model (LlamaMLP) to ensure fairness # LlamaMLP has: gate_proj, up_proj, down_proj (same structure as our HelionGEGLUMLP) + # pyrefly: ignore [missing-attribute] baseline_model = tb_op.baseline_model # Copy gate projection weights diff --git a/examples/grpo_loss.py b/examples/grpo_loss.py index 94d810a6d..11923a11c 100644 --- a/examples/grpo_loss.py +++ b/examples/grpo_loss.py @@ -336,7 +336,7 @@ class GrpoLossFunction(torch.autograd.Function): """Custom autograd function for GRPO loss with forward and backward passes.""" @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: object, logits: torch.Tensor, old_logp: torch.Tensor | None, diff --git a/examples/jagged_hstu_attn.py b/examples/jagged_hstu_attn.py index 4e81c40d9..f7054196a 100644 --- a/examples/jagged_hstu_attn.py +++ b/examples/jagged_hstu_attn.py @@ -22,6 +22,7 @@ import helion.language as hl try: + # pyrefly: ignore [missing-import] from generative_recommenders.ops.triton.triton_hstu_attention import triton_hstu_mha HAS_HAMMER = True diff --git a/examples/jagged_layer_norm.py b/examples/jagged_layer_norm.py index e857ef143..621ace8a2 100644 --- a/examples/jagged_layer_norm.py +++ b/examples/jagged_layer_norm.py @@ -225,6 +225,7 @@ def jagged_layer_norm_tritonbench( Callable that returns normalized tensor values """ x_values = x._values + # pyrefly: ignore [missing-attribute] x_offsets = x._offsets return lambda: jagged_layer_norm_kernel(x_values, x_offsets, eps=1e-6) diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index b279f6b5b..da7f4e421 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -166,12 +166,14 @@ def jagged_mean_tritonbench( Callable that returns tensor of shape (B, M) with mean values per row and feature """ x_values = x._values + # pyrefly: ignore [missing-attribute] x_offsets = x._offsets feature_counts = torch.full( (B,), M, dtype=torch.int32, + # pyrefly: ignore [missing-attribute] device=x_values.device, ) return lambda: jagged_mean_kernel(x_values, x_offsets, feature_counts, M) diff --git a/examples/jagged_softmax.py b/examples/jagged_softmax.py index ee80fec93..f4e4bdfb6 100644 --- a/examples/jagged_softmax.py +++ b/examples/jagged_softmax.py @@ -163,6 +163,7 @@ def jagged_softmax_tritonbench( Returns: Callable that returns tensor of shape (N, M), where N = total number of rows in the jagged tensor """ + # pyrefly: ignore [missing-attribute] return lambda: jagged_softmax_kernel(x._values, x._offsets) diff --git a/examples/jagged_sum.py b/examples/jagged_sum.py index a56d018d7..0bb08ddb6 100644 --- a/examples/jagged_sum.py +++ b/examples/jagged_sum.py @@ -145,6 +145,7 @@ def jagged_sum_tritonbench( Callable that returns tensor of shape (B, M) with mean values per row and feature """ x_values = x._values + # pyrefly: ignore [missing-attribute] x_offsets = x._offsets return lambda: jagged_sum_kernel(x_values, x_offsets) diff --git a/examples/jsd.py b/examples/jsd.py index 979bdf70f..8394a8552 100644 --- a/examples/jsd.py +++ b/examples/jsd.py @@ -318,6 +318,7 @@ def jsd_tritonbench(tb_op: object, log_q: Tensor, log_p: Tensor) -> Callable: Callable: A callable that runs the JSD kernel """ + # pyrefly: ignore [missing-attribute] baseline_model = tb_op.baseline_op helion_jsd = HelionJSD( diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 65f9c2c47..02e7614d3 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -133,6 +133,7 @@ def layer_norm_bwd( grad_w_acc += torch.sum(dy_mb * x_hat, dim=0) if compute_bias_grad: + # pyrefly: ignore [unbound-name] grad_b_acc += torch.sum(dy_mb, dim=0) wdy = weight_cta * dy_mb @@ -155,7 +156,7 @@ def layer_norm_bwd( # %% class LayerNormFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 x: torch.Tensor, normalized_shape: list[int], diff --git a/examples/matmul.py b/examples/matmul.py index b3c3ca4d8..ed210f301 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -191,7 +191,7 @@ def addmm_bwd( # %% class MatMulFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 mat1: Tensor, mat2: Tensor, @@ -220,7 +220,7 @@ def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor: class AddMMFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 bias: Tensor, mat1: Tensor, diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 0d1342ae5..e486e25dc 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -119,7 +119,7 @@ def rms_norm_bwd( # %% class RMSNormFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 x: torch.Tensor, weight: torch.Tensor, diff --git a/examples/softmax.py b/examples/softmax.py index b148896df..dd68137b7 100644 --- a/examples/softmax.py +++ b/examples/softmax.py @@ -128,7 +128,7 @@ def softmax_bwd( class SoftmaxFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 x: torch.Tensor, ) -> torch.Tensor: diff --git a/examples/swiglu.py b/examples/swiglu.py index 7ab6e36af..86ff6c8ed 100644 --- a/examples/swiglu.py +++ b/examples/swiglu.py @@ -130,7 +130,7 @@ def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]: class SwigluFunction(torch.autograd.Function): @staticmethod - def forward( + def forward( # pyrefly: ignore [bad-override] ctx: Any, # noqa: ANN401 x1: Tensor, x2: Tensor, @@ -312,8 +312,11 @@ def swiglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Extract configuration from tritonbench operator config = Config( + # pyrefly: ignore [missing-attribute] hidden_size=tb_op.hidden_size, + # pyrefly: ignore [missing-attribute] intermediate_size=tb_op.intermediate_size, + # pyrefly: ignore [missing-attribute] hidden_act=tb_op.hidden_act, ) @@ -322,6 +325,7 @@ def swiglu_tritonbench(tb_op: object, x: Tensor) -> Callable: # Copy weights from tritonbench baseline model (LlamaMLP) to ensure fairness # LlamaMLP has: gate_proj, up_proj, down_proj (same structure as our HelionGEGLUMLP) + # pyrefly: ignore [missing-attribute] baseline_model = tb_op.baseline_op # Copy gate projection weights diff --git a/helion/_compat.py b/helion/_compat.py index ed1a3c666..935edee5a 100644 --- a/helion/_compat.py +++ b/helion/_compat.py @@ -251,6 +251,7 @@ def _min_dot_size( return (16, 16, 16) if torch.xpu.is_available(): + # pyrefly: ignore [missing-import] from triton.backends.intel.compiler import min_dot_size as min_dot_size_xpu device_properties = torch.xpu.get_device_properties() @@ -263,6 +264,7 @@ def _min_dot_size( dot_size_val = min_dot_size_xpu(gpu_target_info)( torch_dtype_to_tl(lhs), torch_dtype_to_tl(rhs) ) + # pyrefly: ignore [bad-return] return tuple(int(v) for v in dot_size_val) from triton.backends.nvidia.compiler import min_dot_size as min_dot_size_cuda diff --git a/helion/_compiler/ast_extension.py b/helion/_compiler/ast_extension.py index 90a4bd447..68c3c8a38 100644 --- a/helion/_compiler/ast_extension.py +++ b/helion/_compiler/ast_extension.py @@ -137,6 +137,7 @@ class Wrapper(ExtendedAST, cls): def create(cls: type[_T], **fields: object) -> _T: + # pyrefly: ignore [unexpected-keyword] result = get_wrapper_cls(cls)(**fields, _location=current_location()) assert isinstance(result, ExtendedAST) result._location.to_ast(result) @@ -215,6 +216,7 @@ def make_unique(m: re.Match[str]) -> str: def _replace(node: _R) -> _R: # Handle lists by recursively transforming each element if isinstance(node, list): + # pyrefly: ignore [bad-return] return [_replace(item) for item in node] # Pass through non-AST nodes unchanged (e.g., strings, numbers) @@ -227,9 +229,11 @@ def _replace(node: _R) -> _R: # Recursively transform all child nodes and wrap in ExtendedAST subclass cls = get_wrapper_cls(type(node)) + # pyrefly: ignore [bad-return] return location.to_ast( cls( **{field: _replace(getattr(node, field)) for field in node._fields}, + # pyrefly: ignore [unexpected-keyword] _location=location, ) ) @@ -256,9 +260,11 @@ def convert(node: ast.AST) -> ast.AST: return cls( **{field: convert(getattr(node, field)) for field in node._fields}, **{attr: getattr(node, attr) for attr in node._attributes}, + # pyrefly: ignore [unexpected-keyword] _location=location, ) elif isinstance(node, list): + # pyrefly: ignore [bad-return] return [convert(item) for item in node] else: return node @@ -290,15 +296,21 @@ def visit(self, node: ast.AST) -> ast.AST: ) -class _TupleParensRemovedUnparser(ast._Unparser): +class _TupleParensRemovedUnparser( + # pyrefly: ignore [missing-attribute] + ast._Unparser +): def visit_Tuple(self, node: ast.Tuple) -> None: if _needs_to_remove_tuple_parens and isinstance( getattr(node, "ctx", None), ast.Store ): if len(node.elts) == 1: # single-element tuple + # pyrefly: ignore [missing-attribute] self.traverse(node.elts[0]) + # pyrefly: ignore [missing-attribute] self.write(",") else: # multi-element tuple + # pyrefly: ignore [missing-attribute] self.interleave(lambda: self.write(", "), self.traverse, node.elts) return # For everything else fall back to default behavior @@ -306,6 +318,7 @@ def visit_Tuple(self, node: ast.Tuple) -> None: class _LocationAnnotatingOutputLines(OutputLines): + # pyrefly: ignore [missing-attribute] def __init__(self, parent: ast._Unparser) -> None: super().__init__(parent) self._cache: dict[tuple[str, int, int], tuple[str, ...]] = {} @@ -407,6 +420,7 @@ def __init__( if output_origin_lines: self.output = _LocationAnnotatingOutputLines(self) else: + # pyrefly: ignore [bad-assignment] self.output = OutputLines(self) self._source = self.output self._output_origin_lines = output_origin_lines diff --git a/helion/_compiler/aten_lowering.py b/helion/_compiler/aten_lowering.py index 944d1a52c..eff835c10 100644 --- a/helion/_compiler/aten_lowering.py +++ b/helion/_compiler/aten_lowering.py @@ -161,6 +161,7 @@ def codegen_full(ctx: LoweringContext, node: Node) -> object: if isinstance(value_ast, (int, float, bool)): value_ast = expr_from_string(constant_repr(value_ast)) assert isinstance(value_ast, ast.AST), value_ast + # pyrefly: ignore [not-iterable] shape_str = ctx.cg.device_function.tile_strategy.shape_str([*size]) return expr_from_string( f"tl.full({shape_str}, {{value}}, {triton_type(dtype)})", @@ -180,6 +181,7 @@ def codegen_unsqueeze(ctx: LoweringContext, node: Node) -> object: tensor, dim = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) assert isinstance(tensor, ast.AST) assert isinstance(dim, int) + # pyrefly: ignore [missing-attribute] ndim = node.args[0].meta["val"].ndim if dim < 0: dim += ndim @@ -230,6 +232,7 @@ def codegen_permute(ctx: LoweringContext, node: Node) -> object: assert not node.kwargs, "getitem kwargs not supported" tensor, dims = map_arg(node.args, lambda arg: _env_arg(ctx, arg)) assert isinstance(tensor, ast.AST) + # pyrefly: ignore [not-iterable] dims = [*dims] assert {*dims} == {*range(len(dims))}, dims return expr_from_string( @@ -250,6 +253,7 @@ def codegen_stack(ctx: LoweringContext, node: Node) -> object: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) assert isinstance(tensors, (list, tuple)) + # pyrefly: ignore [bad-index] tensor_asts = [ctx.env[t] for t in tensors] n = len(tensor_asts) @@ -311,8 +315,10 @@ def codegen_expand(ctx: LoweringContext, node: Node) -> object: val = node.meta["val"] assert isinstance(val, torch.Tensor) shape = [*val.size()] + # pyrefly: ignore [missing-attribute] if node.args[0].meta["val"].ndim != len(shape): broadcasting = [":"] * len(shape) + # pyrefly: ignore [missing-attribute] for i in range(len(shape) - node.args[0].meta["val"].ndim): broadcasting[i] = "None" tensor = expr_from_string( diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 626fad5b2..9628fe0c9 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -83,6 +83,7 @@ def __init__( from ..autotuner.config_spec import ConfigSpec super().__init__() + # pyrefly: ignore [read-only] self.device = device self.settings = settings self.index_dtype: torch.dtype = ( @@ -316,7 +317,11 @@ def to_fake(self, obj: object, origin: Origin) -> object: return [self.to_fake(e, origin) for e in obj] if isinstance(obj, tuple) and hasattr(obj, "_fields"): return type(obj)( - **{k: self.to_fake(e, origin) for k, e in obj._asdict().items()} + **{ + k: self.to_fake(e, origin) + # pyrefly: ignore [missing-attribute] + for k, e in obj._asdict().items() + } ) if isinstance(obj, tuple): return tuple(self.to_fake(e, origin) for e in obj) @@ -366,6 +371,7 @@ def size_hint(self, n: int | torch.SymInt) -> int: # hint will be wrong since we assign a default value to unbacked symbols. Return a default hint. return 8192 + # pyrefly: ignore [no-matching-overload] return int(self.shape_env.size_hint(n._sympy_())) assert isinstance(n, int) return n @@ -637,6 +643,7 @@ def _to_sympy(x: int | torch.SymInt | sympy.Expr) -> sympy.Expr: def _has_unbacked(expr: sympy.Expr) -> bool: + # pyrefly: ignore [missing-attribute] return any(n.name.startswith("u") for n in expr.free_symbols) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index c12d56d06..4f64aac87 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -78,10 +78,13 @@ def find_block_size_symbols( non_block_size_symbols = set() for symbol in expr.free_symbols: + # pyrefly: ignore [no-matching-overload] origin_info = hf.expr_to_origin.get(symbol) if origin_info is None or not isinstance(origin_info.origin, BlockSizeOrigin): + # pyrefly: ignore [bad-argument-type] non_block_size_symbols.add(symbol) else: + # pyrefly: ignore [unsupported-operation] block_sizes[symbol] = origin_info.origin.block_id return block_sizes, non_block_size_symbols @@ -301,6 +304,7 @@ def allocate_rng_seed(self) -> int: # Ensure seed buffer parameter name exists if self.rng_seed_buffer_param_name is None: + # pyrefly: ignore [bad-assignment] self.rng_seed_buffer_param_name = self.new_var("rng_seed_buffer") return seed_index @@ -353,6 +357,7 @@ def try_map_block_symbols_to_vars(self, expr: sympy.Expr) -> sympy.Expr | None: var_map[symbol] = sympy.Symbol(block_var, integer=True) # Successfully mapped all symbols + # pyrefly: ignore [bad-return] return expr.xreplace(var_map) def merge_variable_names(self, a: str, b: str) -> None: @@ -416,6 +421,7 @@ def user_sympy_expr(self, expr: sympy.Expr) -> str: if block_idx is not None: replacements[sym] = self.tile_strategy.user_size(block_idx) if replacements: + # pyrefly: ignore [bad-assignment] expr = expr.xreplace(replacements) return self.sympy_expr(expr) @@ -683,7 +689,10 @@ def dead_code_elimination(self) -> None: # drop any unused args args_to_remove = { - arg.name for arg in self.arguments if arg.name not in rw.reads + arg.name + for arg in self.arguments + # pyrefly: ignore [unbound-name] + if arg.name not in rw.reads } if args_to_remove: self.arguments = [ diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index d7612e4b3..658e62f4b 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -106,6 +106,7 @@ def _get_proxy_slot( if obj not in tracker: origin = HostFunction.current().tensor_to_origin[obj] assert origin.is_host() + # pyrefly: ignore [unsupported-operation] tracker[obj] = proxy = tracer.create_proxy( "call_function", _tracing_ops._host_tensor, @@ -120,6 +121,7 @@ def _get_proxy_slot( tracker = tracer.symnode_tracker if obj not in tracker: debug_name = CompileEnvironment.current().sympy_debug(obj._sympy_()) + # pyrefly: ignore [unsupported-operation] tracker[obj] = proxy = tracer.create_proxy( "call_function", _tracing_ops._get_symnode, @@ -129,6 +131,7 @@ def _get_proxy_slot( ) proxy.node.meta["val"] = obj proxy.node.meta["lowering"] = APIFuncLowering(_tracing_ops._get_symnode) + # pyrefly: ignore [missing-attribute] proxy.force = lambda: proxy return transform(tracker[obj]) return get_proxy_slot(obj, tracer, default, transform) @@ -305,7 +308,10 @@ def emit_condition( ) -> ast.expr: with state.codegen.set_statements(target_statements): cond_outputs = codegen_call_with_graph( - state.codegen, cond_info.graph, args + state.codegen, + cond_info.graph, + # pyrefly: ignore [bad-argument-type] + args, ) if len(cond_outputs) != 1: raise exc.InternalError( @@ -530,6 +536,7 @@ def _assign(self, target: ast.AST, value: object) -> None: if isinstance(n, ast.Starred): raise exc.StarredArgsNotSupportedOnDevice + # pyrefly: ignore [bad-index] self._assign(n, value[i]) elif isinstance(target, ast.Subscript): dst = self.visit(target.value) @@ -817,6 +824,7 @@ def build_subgraph( proxy_out = tracer.create_proxy( "call_function", _tracing_ops._for_loop, + # pyrefly: ignore [bad-argument-type] *args_to_proxies(tracer, args), ) proxy_tensor.track_tensor_tree( @@ -890,6 +898,7 @@ def build_body( proxy_out = tracer.create_proxy( "call_function", _tracing_ops._while_loop, + # pyrefly: ignore [bad-argument-type] *args_to_proxies(tracer, args), ) proxy_tensor.track_tensor_tree( @@ -950,6 +959,7 @@ def build_body( proxy_out = tracer.create_proxy( "call_function", _tracing_ops._if, + # pyrefly: ignore [bad-argument-type] *args_to_proxies(tracer, args), ) proxy_tensor.track_tensor_tree( @@ -1059,6 +1069,7 @@ def visit_Slice(self, node: ast.Slice) -> slice | torch.Tensor: # Convert slice to hl.arange when step is None or 1 and we have both bounds # This allows FX tracing to handle slice operations with dynamic bounds if lower is not None and upper is not None and (step is None or step == 1): + # pyrefly: ignore [bad-argument-type] return hl.arange(lower, upper) return slice(lower, upper, step) @@ -1139,8 +1150,10 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None: ) return hl.store( + # pyrefly: ignore [bad-argument-type] self.visit(target.value), self._subscript_slice_proxy(target.slice), + # pyrefly: ignore [bad-argument-type] val, ) @@ -1168,12 +1181,16 @@ def visit_Subscript(self, node: ast.Subscript) -> object: if isinstance(type_info, SequenceType): index_value = self.visit(node.slice) if isinstance(index_value, int): + # pyrefly: ignore [bad-index] return self.visit(value)[index_value] raise exc.InvalidSequenceSubscription(node.slice) if isinstance(type_info, StackTensorType): + # pyrefly: ignore [bad-argument-type] return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) if type_info is not None and type_info.origin.is_host(): + # pyrefly: ignore [bad-argument-type] return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) + # pyrefly: ignore [bad-argument-type] return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) def visit_Call(self, node: ast.Call) -> object: @@ -1181,23 +1198,29 @@ def visit_Call(self, node: ast.Call) -> object: kwargs = {} for arg in node.args: if isinstance(arg, ast.Starred): + # pyrefly: ignore [bad-argument-type] args.extend(self.visit(arg.value)) else: args.append(self.visit(arg)) for kwarg in node.keywords: if kwarg.arg is None: + # pyrefly: ignore [no-matching-overload] kwargs.update(self.visit(kwarg.value)) else: kwargs[kwarg.arg] = self.visit(kwarg.value) if isinstance( - (func_type_info := node.func._type_info), + ( + # pyrefly: ignore [missing-attribute] + func_type_info := node.func._type_info + ), CallableType, ) and (replacement := get_device_func_replacement(func_type_info.value)): func = replacement else: func = self.visit(node.func) + # pyrefly: ignore [bad-argument-type] return _CheckForIndexCalls.retry_call(func, args, kwargs) def visit_Attribute(self, node: ast.Attribute) -> object: @@ -1258,12 +1281,15 @@ def visit_For(self, node: ast.For) -> None: self.device_ir.add_root_graph( _make_fx(lambda: WalkDeviceAST(self.device_ir).visit(node)) ) + # pyrefly: ignore [missing-attribute] iter_type = node.iter._type_info assert isinstance(iter_type, IterType) inner = iter_type.inner if isinstance(inner, SequenceType): + # pyrefly: ignore [missing-attribute] block_ids = [x.block_id for x in inner.unpack()] else: + # pyrefly: ignore [missing-attribute] block_ids = [inner.block_id] self.device_ir.grid_block_ids.append(block_ids) else: diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index c629e9711..f07f41ded 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -218,6 +218,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: fields[field] = self.visit(old_value) else: fields[field] = old_value + # pyrefly: ignore [bad-return] return node.new(fields) def visit_For(self, node: ast.For) -> ast.AST | None: @@ -239,6 +240,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: ) ) self.device_function.body.extend( + # pyrefly: ignore [missing-attribute] self.device_function.pid.codegen_pid_init() ) if node._root_id < len(self.host_function.device_ir.root_ids) - 1: @@ -288,6 +290,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: self, fx_node=None, proxy_args=[*bound.arguments.values()], + # pyrefly: ignore [bad-argument-type] ast_args=None, ) @@ -318,6 +321,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: block.append( create( ast.If, + # pyrefly: ignore [missing-attribute] test=self.device_function.pid.codegen_test(state), body=body, orelse=self.next_else_block, @@ -329,6 +333,7 @@ def visit_For(self, node: ast.For) -> ast.AST | None: self.device_function ) if persistent_body is not None: + # pyrefly: ignore [bad-assignment] self.device_function.body = persistent_body self.device_function.dead_code_elimination() if not self.device_function.preamble and not self.device_function.body: @@ -376,6 +381,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: isinstance(x, TileIndexType) for x in type_info.unpack() ): values = type_info.unpack() + # pyrefly: ignore [missing-attribute] block_infos = [env.block_sizes[x.block_id] for x in values] return expr_from_string( self.host_function.literal_expr( @@ -411,6 +417,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST: proxy_params = api._signature.bind(*proxy_args, **proxy_kwargs) ast_params.apply_defaults() proxy_params.apply_defaults() + # pyrefly: ignore [bad-return] return codegen_fn( CodegenState( self, diff --git a/helion/_compiler/helper_function.py b/helion/_compiler/helper_function.py index 8193f60c3..eb80acd31 100644 --- a/helion/_compiler/helper_function.py +++ b/helion/_compiler/helper_function.py @@ -55,6 +55,7 @@ def extract_helper_function(helper_fn: object) -> types.FunctionType: """ from ..runtime.kernel import Kernel + # pyrefly: ignore [bad-return] return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index 606b6a4f7..e1cf5ff00 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -79,6 +79,7 @@ def __init__( ) -> None: super().__init__() env = CompileEnvironment.current() + # pyrefly: ignore [read-only] self.fn = fn self.constexpr_args = constexpr_args self.location: SourceLocation = UnknownLocation() @@ -139,8 +140,10 @@ def get_decorator_name(decorator: ast.expr) -> str: def global_scope_origin(self, name: str) -> AttributeOrigin: if SOURCE_MODULE not in self.global_imports: + # pyrefly: ignore [missing-attribute] module_name = self.fn.__globals__["__name__"] module = sys.modules[module_name] + # pyrefly: ignore [missing-attribute] assert module.__dict__ is self.fn.__globals__ self.global_imports[SOURCE_MODULE] = GlobalImport( value=module, @@ -152,6 +155,7 @@ def global_scope_origin(self, name: str) -> AttributeOrigin: def import_from_module( self, module_scope: dict[str, object], name: str ) -> AttributeOrigin: + # pyrefly: ignore [missing-attribute] if module_scope is self.fn.__globals__: return self.global_scope_origin(name) module_name = module_scope["__name__"] diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 983106081..1586a2cd6 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -165,6 +165,7 @@ def codegen_load( f"tl.load({name} + {{offset}}, {{mask}}{extra})", offset=indexing.index_expr, mask=indexing.mask_expr, + # pyrefly: ignore [bad-argument-type] ev=eviction_policy, ) @@ -211,6 +212,7 @@ def codegen_load( expr_from_string( f"tl.load({{block_ptr}}, boundary_check={indexing.boundary_check(state)}, padding_option='zero'{extra})", block_ptr=indexing.make_block_ptr(state), + # pyrefly: ignore [bad-argument-type] ev=eviction_policy, ), ) @@ -515,6 +517,7 @@ def codegen_load( base=dev_ptrs_ast, offset=indexing.index_expr, mask=mask_expr, + # pyrefly: ignore [bad-argument-type] ev=eviction_policy, ) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index e9575c01d..615cb353b 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -170,9 +170,12 @@ def convert_arg(arg: Node) -> TensorBox: with node.meta["location"], graph_lowering.set_current_node(node): try: result = graph_lowering.call_function( + # pyrefly: ignore [bad-argument-type] node.target, + # pyrefly: ignore [bad-argument-type] *map_arg((node.args, node.kwargs), convert_arg), ) + # pyrefly: ignore [implicit-import] except torch._inductor.exc.LoweringException as e: # Wrap in Helion exception to get location automatically raise InductorLoweringError(str(e)) from e @@ -318,6 +321,7 @@ def create_extra_node( ) with proxy_tensor.disable_proxy_modes_tracing(): node.meta["val"] = torch.empty( + # pyrefly: ignore [no-matching-overload] [*map(to_symint, buffer.get_size())], dtype=buffer.get_dtype(), device=buffer.get_device(), @@ -739,7 +743,9 @@ def codegen(self, ctx: LoweringContext, node: torch.fx.Node) -> object: CodegenState( ctx.cg, fx_node=node, + # pyrefly: ignore [bad-argument-type] proxy_args=proxy_args, + # pyrefly: ignore [bad-argument-type] ast_args=ast_args, ), ) @@ -1009,7 +1015,11 @@ def _collect_multi_outputs( assert "output_nodes" in node.meta output_nodes = node.meta["output_nodes"] outputs: list[object | None] = [None] * len(output_nodes) - all_nodes = {n.name: n for n in self.module.graph.nodes} + all_nodes = { + n.name: n + # pyrefly: ignore [missing-attribute] + for n in self.module.graph.nodes + } for idx, node_name in output_nodes.items(): if node_name == node.name: diff --git a/helion/_compiler/inductor_lowering_extra.py b/helion/_compiler/inductor_lowering_extra.py index 23d11bf50..db3adf0c5 100644 --- a/helion/_compiler/inductor_lowering_extra.py +++ b/helion/_compiler/inductor_lowering_extra.py @@ -62,30 +62,37 @@ def patch_inductor_lowerings() -> Generator[None, Any, Any]: affecting the global state, especially in cases where Helion is missing support for a specific lowering. """ + # pyrefly: ignore [implicit-import] original_lowerings = torch._inductor.lowering.lowerings.copy() try: + # pyrefly: ignore [implicit-import] torch._inductor.lowering.lowerings.update(inductor_lowering_dispatch) yield finally: + # pyrefly: ignore [implicit-import] torch._inductor.lowering.lowerings = original_lowerings +# pyrefly: ignore [implicit-import] register_inductor_lowering = torch._inductor.lowering.register_lowering def var_mean_helper_( + # pyrefly: ignore [implicit-import] x: torch._inductor.ir.TensorBox, *, axis: list[int] | None, correction: float | None, keepdim: bool, return_mean: bool, + # pyrefly: ignore [implicit-import] ) -> torch._inductor.ir.TensorBox: from torch._inductor.lowering import var_mean_sum_ from torch._prims_common import get_computation_dtype out_dtype = x.get_dtype() compute_dtype = get_computation_dtype(out_dtype) + # pyrefly: ignore [bad-assignment] x = to_dtype(x, compute_dtype, copy=False) kwargs = { @@ -98,6 +105,7 @@ def var_mean_helper_( # TODO(yf225): support Welford reduction in Helion, then switch back to use Inductor `var_mean_helper_()`. output = var_mean_sum_(**kwargs) output = tuple(to_dtype(o, out_dtype, copy=False) for o in output) + # pyrefly: ignore [bad-return] return output[0] if not return_mean else output @@ -106,11 +114,13 @@ def var_mean_helper_( lowering_dict=inductor_lowering_dispatch, ) def var_( + # pyrefly: ignore [implicit-import] x: torch._inductor.ir.TensorBox, axis: list[int] | None = None, *, correction: float | None = None, keepdim: bool = False, + # pyrefly: ignore [implicit-import] ) -> torch._inductor.ir.TensorBox: return var_mean_helper_( x, @@ -126,11 +136,13 @@ def var_( lowering_dict=inductor_lowering_dispatch, ) def var_mean( + # pyrefly: ignore [implicit-import] x: torch._inductor.ir.TensorBox, axis: list[int] | None = None, *, correction: float | None = None, keepdim: bool = False, + # pyrefly: ignore [implicit-import] ) -> torch._inductor.ir.TensorBox: return var_mean_helper_( x, diff --git a/helion/_compiler/matmul_utils.py b/helion/_compiler/matmul_utils.py index ad694a9e3..52869c2e5 100644 --- a/helion/_compiler/matmul_utils.py +++ b/helion/_compiler/matmul_utils.py @@ -150,6 +150,7 @@ def _pad_tensor( ) cur_size *= 2 shape = [cur_size, other_dim] if pad_dim == 0 else [other_dim, cur_size] + # pyrefly: ignore [bad-argument-type] x = expr_from_string(f"tl.reshape({{x}}, {shape_str(shape)})", x=x) return x @@ -289,6 +290,7 @@ def emit_tl_dot_with_padding( ] for dim, axis, min_dim, other in acc_pad_specs: if pad_needed[dim] and (cur := dims[dim]): + # pyrefly: ignore [unbound-name] acc_pad = _pad_tensor(acc_pad, axis, cur, min_dim, other) result = _emit_tl_dot( @@ -319,6 +321,7 @@ def emit_tl_dot_with_padding( if pad_needed[dim] and (cur := dims[dim]): assert dim in ("m", "n"), f"dim must be 'm' or 'n', got {dim}" cur_size = min_dim + # pyrefly: ignore [unbound-name] while cur_size > cur: cur_size //= 2 shape = shape_fn(cur_size, other) diff --git a/helion/_compiler/program_id.py b/helion/_compiler/program_id.py index 12ec1c944..0cb32fd6f 100644 --- a/helion/_compiler/program_id.py +++ b/helion/_compiler/program_id.py @@ -139,6 +139,7 @@ class ForEachProgramID(ProgramIDs): Represent multiple top level for loops in the Helion kernel. Turns into `if` statements in generated code. """ + # pyrefly: ignore [bad-override] shared_pid_var: str cases: list[ProgramIDs] = dataclasses.field(default_factory=list) pid_info: list[PIDInfo] = dataclasses.field(default_factory=list, init=False) diff --git a/helion/_compiler/source_location.py b/helion/_compiler/source_location.py index 249102597..6782aeaf7 100644 --- a/helion/_compiler/source_location.py +++ b/helion/_compiler/source_location.py @@ -70,19 +70,28 @@ def from_ast(node: ast.AST) -> SourceLocation: code = host_function.fn.__code__ offset = code.co_firstlineno - 1 return SourceLocation( + # pyrefly: ignore [missing-attribute] node.lineno + offset, + # pyrefly: ignore [missing-attribute] node.col_offset + host_function.column_offset, + # pyrefly: ignore [missing-attribute] node.end_lineno + offset, + # pyrefly: ignore [missing-attribute] node.end_col_offset + host_function.column_offset, filename=code.co_filename, name=code.co_name, ) def to_ast(self, node: _T) -> _T: + # pyrefly: ignore [missing-attribute] if "lineno" in node._attributes: + # pyrefly: ignore [missing-attribute] node.lineno = self.lineno + # pyrefly: ignore [missing-attribute] node.col_offset = self.colno + # pyrefly: ignore [missing-attribute] node.end_lineno = self.end_lineno + # pyrefly: ignore [missing-attribute] node.end_col_offset = self.end_colno return node diff --git a/helion/_compiler/static_loop_unroller.py b/helion/_compiler/static_loop_unroller.py index e4049f1ed..2b42de7de 100644 --- a/helion/_compiler/static_loop_unroller.py +++ b/helion/_compiler/static_loop_unroller.py @@ -25,6 +25,7 @@ class StaticLoopUnroller(ast.NodeTransformer): def visit_For(self, node: ast.For) -> ast.AST | list[ast.AST]: # Generic visit to handle nested loops + # pyrefly: ignore [bad-assignment] node = self.generic_visit(node) # Check if this is a static loop that can be unrolled diff --git a/helion/_compiler/tensor_utils.py b/helion/_compiler/tensor_utils.py index d736b2366..da088b59f 100644 --- a/helion/_compiler/tensor_utils.py +++ b/helion/_compiler/tensor_utils.py @@ -23,6 +23,7 @@ class _PadTensorFactoryMode(TorchDispatchMode): torch.ops.aten.new_ones.default: 1, } + # pyrefly: ignore [bad-override] def __torch_dispatch__( self, func: Callable[..., torch.Tensor], diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 84a19bdad..a87a3a28a 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -357,6 +357,7 @@ def _fold_tile_end_op( class FlattenedTileStrategy(BlockSizeTileStrategy): """Collapse all dimensions into single flat iteration space.""" + # pyrefly: ignore [bad-override] block_size: SymIntLike def __init__( @@ -410,6 +411,7 @@ def _codegen_common( total_numel = sympy.S.One statements = [] + # pyrefly: ignore [bad-assignment] for i, block_idx in enumerate(self._reorder(block_ids)): numel = env.block_sizes[block_idx].numel block_index_var = self.index_var(block_idx) @@ -428,6 +430,7 @@ def _codegen_common( f"{mask_var} = {offsets_var} < ({state.sympy_expr(total_numel)})" ) ) + # pyrefly: ignore [bad-return] return block_size_var, offsets_var, total_numel, statements def codegen_grid(self, state: CodegenState) -> DeviceGridState: @@ -562,6 +565,7 @@ def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]: class _BaseNDTileStrategy(BlockSizeTileStrategy): + # pyrefly: ignore [bad-override] block_size: list[SymIntLike] def __init__( @@ -631,6 +635,7 @@ def codegen_grid(self, state: CodegenState) -> DeviceGridState: state.add_statement( f"{index_var} = {offset_var} + tl.zeros([1], {dtype})" ) + # pyrefly: ignore [missing-attribute] mask_statement = self._setup_mask( state, block_idx, block_size, index_var, numel ) @@ -734,11 +739,13 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})" ), ] + # pyrefly: ignore [missing-attribute] mask_statement = self._setup_mask( state, block_idx, block_size, index_var, end ) if mask_statement is not None: extra_body.append(mask_statement) + # pyrefly: ignore [unsupported-operation] body[:] = [*extra_body, *body] body = [for_node] assert for_node is not None diff --git a/helion/_compiler/traceback_compat.py b/helion/_compiler/traceback_compat.py index d4d74bea9..ed6a5b2c4 100644 --- a/helion/_compiler/traceback_compat.py +++ b/helion/_compiler/traceback_compat.py @@ -32,6 +32,7 @@ def _ensure_original_line(fs: traceback.FrameSummary) -> None: # Same public behaviour as 3.11's property: # "return the line as-is from the source, without modifying whitespace". + # pyrefly: ignore [missing-attribute] fs._original_line = raw @@ -88,17 +89,26 @@ def normalize(off: int) -> int: statement = tree.body[0] if isinstance(statement, ast.Expr): - expr = statement.expr + expr = ( + # pyrefly: ignore [missing-attribute] + statement.expr + ) # # 1. Binary operator (a + b, a * b, ...) # if isinstance(expr, ast.BinOp): - operator_start = normalize(expr.left.end_col_offset) + operator_start = normalize( + # pyrefly: ignore [bad-argument-type] + expr.left.end_col_offset + ) operator_end = normalize(expr.right.col_offset) operator_str = segment[operator_start:operator_end] operator_offset = len(operator_str) - len(operator_str.lstrip()) - left_anchor = expr.left.end_col_offset + operator_offset + left_anchor = ( + # pyrefly: ignore [unsupported-operation] + expr.left.end_col_offset + operator_offset + ) right_anchor = left_anchor + 1 if ( operator_offset + 1 < len(operator_str) @@ -122,8 +132,14 @@ def normalize(off: int) -> int: # 2. Subscript (a[index]) # if isinstance(expr, ast.Subscript): - left_anchor = normalize(expr.value.end_col_offset) - right_anchor = normalize(expr.slice.end_col_offset + 1) + left_anchor = normalize( + # pyrefly: ignore [bad-argument-type] + expr.value.end_col_offset + ) + right_anchor = normalize( + # pyrefly: ignore [unsupported-operation] + expr.slice.end_col_offset + 1 + ) while left_anchor < len(segment) and ( (ch := segment[left_anchor]).isspace() or ch != "[" diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index bd2ae28b9..09a74bdff 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -90,6 +90,7 @@ def get(self, name: str) -> TypeInfo: def _get(self, name: str) -> TypeInfo: try: + # pyrefly: ignore [missing-attribute] value = self.function.fn.__globals__[name] except KeyError: if hasattr(builtins, name): @@ -135,6 +136,7 @@ def maybe_get(self, name: str) -> TypeInfo | None: except exc.UndefinedVariable: return None + # pyrefly: ignore [bad-override] def set(self, name: str, type_info: TypeInfo) -> None: self.variables[name] = type_info @@ -245,6 +247,7 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: origin, dict( zip( + # pyrefly: ignore [missing-attribute] value._fields, cls._unpack_example( value._asdict().items(), @@ -298,6 +301,7 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: ) attrs[compute_unit_literal] = SymIntType(attr_origin, sym) + # pyrefly: ignore [bad-argument-type] return ClassType(origin, attrs) raise exc.UnsupportedPythonType(type(value).__name__) @@ -543,6 +547,7 @@ def propagate_setitem( def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: if origin.is_host(): try: + # pyrefly: ignore [bad-index] return TypeInfo.from_example(self.fake_value[key.proxy()], origin) except NotImplementedError: raise exc.TypeInferenceError( @@ -604,6 +609,7 @@ def populate_symbol_origins(self, origin: Origin) -> None: class TensorAttributeType(TypeInfo): + # pyrefly: ignore [bad-override] origin: AttributeOrigin tensor: TensorType @@ -701,6 +707,7 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: try: + # pyrefly: ignore [bad-index] return TypeInfo.from_example(self.value[key.as_literal()], origin) except NotImplementedError: pass @@ -720,11 +727,13 @@ def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: float, bool, ): + # pyrefly: ignore [bad-argument-type] return NumericType.subtype(self.python_type).new_unbacked(self.origin) return super().merge(other, var_name=var_name) def unpack(self) -> list[TypeInfo]: try: + # pyrefly: ignore [no-matching-overload] it = iter(self.value) except TypeError: return super().unpack() @@ -744,6 +753,7 @@ def __str__(self) -> str: class ConfigFragmentType(LiteralType): """TypeInfo for config fragments are treated as constant literals during compilation.""" + # pyrefly: ignore [bad-override] value: ConfigSpecFragment def __init__(self, origin: Origin, fragment: ConfigSpecFragment) -> None: @@ -752,6 +762,7 @@ def __init__(self, origin: Origin, fragment: ConfigSpecFragment) -> None: class CallableType(LiteralType): + # pyrefly: ignore [bad-override] value: Callable[..., object] def __init__(self, origin: Origin, value: Callable[..., object]) -> None: @@ -771,6 +782,7 @@ def name(self) -> str: except AttributeError: return str(self.value) + # pyrefly: ignore [bad-override] def propagate_call( self, args: tuple[TypeInfo, ...], kwargs: dict[str, TypeInfo], origin: Origin ) -> TypeInfo | None: @@ -877,6 +889,7 @@ def _raise_shape_specializing(*args: object) -> None: class PythonModuleType(LiteralType): + # pyrefly: ignore [bad-override] value: types.ModuleType def __init__(self, origin: Origin, value: types.ModuleType) -> None: @@ -975,6 +988,7 @@ def populate_symbol_origins(self, origin: Origin) -> None: class SymIntType(NumericType): + # pyrefly: ignore [bad-override] value: torch.SymInt @classmethod @@ -995,6 +1009,7 @@ def proxy(self) -> torch.SymInt | int: class SymFloatType(NumericType): + # pyrefly: ignore [bad-override] value: torch.SymFloat @classmethod @@ -1012,6 +1027,7 @@ def python_type(self) -> type[float]: class SymBoolType(NumericType): + # pyrefly: ignore [bad-override] value: torch.SymBool @classmethod @@ -1215,8 +1231,10 @@ def propagate_setitem( k := key.value, (int, str) ): if k in elements: + # pyrefly: ignore [bad-index, unsupported-operation] elements[k] = elements[k].merge(value) else: + # pyrefly: ignore [unsupported-operation] elements[k] = value return self return super().propagate_setitem(key, value, origin) @@ -1228,12 +1246,14 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: pass else: try: + # pyrefly: ignore [bad-index] result = self.element_types[literal_key] except (KeyError, IndexError) as e: raise exc.TypeInferenceError(f"{type(e).__name__}: {e}") from None if isinstance(result, TypeInfo): return result if type(result) is self.python_type: # sliced! + # pyrefly: ignore [bad-argument-type] return type(self)(origin=origin, element_types=result) return super().propagate_getitem(key, origin) @@ -1245,6 +1265,7 @@ def tree_map(self, fn: Callable[[TypeInfo], object]) -> object: class SequenceType(CollectionType): + # pyrefly: ignore [bad-override] element_types: list[TypeInfo] | tuple[TypeInfo, ...] def __str__(self) -> str: @@ -1309,6 +1330,7 @@ def tree_map( class DictType(CollectionType): + # pyrefly: ignore [bad-override] element_types: dict[str | int, TypeInfo] def __str__(self) -> str: @@ -1362,8 +1384,10 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo: class StackTensorType(ClassType): + # pyrefly: ignore [bad-override] element_types: dict[str, TypeInfo] + # pyrefly: ignore [bad-override] def proxy(self) -> StackTensor: with proxy_tensor.disable_proxy_modes_tracing(): fake_mode = torch._C._unset_dispatch_mode( @@ -1444,6 +1468,7 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: class SliceType(CollectionType): + # pyrefly: ignore [bad-override] element_types: slice @property @@ -1496,40 +1521,56 @@ def _eval_unary(op: ast.unaryop, value: object) -> object: if isinstance(op, ast.Not): return not value if isinstance(op, ast.UAdd): + # pyrefly: ignore [unsupported-operation] return +value if isinstance(op, ast.USub): + # pyrefly: ignore [unsupported-operation] return -value if isinstance(op, ast.Invert): + # pyrefly: ignore [unsupported-operation] return ~value raise AssertionError(f"{type(op).__name__} unknown unary op") def _eval_binary(op: ast.operator, left: object, right: object) -> object: if isinstance(op, ast.Add): + # pyrefly: ignore [unsupported-operation] return left + right if isinstance(op, ast.Sub): + # pyrefly: ignore [unsupported-operation] return left - right if isinstance(op, ast.Mult): + # pyrefly: ignore [unsupported-operation] return left * right if isinstance(op, ast.Div): + # pyrefly: ignore [unsupported-operation] return left / right if isinstance(op, ast.FloorDiv): + # pyrefly: ignore [unsupported-operation] return left // right if isinstance(op, ast.Mod): + # pyrefly: ignore [unsupported-operation] return left % right if isinstance(op, ast.Pow): + # pyrefly: ignore [unsupported-operation] return left**right if isinstance(op, ast.LShift): + # pyrefly: ignore [unsupported-operation] return left << right if isinstance(op, ast.RShift): + # pyrefly: ignore [unsupported-operation] return left >> right if isinstance(op, ast.BitOr): + # pyrefly: ignore [unsupported-operation] return left | right if isinstance(op, ast.BitXor): + # pyrefly: ignore [unsupported-operation] return left ^ right if isinstance(op, ast.BitAnd): + # pyrefly: ignore [unsupported-operation] return left & right if isinstance(op, ast.MatMult): + # pyrefly: ignore [unsupported-operation] return left @ right raise AssertionError(f"{type(op).__name__} unknown binary op") @@ -1540,20 +1581,26 @@ def _eval_compare(op: ast.cmpop, left: object, right: object) -> object: if isinstance(op, ast.NotEq): return left != right if isinstance(op, ast.Lt): + # pyrefly: ignore [unsupported-operation] return left < right if isinstance(op, ast.LtE): + # pyrefly: ignore [unsupported-operation] return left <= right if isinstance(op, ast.Gt): + # pyrefly: ignore [unsupported-operation] return left > right if isinstance(op, ast.GtE): + # pyrefly: ignore [unsupported-operation] return left >= right if isinstance(op, ast.Is): return left is right if isinstance(op, ast.IsNot): return left is not right if isinstance(op, ast.In): + # pyrefly: ignore [not-iterable] return left in right if isinstance(op, ast.NotIn): + # pyrefly: ignore [not-iterable] return left not in right raise AssertionError(f"{type(op).__name__} unknown compare op") @@ -1698,6 +1745,7 @@ def _bool_op(self, op: ast.boolop, left: TypeInfo, right: TypeInfo) -> TypeInfo: and left.python_type == right.python_type and (pt := left.python_type) in (int, float, bool) ): + # pyrefly: ignore [bad-argument-type] return NumericType.subtype(pt).new_unbacked(self.origin()) raise exc.TypeInferenceError( f"{type(op).__name__} not supported on {left!s} and {right!s}" @@ -1801,6 +1849,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None: ) from None return self._assign(lhs.value, unpacked) if isinstance(lhs, (ast.Tuple, ast.List)): + # pyrefly: ignore [bad-assignment] lhs = lhs.elts elements: list[TypeInfo] try: @@ -1808,14 +1857,17 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None: except NotImplementedError: if isinstance(rhs, TileIndexType): raise exc.FailedToUnpackTile from None + # pyrefly: ignore [bad-argument-type] raise exc.FailedToUnpackTupleAssign(len(lhs), rhs) from None used_star = False idx = 0 + # pyrefly: ignore [not-iterable] for elt in lhs: if isinstance(elt, ast.Starred): # TODO(jansel): need to test this assert not used_star, "multiple `*` in assignment" used_star = True + # pyrefly: ignore [bad-argument-type] star_len = len(elements) - len(lhs) + 1 assert star_len >= 0, "wrong number of elements to unpack" self._assign( @@ -1898,8 +1950,11 @@ def _list_or_tuple(self, node: ast.List | ast.Tuple) -> TypeInfo: cls(elements), ) + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_List: _VisitMethod = _list_or_tuple + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Tuple: _VisitMethod = _list_or_tuple + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Set: _VisitMethod = _unsupported(set) def visit_Dict(self, node: ast.Dict) -> TypeInfo: @@ -1930,6 +1985,7 @@ def visit_Name(self, node: ast.Name) -> TypeInfo: raise exc.CannotReadDeviceVariableOnHost(node.id) return result + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Starred: _VisitMethod = generic_visit def visit_Expr(self, node: ast.Expr) -> TypeInfo: @@ -2059,6 +2115,7 @@ def visit_Call(self, node: ast.Call) -> TypeInfo: "Failed to unpack */** args to function, got: " + ", ".join(map(str, unhandled)) ) + # pyrefly: ignore [bad-argument-type, bad-return] return func.propagate_call(tuple(args), kwargs, self.origin()) def visit_IfExp(self, node: ast.IfExp) -> TypeInfo: @@ -2167,14 +2224,19 @@ def visit_Assert(self, node: ast.Assert) -> TypeInfo: self.visit(node.msg) return NoType(origin=self.origin()) + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Raise: _VisitMethod = generic_statement + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Delete: _VisitMethod = generic_statement def visit_Pass(self, node: ast.Pass) -> TypeInfo: return NoType(origin=self.origin()) + # pyrefly: ignore [bad-assignment] visit_TypeAlias: _VisitMethod = generic_statement + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Import: _VisitMethod = generic_statement + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_ImportFrom: _VisitMethod = generic_statement def visit_Global(self, node: ast.Global) -> TypeInfo: @@ -2182,6 +2244,7 @@ def visit_Global(self, node: ast.Global) -> TypeInfo: return NoType(origin=self.origin()) # TODO(jansel): support lambda + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Lambda: _VisitMethod = generic_visit ################################################################ @@ -2275,7 +2338,9 @@ def visit_While(self, node: ast.While) -> TypeInfo: self.scope.merge_if_else(body, orelse) return NoType(origin=self.origin()) + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Break: _VisitMethod = generic_statement + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Continue: _VisitMethod = generic_statement def visit_Try(self, node: ast.Try) -> TypeInfo: @@ -2288,6 +2353,7 @@ def visit_Try(self, node: ast.Try) -> TypeInfo: self.scope.overwrite(self._body(node.finalbody)) return NoType(origin=self.origin()) + # pyrefly: ignore [bad-assignment] visit_TryStar: _VisitMethod = visit_Try def _not_on_device_statement(self, node: ast.AST) -> TypeInfo: @@ -2297,8 +2363,11 @@ def _not_on_device_statement(self, node: ast.AST) -> TypeInfo: self.visit(child_node) return NoType(origin=self.origin()) + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_ExceptHandler: _VisitMethod = _not_on_device_statement + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_With: _VisitMethod = _not_on_device_statement + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Return: _VisitMethod = _not_on_device_statement def _not_supported(self, node: ast.AST) -> TypeInfo: @@ -2361,28 +2430,48 @@ def visit_ListComp(self, node: ast.ListComp) -> TypeInfo: return self._evaluate_comprehension(node.generators[0], node.elt) # TODO(jansel): need to implement these + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_SetComp: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_GeneratorExp: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_DictComp: _VisitMethod = _not_supported # TODO(jansel): support closure functions defined on host + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_FunctionDef: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_ClassDef: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Yield: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_YieldFrom: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_AsyncFunctionDef: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_AsyncFor: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_AsyncWith: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Await: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_Match: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchValue: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchSingleton: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchSequence: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchStar: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchMapping: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchClass: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchAs: _VisitMethod = _not_supported + # pyrefly: ignore [bad-assignment, bad-param-name-override] visit_MatchOr: _VisitMethod = _not_supported diff --git a/helion/_logging/_internal.py b/helion/_logging/_internal.py index d5f17937c..23354cdb7 100644 --- a/helion/_logging/_internal.py +++ b/helion/_logging/_internal.py @@ -86,6 +86,7 @@ class LazyString: def __init__( self, func: Callable[P, str], *args: P.args, **kwargs: P.kwargs ) -> None: + # pyrefly: ignore [invalid-type-var] self.func: Callable[P, str] = func self.args: tuple[object, ...] = args self.kwargs: object = kwargs diff --git a/helion/_testing.py b/helion/_testing.py index 41d2f023e..e595c1b84 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -39,6 +39,7 @@ def _get_triton_backend() -> str | None: try: + # pyrefly: ignore [missing-attribute] return triton.runtime.driver.active.get_current_target().backend except Exception: return None @@ -213,6 +214,7 @@ def tracked_run_ref(self: BoundKernel, *args: object) -> object: run_ref_count[0] += 1 return original_run_ref(self, *args) + # pyrefly: ignore [bad-assignment] BoundKernel.run_ref = tracked_run_ref try: @@ -302,6 +304,7 @@ def setUp(self) -> None: # Patch torch.testing.assert_close to count calls if RefEagerTestBase._original_assert_close_func is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_assert_close_func = torch.testing.assert_close def counting_assert_close(*args: object, **kwargs: object) -> None: @@ -312,6 +315,7 @@ def counting_assert_close(*args: object, **kwargs: object) -> None: # Patch self.assertRaises to count calls if RefEagerTestBase._original_assert_raises_func is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_assert_raises_func = self.assertRaises def counting_assert_raises(*args: object, **kwargs: object) -> object: @@ -322,6 +326,7 @@ def counting_assert_raises(*args: object, **kwargs: object) -> object: # Patch self.skipTest to count calls if RefEagerTestBase._original_skip_test_func is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_skip_test_func = self.skipTest def counting_skip_test(*args: object, **kwargs: object) -> object: @@ -336,6 +341,7 @@ def counting_skip_test(*args: object, **kwargs: object) -> object: # Patch pytest.raises to count calls if RefEagerTestBase._original_pytest_raises is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_pytest_raises = pytest.raises def counting_pytest_raises(*args: object, **kwargs: object) -> object: @@ -348,6 +354,7 @@ def counting_pytest_raises(*args: object, **kwargs: object) -> object: # Patch self.assertTrue to count calls if RefEagerTestBase._original_assert_true_func is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_assert_true_func = self.assertTrue def counting_assert_true(*args: object, **kwargs: object) -> None: @@ -358,6 +365,7 @@ def counting_assert_true(*args: object, **kwargs: object) -> None: # Patch self.assertFalse to count calls if RefEagerTestBase._original_assert_false_func is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_assert_false_func = self.assertFalse def counting_assert_false(*args: object, **kwargs: object) -> None: @@ -368,6 +376,7 @@ def counting_assert_false(*args: object, **kwargs: object) -> None: # Patch self.assertGreater to count calls if RefEagerTestBase._original_assert_greater_func is None: + # pyrefly: ignore [bad-assignment] RefEagerTestBase._original_assert_greater_func = self.assertGreater def counting_assert_greater(*args: object, **kwargs: object) -> None: @@ -398,6 +407,7 @@ def tearDown(self) -> None: # Assert that either run_ref was called or the test was skipped if not is_skipped and self._run_ref_count[0] == 0: self.fail( # type: ignore[attr-defined] + # pyrefly: ignore [missing-attribute] f"Test {self._testMethodName} did not call run_ref and was not skipped" ) @@ -506,8 +516,10 @@ def assertIsInstance( def import_path(filename: Path) -> types.ModuleType: module_name = f"{__name__}.{filename.stem}" if module_name not in sys.modules: + # pyrefly: ignore [implicit-import] spec = importlib.util.spec_from_file_location(module_name, filename) assert spec is not None + # pyrefly: ignore [implicit-import] module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) @@ -523,6 +535,7 @@ def code_and_output( bound = fn.bind(args) if is_ref_mode_enabled(bound.kernel.settings): if kwargs: + # pyrefly: ignore [bad-argument-type] config = Config(**kwargs) bound._config = config result = fn(*args) @@ -531,7 +544,10 @@ def code_and_output( return code, result if kwargs: - config = Config(**kwargs) + config = Config( + # pyrefly: ignore [bad-argument-type] + **kwargs + ) elif fn.configs: (config,) = fn.configs else: @@ -673,6 +689,7 @@ def run_example( all_benchmarks = {**kernels, **baselines} bench_fns = [functools.partial(fn, *args) for fn in all_benchmarks.values()] repeat = compute_repeat(bench_fns[0]) + # pyrefly: ignore [bad-argument-type] timings = interleaved_bench(bench_fns, repeat=repeat, desc="Benchmarking") all_times = dict(zip(all_benchmarks.keys(), timings, strict=True)) best_baseline_time = min(all_times[name] for name in baselines) diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index f83977129..7c2a38b6f 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -851,8 +851,10 @@ def rebenchmark( repeat = min(1000, max(3, base_repeat)) iterator = [functools.partial(m.fn, *self.args) for m in members] if self.settings.autotune_progress_bar: + # pyrefly: ignore [bad-argument-type] new_timings = interleaved_bench(iterator, repeat=repeat, desc=desc) else: + # pyrefly: ignore [bad-argument-type] new_timings = interleaved_bench(iterator, repeat=repeat) for m, t in zip(members, new_timings, strict=True): m.perfs.append(t) @@ -1091,6 +1093,7 @@ def _wait_for_all_step( # Wait for at least one to finish or time out timeout = min([f.seconds_left() for f in running], default=0.0) + # pyrefly: ignore [missing-attribute] handles = [f.process.sentinel for f in running] if handles and timeout > 0: connection.wait(handles, timeout) @@ -1287,6 +1290,7 @@ def _consume_result(self, *, raise_on_raise: bool) -> None: self.search.kernel.maybe_log_repro( self.search.log.warning, self.search.args, self.config ) + # pyrefly: ignore [unbound-name] elif not ignore_errors: self.search.log.debug(formatted) self.search.kernel.maybe_log_repro( diff --git a/helion/autotuner/block_id_sequence.py b/helion/autotuner/block_id_sequence.py index 0663eeb86..da5b8c04a 100644 --- a/helion/autotuner/block_id_sequence.py +++ b/helion/autotuner/block_id_sequence.py @@ -72,13 +72,16 @@ def _reindex(self) -> None: new_index[block_id] = i self._block_id_to_index = new_index + # pyrefly: ignore [bad-override] def __getitem__(self, index: int) -> _BlockIdItemT: return self._data[index] + # pyrefly: ignore [bad-override] def __setitem__(self, index: int, value: _BlockIdItemT) -> None: self._data[index] = value self._reindex() # could be faster, but uncommon case + # pyrefly: ignore [bad-override] def __delitem__(self, index: int) -> None: del self._data[index] self._reindex() # could be faster, but uncommon case @@ -120,7 +123,11 @@ def disable_block_id(self, block_id: int) -> None: self._reindex() def config_get( - self, config: list[_T], block_id: int, default: _D = None + self, + config: list[_T], + block_id: int, + # pyrefly: ignore [bad-function-definition] + default: _D = None, ) -> _T | _D: """ Get the config value for the given block_id, or return default if not found. @@ -178,6 +185,7 @@ def _normalize( if len(values) < size: try: for spec in self._data[len(values) :]: + # pyrefly: ignore [bad-argument-type] values.append(spec._fill_missing()) except NotImplementedError: raise InvalidConfig( @@ -186,7 +194,9 @@ def _normalize( f"Did you forget to specify block sizes for all your hl.tile() dimensions?" ) from None for i, spec in enumerate(self._data): + # pyrefly: ignore [unsupported-operation] values[i] = spec._normalize(f"config[{name}][{i}]", values[i]) + # pyrefly: ignore [bad-return] return values def _remove_duplicates(self) -> None: diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 4e2afbb86..624e74eea 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -107,7 +107,9 @@ class ConfigSpec: ) indexing: ListOf = dataclasses.field( default_factory=lambda: ListOf( - EnumFragment(choices=ConfigSpec._valid_indexing_types()), length=0 + # pyrefly: ignore [unbound-name] + EnumFragment(choices=ConfigSpec._valid_indexing_types()), + length=0, ) ) @@ -320,6 +322,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf if not config.get(name): config.pop(name, None) self.normalize(config) + # pyrefly: ignore [bad-argument-type] return helion.Config(**config) diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index f3fc2ccf6..c709c45e4 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -198,6 +198,7 @@ def wrapper(*args: object, **kwargs: object) -> object: cast("Callable[..., object]", fn) ) api._ref_fn = None + # pyrefly: ignore [bad-return] return wrapper return _impl @@ -218,6 +219,7 @@ def _impl(fake_fn: Callable[..., object]) -> Callable[..., Never]: ) return _no_call + # pyrefly: ignore [bad-return] return _impl @@ -231,6 +233,7 @@ def _impl(type_fn: Callable[..., TypeInfo]) -> Callable[..., Never]: original_fn._type_function = type_fn return _no_call + # pyrefly: ignore [bad-return] return _impl @@ -249,6 +252,7 @@ def _impl( original_fn._prepare_args = prep_fn return _no_call + # pyrefly: ignore [bad-return] return _impl @@ -266,6 +270,7 @@ def _impl(codegen_fn: Callable[[CodegenState], object]) -> Callable[..., Never]: original_fn._codegen[backend] = codegen_fn return _no_call + # pyrefly: ignore [bad-return] return _impl @@ -284,6 +289,7 @@ def _impl( original_fn._get_masked_value = mask_value_fn return _no_call + # pyrefly: ignore [bad-return] return _impl @@ -298,6 +304,7 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]: original_fn._to_device_ir = to_device_ir_fn return _no_call + # pyrefly: ignore [bad-return] return _impl @@ -314,6 +321,7 @@ def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]: original_fn._ref_fn = ref_fn return _no_call + # pyrefly: ignore [bad-return] return _impl diff --git a/helion/language/_tracing_ops.py b/helion/language/_tracing_ops.py index af9e978a7..a4ddae9a8 100644 --- a/helion/language/_tracing_ops.py +++ b/helion/language/_tracing_ops.py @@ -42,6 +42,7 @@ def _get_symnode(debug_name: str) -> int: @_decorators.codegen(_get_symnode, "triton") def _(state: CodegenState) -> ast.AST: + # pyrefly: ignore [missing-attribute] val = state.fx_node.meta["val"] # Handle the case where val is a regular integer (e.g., from reduction_loops config) @@ -49,6 +50,7 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(str(val)) assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val + # pyrefly: ignore [bad-argument-type] if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None: block_size_var = state.device_function.block_size_var(block_idx) if block_size_var is None: @@ -85,6 +87,7 @@ def _for_loop( @_decorators.codegen(_for_loop, "triton") def _(state: CodegenState) -> None: + # pyrefly: ignore [bad-index] return HostFunction.current().device_ir.graphs[state.proxy_arg(0)].codegen(state) @@ -102,6 +105,7 @@ def _while_loop( @_decorators.codegen(_while_loop, "triton") def _(state: CodegenState) -> None: + # pyrefly: ignore [bad-index] return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) @@ -114,6 +118,7 @@ def _if(test: object, graph_id: int, args: list[object]) -> list[object]: @_decorators.codegen(_if, "triton") def _(state: CodegenState) -> None: + # pyrefly: ignore [bad-index] return HostFunction.current().device_ir.graphs[state.proxy_arg(1)].codegen(state) @@ -182,6 +187,7 @@ def _and(left: object, right: object) -> object: @_decorators.codegen(_and, "triton") def _(state: CodegenState) -> None: + # pyrefly: ignore [bad-return] return expr_from_string( "{lhs} and {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1) ) @@ -235,6 +241,7 @@ def _(left: object, right: object) -> object: @_decorators.codegen(_or, "triton") def _(state: CodegenState) -> None: + # pyrefly: ignore [bad-return] return expr_from_string( "{lhs} or {rhs}", lhs=state.ast_arg(0), rhs=state.ast_arg(1) ) @@ -343,10 +350,13 @@ def _new_var(value: _T, /) -> _T: @_decorators.register_fake(_new_var) def _(value: _T) -> _T: if isinstance(value, torch.Tensor): + # pyrefly: ignore [bad-return] return torch.empty_like(value) if isinstance(value, torch.SymInt): + # pyrefly: ignore [bad-return] return CompileEnvironment.current().create_unbacked_symint() if isinstance(value, (int, float, bool)) or value is None: + # pyrefly: ignore [bad-return] return value raise NotImplementedError(f"Unsupported type for _new_var: {type(value)}") diff --git a/helion/language/constexpr.py b/helion/language/constexpr.py index 8c73171f5..4528d9e05 100644 --- a/helion/language/constexpr.py +++ b/helion/language/constexpr.py @@ -112,8 +112,10 @@ def _convert_specializable( on_symint: Callable[[torch.SymInt], int] = lambda symint: symint.__int__(), ) -> _T: if isinstance(value, torch.SymInt): + # pyrefly: ignore [bad-return] return on_symint(value) if isinstance(value, int): + # pyrefly: ignore [bad-return] return value if isinstance(value, (torch.Size, tuple, list)): try: diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index 62075efd1..1da27f194 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -210,6 +210,7 @@ def arange( env = CompileEnvironment.current() if dtype is None: dtype = env.index_dtype + # pyrefly: ignore [no-matching-overload] return torch.arange( *args, **kwargs, diff --git a/helion/language/loops.py b/helion/language/loops.py index 97cff66c3..e8bda9815 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -329,6 +329,7 @@ def _( ) block_size_list = Tile._tiles_to_sizes(block_size_list) + # pyrefly: ignore [unbound-name] if unpack: target = getattr(parent, "target", None) if isinstance(target, (ast.Tuple, ast.List)) and len(target.elts) > 1: @@ -375,6 +376,7 @@ def _( ) ], ) + # pyrefly: ignore [unbound-name] if unpack: (result,) = results else: @@ -734,6 +736,7 @@ def _( size = None # data dependent size if step_part is None: step_part = 1 + # pyrefly: ignore [bad-argument-type] results.append(GridIndexType.allocate(size, origin, step_part)) _add_config_choices( @@ -746,6 +749,7 @@ def _( ) ], ) + # pyrefly: ignore [unbound-name] if unpack: (result,) = results else: diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index f45ec5947..eac0acde6 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -132,6 +132,7 @@ def _( for i, idx in enumerate(index): if isinstance(idx, RefTile): idx = idx.index + # pyrefly: ignore [bad-argument-type] indices.append(idx) if isinstance(idx, torch.Tensor): tensor_idx_positions.append(i) @@ -139,9 +140,12 @@ def _( # Handle broadcasting for multiple tensor indices if len(tensor_idx_positions) > 1: grids = torch.meshgrid( - *(indices[i] for i in tensor_idx_positions), indexing="ij" + # pyrefly: ignore [bad-argument-type] + *(indices[i] for i in tensor_idx_positions), + indexing="ij", ) for i, grid in zip(tensor_idx_positions, grids, strict=False): + # pyrefly: ignore [unsupported-operation] indices[i] = grid if extra_mask is not None: @@ -163,6 +167,7 @@ def _( else: idx_val = int(idx) if isinstance(idx, torch.SymInt) else idx valid_indices.append( + # pyrefly: ignore [no-matching-overload] torch.full( (mask_count,), idx_val, dtype=torch.long, device=tensor.device ) @@ -306,6 +311,7 @@ def _( from .ref_tile import RefTile if extra_mask is None: + # pyrefly: ignore [bad-argument-type] return tensor[tuple(index)] # Create zero result matching mask shape diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py index bd852886f..35d385003 100644 --- a/helion/language/random_ops.py +++ b/helion/language/random_ops.py @@ -118,6 +118,7 @@ def _rand_codegen(state: CodegenState) -> ast.AST: broadcast_slice = StackIndexingStrategy.get_element_broadcast_slice(i, ndim) broadcasted_index = f"{index_vars[i]}{broadcast_slice}" if i < ndim - 1: + # pyrefly: ignore [no-matching-overload] stride_expr = " * ".join(map("({})".format, size_names[i + 1 :])) offset_parts.append(f"{broadcasted_index} * {stride_expr}") else: diff --git a/helion/language/reduce_ops.py b/helion/language/reduce_ops.py index 70343eafc..8d359d847 100644 --- a/helion/language/reduce_ops.py +++ b/helion/language/reduce_ops.py @@ -161,6 +161,7 @@ def wrapped_combine_fn2( left_tuple: tuple[torch.Tensor, ...], right_tuple: tuple[torch.Tensor, ...], ) -> tuple[torch.Tensor, ...]: + # pyrefly: ignore [bad-return] return original_fn(*left_tuple, *right_tuple) combine_fn = wrapped_combine_fn2 @@ -196,6 +197,7 @@ def wrapped_combine_fn2( ] # Iterate over all combinations of non-reduced dimensions + # pyrefly: ignore [no-matching-overload] for idx in itertools.product(*index_iterators): # Gather values along reduction dimensions values_list = [] @@ -396,6 +398,7 @@ def _( # For single tensor or single other value, use mask_node_inputs from .._compiler.node_masking import mask_node_inputs + # pyrefly: ignore [bad-argument-type] mask_node_inputs(actual_node, other=other) # Create output tensors with reduced shape @@ -522,6 +525,7 @@ def _create_reduce_expression( return expr_from_string( template, input_tensor=input_tensor, + # pyrefly: ignore [bad-argument-type] dim_value=ast.Constant(value=dim), ) diff --git a/helion/language/ref_tile.py b/helion/language/ref_tile.py index 5c6ed152f..787e9faf7 100644 --- a/helion/language/ref_tile.py +++ b/helion/language/ref_tile.py @@ -155,6 +155,7 @@ def _handle_getitem( assert isinstance(tensor, torch.Tensor) slice_index = convert_tile_indices_to_slices(index) + # pyrefly: ignore [bad-index] return tensor[slice_index] @classmethod @@ -170,6 +171,7 @@ def _handle_setitem( assert isinstance(value, (int, float, bool, torch.Tensor)) slice_index = convert_tile_indices_to_slices(index) + # pyrefly: ignore [bad-index] target_shape = tensor[slice_index].shape # Slice value tensor to match target shape if needed @@ -181,6 +183,7 @@ def _handle_setitem( slices = create_shape_matching_slices(value.shape, target_shape) value = value[slices] + # pyrefly: ignore [unsupported-operation] tensor[slice_index] = value return None @@ -191,7 +194,7 @@ def __index__(self) -> int: return self.block_size @property - def index(self) -> torch.Tensor: + def index(self) -> torch.Tensor: # pyrefly: ignore [bad-override] """Return tensor of indices for .index attribute access in ref mode.""" from .._compiler.compile_environment import CompileEnvironment diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 8d0d3f42c..a36344661 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -158,7 +158,9 @@ def _(state: CodegenState) -> ast.AST: else: raise NotImplementedError(f"Unsupported signal pad type: {type(signal_pad)}") + # pyrefly: ignore [bad-argument-type] signal_expr = ast.Constant(value=signal) + # pyrefly: ignore [bad-argument-type] update_expr = ast.Constant(value=update) is_scalar = len(shape) == 0 @@ -322,8 +324,10 @@ def _(state: CodegenState) -> ast.AST: is_scalar = len(shape) == 0 + # pyrefly: ignore [bad-argument-type] signal_expr = ast.Constant(value=signal) if wait_for is not None: + # pyrefly: ignore [bad-argument-type] wait_for_expr = ast.Constant(value=wait_for) else: wait_for_expr = ast.Constant(value=0) diff --git a/helion/language/stack_tensor.py b/helion/language/stack_tensor.py index ee23d6835..33d34c1a7 100644 --- a/helion/language/stack_tensor.py +++ b/helion/language/stack_tensor.py @@ -76,6 +76,7 @@ def device(self) -> torch.device: def shape(self) -> torch.Size: return self.dev_ptrs.shape + self.tensor_like.shape + # pyrefly: ignore [bad-override] def __getitem__( self, index: list[object] | torch.Tensor, @@ -92,6 +93,7 @@ def __setitem__( def new_empty( self, *args: Sequence[int | torch.SymInt], **kwargs: dict ) -> torch.Tensor: + # pyrefly: ignore [no-matching-overload] return self.tensor_like.new_empty(*args, **kwargs) # TODO(joydddd): Implement this to support StackTensor in ref mode. @@ -201,6 +203,7 @@ def _(tensor_like: TypeInfo, dev_ptrs: TypeInfo, *, origin: Origin) -> TypeInfo: "tensor_like": tensor_like, } + # pyrefly: ignore [bad-argument-type] return StackTensorType(origin, element_types) diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index dccf56bad..da4678f6d 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -90,6 +90,7 @@ def _(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: @_decorators.codegen(subscript, "triton") def _(state: CodegenState) -> ast.AST: output_keys = [] + # pyrefly: ignore [not-iterable] for val in state.proxy_arg(1): if val is None: output_keys.append("None") @@ -105,6 +106,7 @@ def _(state: CodegenState) -> ast.AST: @_decorators.ref(subscript) def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: + # pyrefly: ignore [bad-index] return tensor[indices] diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index ae5c008f5..c81e6be8b 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -90,10 +90,13 @@ def _check(tensor: torch.Tensor) -> None: over_limit = True tree_map_only(torch.Tensor, _check, args) + # pyrefly: ignore [unbound-name] if index_dtype is None: # Auto-select when not provided return torch.int64 if over_limit else torch.int32 if over_limit: + # pyrefly: ignore [unbound-name] raise exc.InputTensorNumelExceedsIndexType(index_dtype=index_dtype) + # pyrefly: ignore [unbound-name] return index_dtype @@ -119,12 +122,15 @@ def __init__( assert isinstance(fn, types.FunctionType) assert_no_conflicts(fn) self.name: str = fn.__name__ + # pyrefly: ignore [read-only] self.fn: types.FunctionType = fn self.signature: inspect.Signature = inspect.signature(fn) self.settings: Settings = settings or Settings() self._key_fn: Callable[..., Hashable] | None = key self.configs: list[Config] = [ - Config(**c) if isinstance(c, dict) else c for c in configs or [] + # pyrefly: ignore [bad-argument-type] + Config(**c) if isinstance(c, dict) else c + for c in configs or [] ] self._bound_kernels: dict[BoundKernelInMemoryCacheKey, BoundKernel] = {} self._specialize_extra: dict[ @@ -392,8 +398,12 @@ def __init__( patch_inductor_lowerings(), ): try: + # pyrefly: ignore [bad-assignment] self.host_function: HostFunction = HostFunction( - self.kernel.fn, self.fake_args, constexpr_args + # pyrefly: ignore [bad-argument-type] + self.kernel.fn, + self.fake_args, + constexpr_args, ) except Exception: config = self.env.config_spec.default_config() @@ -455,8 +465,10 @@ def to_triton_code( config = self._require_implicit_config() with self.env: if not isinstance(config, Config): + # pyrefly: ignore [bad-argument-type] config = Config(**config) self.env.config_spec.normalize(config) + # pyrefly: ignore [bad-argument-type] root = generate_ast(self.host_function, config, emit_repro_caller) if output_origin_lines is None: output_origin_lines = self.settings.output_origin_lines @@ -480,7 +492,10 @@ def compile_config( if config is None: config = self._require_implicit_config() if not isinstance(config, Config): - config = Config(**config) + config = Config( + # pyrefly: ignore [bad-argument-type] + **config + ) if (rv := self._compile_cache.get(config)) is not None: return rv try: @@ -573,7 +588,10 @@ def set_config(self, config: ConfigLike) -> None: config: The configuration to set. """ if not isinstance(config, Config): - config = Config(**config) + config = Config( + # pyrefly: ignore [bad-argument-type] + **config + ) self._run = self.compile_config(config) self._config = config @@ -641,6 +659,7 @@ def _require_implicit_config(self) -> Config: raise RuntimeError("no config provided and no implicit config available") return config + # pyrefly: ignore [bad-return] def run_ref(self, *args: object) -> _R: # Unwrap ConstExpr arguments clean_args = [] @@ -928,6 +947,7 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: _specialization_extractors: dict[ type[object] | str, Callable[[Kernel, object], Hashable] + # pyrefly: ignore [bad-assignment] ] = { torch.Tensor: _tensor_key, torch.nn.Parameter: _tensor_key, @@ -940,12 +960,16 @@ def _graph_module_key(fn: Kernel, obj: torch.fx.GraphModule) -> Hashable: str: lambda fn, x: x, list: _sequence_key, tuple: _sequence_key, + # pyrefly: ignore [bad-argument-type] dict: lambda fn, x: _mapping_key(fn, x, type(x)), + # pyrefly: ignore [missing-attribute] "namedtuple": lambda fn, x: _mapping_key(fn, x._asdict(), type(x)), + # pyrefly: ignore [no-matching-overload] "dataclass": lambda fn, x: _mapping_key(fn, dataclasses.asdict(x), type(x)), types.FunctionType: _function_key, types.BuiltinFunctionType: lambda fn, x: x, torch.fx.GraphModule: _graph_module_key, + # pyrefly: ignore [missing-attribute] ConstExpr: lambda fn, x: x.value, type(None): lambda fn, x: None, } @@ -984,7 +1008,9 @@ def _find_device(args: tuple[object, ...]) -> torch.device: def _maybe_skip_dtype_check_in_meta_registrations() -> ( contextlib.AbstractContextManager[None, None] ): + # pyrefly: ignore [implicit-import] if hasattr(torch.fx.experimental._config, "skip_dtype_check_in_meta_registrations"): + # pyrefly: ignore [implicit-import, missing-attribute] return torch.fx.experimental._config.patch( skip_dtype_check_in_meta_registrations=True ) diff --git a/helion/runtime/precompile_shim.py b/helion/runtime/precompile_shim.py index f49dd4f33..d8ddfe0b8 100644 --- a/helion/runtime/precompile_shim.py +++ b/helion/runtime/precompile_shim.py @@ -32,6 +32,7 @@ def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], None]: parts so we can wrap it in a subprocess to handle configs that hang in Triton compile and never return. """ + # pyrefly: ignore [bad-argument-type] device = _find_device([*args, *kwargs.values()]) kwargs["debug"] = ( kwargs.get("debug", fn.debug) or os.environ.get("TRITON_DEBUG", "0") == "1" diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index affec721b..8f674feb9 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -163,6 +163,7 @@ def _get_autotune_log_level() -> int: if text.lstrip("+-").isdigit(): return int(text) upper = text.upper() + # pyrefly: ignore [deprecated] level = logging.getLevelName(upper) if isinstance(level, int): return level @@ -248,6 +249,7 @@ def default_autotuner_fn( f"{', '.join(cache_classes.keys())}" ) + # pyrefly: ignore [bad-argument-type] return cache_cls(autotuner_cls(bound_kernel, args, **kwargs)) @@ -484,6 +486,7 @@ def __init__(self, **settings: object) -> None: Initialize the Settings object with the provided dictionary of settings. """ + # pyrefly: ignore [bad-argument-type] super().__init__(**settings) self._check_ref_eager_mode_before_print_output_code() diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index 834bd0140..4b95ffea4 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -64,6 +64,7 @@ def triton_wait_signal( scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, + # pyrefly: ignore [bad-function-definition] sync_before: tl.constexpr = False, ) -> None: """ @@ -83,6 +84,7 @@ def triton_wait_signal( sync_before: Add a CTA sync before the wait (default: False) """ tl.static_assert( + # pyrefly: ignore [missing-attribute] addr.type.is_ptr(), "Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ", ) @@ -135,6 +137,7 @@ def triton_wait_multiple_signal( scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, + # pyrefly: ignore [bad-function-definition] sync_before: tl.constexpr = False, ) -> None: """ From d918f34cfa430d215151a0fc2a62c11d85a67b40 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 15:28:24 -0800 Subject: [PATCH 07/10] Update AGENTS.md and README.md. Update references to pyright and pyre, respectively. --- AGENTS.md | 2 +- CONTRIBUTING.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 85ee4b793..2df89cb7a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,7 +26,7 @@ This document explains how to work effectively in this repository. - Imports: Sorted by Ruff/isort; single import per line. - Helion import pattern: `import helion; import helion.language as hl` (do not `import helion as hl`). - Modules/files: snake_case; tests `test_*.py`; examples `*.py` with `main()`. -- Run `./lint.sh fix` before pushing; CI uses Ruff and Pyright. +- Run `./lint.sh fix` before pushing; CI uses Ruff and Pyrefly. ## Testing Guidelines diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4ea0c9a62..6e13c3071 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,7 +28,7 @@ outlined on that page and do not file a public issue. ## Coding Style * Code is formatted and checked using [ruff](https://docs.astral.sh/ruff/formatter/). -* All files must be typed with [pyre](https://pyre-check.org/). +* All files must be typed with [pyrefly](https://pyrefly.org/). * Run `./lint.sh install && ./lint.sh` to check your code. Many formatting issues can be fixed automatically. ## License From 48ef0a169b35b90a57be8380b3c2155f5cf4baac Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 15:36:45 -0800 Subject: [PATCH 08/10] Add back some missing import suppressions. --- benchmarks/run.py | 4 ++++ docs/conf.py | 1 + 2 files changed, 5 insertions(+) diff --git a/benchmarks/run.py b/benchmarks/run.py index 969253494..cb1002a39 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -661,6 +661,7 @@ def check_and_setup_tritonbench() -> None: installing_marker = (benchmarks_dir / ".tritonbench_installing").resolve() try: + # pyrefly: ignore [missing-import] import tritonbench module_file = getattr(tritonbench, "__file__", None) @@ -783,6 +784,7 @@ def is_local(path: Path) -> bool: importlib.invalidate_caches() try: + # pyrefly: ignore [missing-import] import tritonbench print("Tritonbench installed successfully.", file=sys.stderr) @@ -871,6 +873,7 @@ def run_kernel_variants( """Run kernel variants in the same benchmark run.""" # Import tritonbench components + # pyrefly: ignore [missing-import] from tritonbench.utils.parser import get_parser from tritonbench.utils.triton_op import BenchmarkOperator from tritonbench.utils.triton_op import BenchmarkOperatorMetrics @@ -940,6 +943,7 @@ def run_kernel_variants( sys.exit(1) # Import register_benchmark API + # pyrefly: ignore [missing-import] from tritonbench.utils.triton_op import register_benchmark # Register all variants as separate methods diff --git a/docs/conf.py b/docs/conf.py index c339fbd23..a4cb49963 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,6 +9,7 @@ from typing import Callable from typing import Protocol +# pyrefly: ignore [missing-import] import pytorch_sphinx_theme2 # -- Path setup -------------------------------------------------------------- From 630bf3625ee22a9442cf54f9eda1d539602a31e9 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 15:37:38 -0800 Subject: [PATCH 09/10] Remove pyrefly version from lint.yml. --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 262ac84f2..dae418310 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -40,7 +40,7 @@ jobs: - name: Install lint dependencies run: | source .venv/bin/activate - uv pip install pyrefly==0.42.0 + uv pip install pyrefly uv pip install .'[dev]' - name: Run pre-commit From fd16c57d620e987b39a4923f6e24dca62d0b0c92 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Mon, 17 Nov 2025 15:47:31 -0800 Subject: [PATCH 10/10] Check benchmarks/ and docs/ directories and add corresponding suppressions. --- benchmarks/run.py | 8 ++++++++ pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index cb1002a39..5ccb5cdd4 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -103,6 +103,7 @@ class RunResult: # - Single kernel with args: (tritonbench_module, helion_module, helion_func, args_dict) # - Multiple kernels: (tritonbench_module, [(helion_module, helion_func), ...]) # - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict) +# pyrefly: ignore [bad-assignment] KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # : (, , ) "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), @@ -875,7 +876,11 @@ def run_kernel_variants( # Import tritonbench components # pyrefly: ignore [missing-import] from tritonbench.utils.parser import get_parser + + # pyrefly: ignore [missing-import] from tritonbench.utils.triton_op import BenchmarkOperator + + # pyrefly: ignore [missing-import] from tritonbench.utils.triton_op import BenchmarkOperatorMetrics # Get the tritonbench operator name, stripping -bwd suffix for backward operators @@ -1092,11 +1097,14 @@ def accuracy_fail_hook( ) try: + # pyrefly: ignore [missing-import] from tritonbench.run import run as tritonbench_run except ImportError: try: + # pyrefly: ignore [missing-import] from tritonbench.utils.run_utils import tritonbench_run except ImportError: + # pyrefly: ignore [missing-import] from pytorch.tritonbench.run import run as tritonbench_run with tempfile.NamedTemporaryFile(mode="w+t", suffix=".csv") as tmp: diff --git a/pyproject.toml b/pyproject.toml index 70cf212c7..3caa85c4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ allow-direct-references = true source = "vcs" [tool.pyrefly] -project-includes = ["helion", "examples"] +project-includes = ["helion", "benchmarks", "docs", "examples"] project-excludes = ["test"] python-version = "3.10"