diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dae418310..0491d371d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -46,4 +46,6 @@ jobs: - name: Run pre-commit run: | source .venv/bin/activate + # TODO(rchen152): Delete this when warnings do not cause failures + sed -i '/search-path/d' pyproject.toml pre-commit run --all-files diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 9628fe0c9..09d63af11 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -243,6 +243,7 @@ def create_unbacked_symint(self, hint: int = 8192) -> torch.SymInt: sym = self.shape_env.create_unbacked_symint(source=source) # TODO(jansel): this is a hack to get us past some == 1 checks # we should probably have a better way to handle this + # type: ignore [unsupported-operation] self.shape_env.var_to_val[sym._sympy_()] = sympy.sympify(hint) return sym @@ -639,6 +640,7 @@ def _to_sympy(x: int | torch.SymInt | sympy.Expr) -> sympy.Expr: return sympy.Integer(x) if isinstance(x, sympy.Expr): return x + # type: ignore [missing-attribute] return sympy.sympify(x) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 4f64aac87..95d48ab3e 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -557,6 +557,7 @@ def _format_constexpr_value(self, value: object) -> str: # Handle sympy expressions (sanitize by replacing triton_helpers functions) if isinstance(value, sympy.Expr): + # type: ignore [missing-attribute] sanitized = value.replace( lambda node: isinstance(node, sympy.Function) and getattr(node.func, "__name__", "") diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 615cb353b..d223cbd51 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -344,6 +344,7 @@ def _unpack_symint(x: torch.SymInt | int) -> sympy.Expr: if isinstance(x, torch.SymInt): return x._sympy_() if isinstance(x, int): + # type: ignore [bad-return] return sympy.sympify(x) raise TypeError(f"Expected SymInt or int, got {type(x)}") diff --git a/pyproject.toml b/pyproject.toml index 3caa85c4a..bc76c5127 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ source = "vcs" project-includes = ["helion", "benchmarks", "docs", "examples"] project-excludes = ["test"] python-version = "3.10" +search-path = ["../pytorch"] [tool.codespell] ignore-words = "scripts/dictionary.txt"