Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions helion/_compiler/helper_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def create_combine_function_wrapper(
Args:
combine_fn: The original combine function
is_tuple_input: Whether the input is a tuple
target_format: Either 'tuple' or 'unpacked' format
- 'tuple': expects (left_tuple, right_tuple) for tuple inputs
- 'unpacked': expects (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) for tuple inputs
target_format: Either 'tuple' or 'unpacked'. The 'tuple' option expects
(left_tuple, right_tuple) inputs, while 'unpacked' expects
(left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) inputs

Returns:
A wrapper function that converts between the formats
Expand Down
1 change: 1 addition & 0 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def wait_for_all(

Args:
futures: A list of PrecompileFuture objects.
desc: Optional description used for the progress display.

Returns:
A list of boolean values indicating completion status.
Expand Down
7 changes: 5 additions & 2 deletions helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def zeros(
Args:
shape: A list of sizes (or tile indices which are implicitly converted to sizes)
dtype: Data type of the tensor (default: torch.float32)
device: Device must match the current compile environment device

Returns:
torch.Tensor: A device tensor of the given shape and dtype filled with zeros
Expand Down Expand Up @@ -82,6 +83,7 @@ def full(
shape: A list of sizes (or tile indices which are implicitly converted to sizes)
value: The value to fill the tensor with
dtype: The data type of the tensor (default: torch.float32)
device: Device must match the current compile environment device

Returns:
torch.Tensor: A device tensor of the given shape and dtype filled with value
Expand Down Expand Up @@ -192,9 +194,10 @@ def arange(
automatically using the current kernel's device and index dtype.

Args:
*args: Variable arguments passed to torch.arange(start, end, step).
args: Positional arguments passed to torch.arange(start, end, step).
dtype: Data type of the result tensor (defaults to kernel's index dtype)
**kwargs: Additional keyword arguments passed to torch.arange
device: Device must match the current compile environment device
kwargs: Additional keyword arguments passed to torch.arange

Returns:
torch.Tensor: 1D tensor containing the sequence
Expand Down
1 change: 1 addition & 0 deletions helion/language/random_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def rand(
Args:
shape: A list of sizes for the output tensor
seed: A single element int64 tensor or int literal
device: Device must match the current compile environment device

Returns:
torch.Tensor: A device tensor of float32 dtype filled with uniform random values in [0, 1)
Expand Down
17 changes: 10 additions & 7 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def autotune(
Args:
args: Example arguments used for benchmarking during autotuning.
force: If True, force full autotuning even if a config is provided.
**options: Additional options for autotuning.
options: Additional keyword options forwarded to the autotuner.

Returns:
Config: The best configuration found during autotuning.
Expand Down Expand Up @@ -490,7 +490,7 @@ def autotune(
Args:
args: Example arguments used for benchmarking during autotuning.
force: If True, force full autotuning even if a config is provided.
**kwargs: Additional options for autotuning.
kwargs: Additional keyword options forwarded to the autotuner.

Returns:
Config: The best configuration found during autotuning.
Expand Down Expand Up @@ -678,13 +678,16 @@ def kernel(

Args:
fn: The function to be wrapped by the Kernel. If None, a decorator is returned.
config: A single configuration to use for the kernel. See :class:`~helion.Config` for details.
configs: A list of configurations to use for the kernel. Can only specify one of config or configs.
See :class:`~helion.Config` for details.
config: A single configuration to use for the kernel. Refer to the
``helion.Config`` class for details.
configs: A list of configurations to use for the kernel. Can only specify
one of config or configs. Refer to the ``helion.Config`` class for
details.
key: Optional callable returning a hashable that augments the specialization key.
settings: Keyword arguments representing settings for the Kernel.
Can also use settings=Settings(...) to pass a Settings object directly.
See :class:`~helion.Settings` for available options.
Can also use settings=Settings(...) to pass a Settings object
directly. Refer to the ``helion.Settings`` class for available
options.

Returns:
object: A Kernel object or a decorator that returns a Kernel object.
Expand Down
11 changes: 8 additions & 3 deletions helion/runtime/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ def triton_wait_multiple_signal(
sync_before: tl.constexpr = False, # pyright: ignore[reportArgumentType]
) -> None:
"""
Simultenuoslly wait for multiple global memory barrier to reach the expected value.
Simultaneously wait for multiple global memory barriers to reach the
expected value.

This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs.
Each thread in a CTA spin-waits and continuously checks its assigned memory
location until it reaches the expected value, providing synchronization
across CTAs.

Args:
addr: Memory addresses of the barriers to wait on (Maximum 32 barriers)
Expand All @@ -149,7 +152,9 @@ def triton_wait_multiple_signal(
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "ld", "atomic_cas"
skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False)
skip_sync: Skip CTA synchronization after acquiring the barrier
(default False).
sync_before: Add a CTA sync before the wait (default False)
"""
tl.static_assert(
(sem == "acquire" or sem == "relaxed") or sem == "release",
Expand Down
Loading