diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0cb162cb3..f41e186da 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,6 +35,8 @@ jobs: - name: Install PyTorch run: | source .venv/bin/activate + # Install nvidia-nvshmem-cu12 from cu129 index (missing on cu128) + uv pip install -U --pre nvidia-nvshmem-cu12 --index-url https://download.pytorch.org/whl/nightly/cu129 uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 - name: Install lint dependencies diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b51f21669..d8d8eada4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -113,6 +113,10 @@ jobs: uv pip install -U "torch==2.9.*" --index-url https://download.pytorch.org/whl/${{ matrix.runtime-version }} else # Default to nightly + if [[ "${{ matrix.runtime-version }}" == "cu128" ]]; then + # Install nvidia-nvshmem-cu12 from cu129 index (missing on cu128) + uv pip install -U --pre nvidia-nvshmem-cu12 --index-url https://download.pytorch.org/whl/nightly/cu129 + fi uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/${{ matrix.runtime-version }} fi diff --git a/benchmarks/run.py b/benchmarks/run.py index 8575703d5..5c4eab490 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -741,7 +741,7 @@ def load_kernel_config( def process_single_kernel_mapping( kernel_name: str, mapping: dict[str, Any] -) -> tuple[str, ...]: +) -> tuple[Any, ...]: """Process a single kernel mapping configuration.""" if not isinstance(mapping, dict): raise ValueError( @@ -785,11 +785,11 @@ def process_single_kernel_mapping( def merge_kernel_configs( - base_mappings: dict[str, tuple[str, ...]], + base_mappings: dict[str, tuple[Any, ...]], base_metrics: dict[str, dict[str, str]], - custom_mappings: dict[str, tuple[str, ...]], + custom_mappings: dict[str, tuple[Any, ...]], custom_metrics: dict[str, dict[str, str]], -) -> tuple[dict[str, tuple[str, ...]], dict[str, dict[str, str]]]: +) -> tuple[dict[str, tuple[Any, ...]], dict[str, dict[str, str]]]: """Merge custom kernel configurations with base configurations. Custom configs extend and can override base configs. diff --git a/helion/_logging/_internal.py b/helion/_logging/_internal.py index 23354cdb7..7dd5d2b40 100644 --- a/helion/_logging/_internal.py +++ b/helion/_logging/_internal.py @@ -2,9 +2,11 @@ from dataclasses import dataclass from dataclasses import field +import functools import logging import os from typing import Callable +from typing import Generic from typing import ParamSpec LOG_ENV_VAR = "HELION_LOGS" @@ -82,14 +84,11 @@ def init_logs() -> None: P = ParamSpec("P") -class LazyString: +class LazyString(Generic[P]): def __init__( self, func: Callable[P, str], *args: P.args, **kwargs: P.kwargs ) -> None: - # pyrefly: ignore [invalid-type-var] - self.func: Callable[P, str] = func - self.args: tuple[object, ...] = args - self.kwargs: object = kwargs + self._callable: Callable[[], str] = functools.partial(func, *args, **kwargs) def __str__(self) -> str: - return self.func(*self.args, **self.kwargs) + return self._callable()