From 9e05dba478519fd033e04d7ef1a0d86049750fd9 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 15 Oct 2025 14:49:27 -0700 Subject: [PATCH] Docstring updates stack-info: PR: https://github.com/pytorch/helion/pull/952, branch: jansel/stack/197 --- helion/_compiler/helper_function.py | 6 +++--- helion/autotuner/base_search.py | 1 + helion/language/creation_ops.py | 7 +++++-- helion/language/random_ops.py | 1 + helion/runtime/kernel.py | 17 ++++++++++------- helion/runtime/triton_helpers.py | 11 ++++++++--- 6 files changed, 28 insertions(+), 15 deletions(-) diff --git a/helion/_compiler/helper_function.py b/helion/_compiler/helper_function.py index c1b597763..baccb5483 100644 --- a/helion/_compiler/helper_function.py +++ b/helion/_compiler/helper_function.py @@ -80,9 +80,9 @@ def create_combine_function_wrapper( Args: combine_fn: The original combine function is_tuple_input: Whether the input is a tuple - target_format: Either 'tuple' or 'unpacked' format - - 'tuple': expects (left_tuple, right_tuple) for tuple inputs - - 'unpacked': expects (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) for tuple inputs + target_format: Either 'tuple' or 'unpacked'. The 'tuple' option expects + (left_tuple, right_tuple) inputs, while 'unpacked' expects + (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) inputs Returns: A wrapper function that converts between the formats diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index bed4b324a..6091ff77c 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -863,6 +863,7 @@ def wait_for_all( Args: futures: A list of PrecompileFuture objects. + desc: Optional description used for the progress display. Returns: A list of boolean values indicating completion status. diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index ed95c25b2..52d897d22 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -36,6 +36,7 @@ def zeros( Args: shape: A list of sizes (or tile indices which are implicitly converted to sizes) dtype: Data type of the tensor (default: torch.float32) + device: Device must match the current compile environment device Returns: torch.Tensor: A device tensor of the given shape and dtype filled with zeros @@ -82,6 +83,7 @@ def full( shape: A list of sizes (or tile indices which are implicitly converted to sizes) value: The value to fill the tensor with dtype: The data type of the tensor (default: torch.float32) + device: Device must match the current compile environment device Returns: torch.Tensor: A device tensor of the given shape and dtype filled with value @@ -192,9 +194,10 @@ def arange( automatically using the current kernel's device and index dtype. Args: - *args: Variable arguments passed to torch.arange(start, end, step). + args: Positional arguments passed to torch.arange(start, end, step). dtype: Data type of the result tensor (defaults to kernel's index dtype) - **kwargs: Additional keyword arguments passed to torch.arange + device: Device must match the current compile environment device + kwargs: Additional keyword arguments passed to torch.arange Returns: torch.Tensor: 1D tensor containing the sequence diff --git a/helion/language/random_ops.py b/helion/language/random_ops.py index 412ec1f6c..bb1403c20 100644 --- a/helion/language/random_ops.py +++ b/helion/language/random_ops.py @@ -32,6 +32,7 @@ def rand( Args: shape: A list of sizes for the output tensor seed: A single element int64 tensor or int literal + device: Device must match the current compile environment device Returns: torch.Tensor: A device tensor of float32 dtype filled with uniform random values in [0, 1) diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 2ea4224c6..b875caebf 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -268,7 +268,7 @@ def autotune( Args: args: Example arguments used for benchmarking during autotuning. force: If True, force full autotuning even if a config is provided. - **options: Additional options for autotuning. + options: Additional keyword options forwarded to the autotuner. Returns: Config: The best configuration found during autotuning. @@ -490,7 +490,7 @@ def autotune( Args: args: Example arguments used for benchmarking during autotuning. force: If True, force full autotuning even if a config is provided. - **kwargs: Additional options for autotuning. + kwargs: Additional keyword options forwarded to the autotuner. Returns: Config: The best configuration found during autotuning. @@ -678,13 +678,16 @@ def kernel( Args: fn: The function to be wrapped by the Kernel. If None, a decorator is returned. - config: A single configuration to use for the kernel. See :class:`~helion.Config` for details. - configs: A list of configurations to use for the kernel. Can only specify one of config or configs. - See :class:`~helion.Config` for details. + config: A single configuration to use for the kernel. Refer to the + ``helion.Config`` class for details. + configs: A list of configurations to use for the kernel. Can only specify + one of config or configs. Refer to the ``helion.Config`` class for + details. key: Optional callable returning a hashable that augments the specialization key. settings: Keyword arguments representing settings for the Kernel. - Can also use settings=Settings(...) to pass a Settings object directly. - See :class:`~helion.Settings` for available options. + Can also use settings=Settings(...) to pass a Settings object + directly. Refer to the ``helion.Settings`` class for available + options. Returns: object: A Kernel object or a decorator that returns a Kernel object. diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index 87354d45a..88e4cc25b 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -138,9 +138,12 @@ def triton_wait_multiple_signal( sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType] ) -> None: """ - Simultenuoslly wait for multiple global memory barrier to reach the expected value. + Simultaneously wait for multiple global memory barriers to reach the + expected value. - This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs. + Each thread in a CTA spin-waits and continuously checks its assigned memory + location until it reaches the expected value, providing synchronization + across CTAs. Args: addr: Memory addresses of the barriers to wait on (Maximum 32 barriers) @@ -149,7 +152,9 @@ def triton_wait_multiple_signal( sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed". scope: Scope of the atomic operation. Options: "gpu", "sys" op: Atomic operation type: "ld", "atomic_cas" - skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False) + skip_sync: Skip CTA synchronization after acquiring the barrier + (default False). + sync_before: Add a CTA sync before the wait (default False) """ tl.static_assert( (sem == "acquire" or sem == "relaxed") or sem == "release",