Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions docs/deployment_autotuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,122 @@ def my_kernel(x, y):

See {doc}`api/kernel` for the full decorator reference.

## Selective Shape Specialization

The `static_shapes` setting is all-or-nothing: either every dimension is
specialized (`static_shapes=True`) or dimensions are bucketed dynamically
(`static_shapes=False`). Sometimes you want finer control - specializing
only specific dimensions while keeping others dynamic.

Helion provides two APIs for selective shape specialization:

| API | Location | Effect |
|-----|----------|--------|
| `hl.specialize()` | Inside kernel | Dimension always specialized for all calls |
| `torch._dynamo.mark_static()` | Outside kernel | Dimension specialized only for marked tensors |

### `hl.specialize()` - Internal Specialization

Use {func}`~helion.language.specialize` inside the kernel to make specific
dimensions compile-time constants. This applies to **every call** to the kernel:

```python
import torch
import helion
import helion.language as hl

@helion.kernel(static_shapes=False)
def rms_norm_fwd(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5
) -> torch.Tensor:
m, n = x.size()
hl.specialize(n) # hidden dimension becomes a compile-time constant
out = torch.empty_like(x)
for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)
x_squared = x_tile * x_tile
mean_x_squared = torch.mean(x_squared, dim=-1)
inv_rms = torch.rsqrt(mean_x_squared + eps)
normalized = x_tile * inv_rms[:, None]
out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
return out

# Every call specializes on n - different hidden sizes = different cache entries
weight_4096 = torch.randn([4096], device="cuda")
weight_2048 = torch.randn([2048], device="cuda")
result1 = rms_norm_fwd(torch.randn([2048, 4096], device="cuda"), weight_4096) # compiles for n=4096
result2 = rms_norm_fwd(torch.randn([1024, 4096], device="cuda"), weight_4096) # reuses n=4096
result3 = rms_norm_fwd(torch.randn([2048, 2048], device="cuda"), weight_2048) # compiles for n=2048
```

Use `hl.specialize()` when a dimension is performance-critical and you want
it specialized regardless of how the kernel is called.

### `torch._dynamo.mark_static()` - External Specialization

Use `torch._dynamo.mark_static()` **before** calling the kernel to specialize
dimensions on specific tensors. This is useful when you want the **same kernel**
to serve both dynamic and specialized code paths:

```python
@helion.kernel(static_shapes=False)
def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
m, k = x.size()
k2, n = y.size()
out = torch.empty([m, n], device=x.device, dtype=x.dtype)
for tile_m, tile_n in hl.tile([m, n]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(k):
acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n])
out[tile_m, tile_n] = acc.to(x.dtype)
return out

# Dynamic call - all dimensions remain symbolic
x_dyn = torch.randn([m, k], device="cuda", dtype=torch.float16)
y_dyn = torch.randn([k, n], device="cuda", dtype=torch.float16)
result = matmul(x_dyn, y_dyn)

# Specialized call - mark specific dimensions as compile-time constants
x_opt = torch.randn([64, 128], device="cuda", dtype=torch.float16)
y_opt = torch.randn([128, 56], device="cuda", dtype=torch.float16)
torch._dynamo.mark_static(x_opt, [0, -1]) # specialize dims 0 and -1 (M and K)
torch._dynamo.mark_static(y_opt, 1) # specialize dim 1 (N)
result = matmul(x_opt, y_opt) # generates code with 64, 128, 56 as constants
```

This pattern enables a **single kernel definition** to serve both:
- Fully dynamic fallback paths (for rare edge-case shapes)
- Optimized hot paths (with shape constants baked into generated code)

### Combining Both APIs

The two APIs form a **union** - you can use `hl.specialize()` for dimensions
that should always be specialized, and `mark_static()` for additional
per-call specialization:

```python
@helion.kernel(static_shapes=False)
def fn(x: torch.Tensor) -> torch.Tensor:
hl.specialize(x.size(0)) # dim 0 always specialized (internal)
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] * 2
return out

# mark_static on dim 1 combines with hl.specialize on dim 0
x = torch.randn([320, 640], device="cuda")
torch._dynamo.mark_static(x, -1) # specialize dim 1 (external)
result = fn(x) # both 320 and 640 become constants
```

### Cache Behavior

Each unique combination of specialized dimension values creates a separate
cache entry:
- Unspecialized calls share one dynamic cache entry
- Calls with `mark_static()` create entries keyed by the specialized values
- Different specialized values (e.g., `[64, 128]` vs `[48, 96]`) create separate entries

## Advanced Manual Deployment

Some teams prefer to skip all runtime selection, using Helion only as
Expand Down
10 changes: 10 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ def __init__(
0 # Track number of loads in all device code for eviction policy tuning
)

def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
"""Substitute any specialized vars with their concrete values."""
if subs := {
s: sympy.Integer(self.shape_env.size_hint(s))
for s in expr.free_symbols & self.specialized_vars
}:
# pyrefly: ignore [bad-assignment]
expr = expr.xreplace(subs)
return expr

def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
from .device_function import contains_only_block_size_symbols

Expand Down
9 changes: 3 additions & 6 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ def set_pid(self, pid: ProgramIDs) -> None:
self.pid = pid

def sympy_expr(self, expr: sympy.Expr) -> str:
expr = CompileEnvironment.current().shape_env.simplify(expr)
env = CompileEnvironment.current()
expr = env.specialize_expr(env.shape_env.simplify(expr))
if not expr.free_symbols:
return texpr(expr)
if expr in self.expr_to_var_info:
Expand All @@ -394,6 +395,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
replacements[sym] = sympy.Symbol(
self._lift_sympy_arg(sym), integer=True
)
# pyrefly: ignore [bad-argument-type]
return texpr(expr.xreplace(replacements))

def _lift_sympy_arg(self, expr: sympy.Expr) -> str:
Expand Down Expand Up @@ -615,11 +617,6 @@ def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
if isinstance(v, int):
if env.settings.static_shapes:
return StaticShape(v)
else:
# Check if all free symbols are specialized
syms = v._sympy_().free_symbols
if syms and syms <= env.specialized_vars:
return StaticShape(int(v))
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")

def sorted_args(self) -> list[Argument]:
Expand Down
6 changes: 5 additions & 1 deletion helion/_compiler/host_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,18 @@ def set_local_types(self, local_types: dict[str, TypeInfo]) -> None:
type_info.populate_symbol_origins(NameOrigin(name, fn))

def sympy_expr(self, expr: sympy.Expr) -> str:
expr = CompileEnvironment.current().shape_env.simplify(expr)
env = CompileEnvironment.current()
expr = env.specialize_expr(env.shape_env.simplify(expr))
if not expr.free_symbols:
return pexpr(expr)
if expr in self.expr_to_origin:
return self.expr_to_origin[expr].origin.host_str()
replacements = {}
for sym in sorted(expr.free_symbols, key=lambda x: x.name):
assert isinstance(sym, sympy.Symbol)
origin = self.expr_to_origin[sym].origin
replacements[sym] = sympy.Symbol(origin.host_str(), integer=True)
# pyrefly: ignore [bad-argument-type]
return pexpr(expr.xreplace(replacements))

def literal_expr(self, expr: object) -> str:
Expand Down
8 changes: 8 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,14 @@ def assertNotIn(
if not self._in_ref_eager_mode:
super().assertNotIn(member, container, msg) # type: ignore[misc]

def assertIs(self, expr1: object, expr2: object, msg: str | None = None) -> None:
if not self._in_ref_eager_mode:
super().assertIs(expr1, expr2, msg) # type: ignore[misc]

def assertIsNot(self, expr1: object, expr2: object, msg: str | None = None) -> None:
if not self._in_ref_eager_mode:
super().assertIsNot(expr1, expr2, msg) # type: ignore[misc]

def assertTrueIfInNormalMode(self, condition: bool, msg: str | None = None) -> None:
if not self._in_ref_eager_mode:
self.assertTrue(condition, msg) # type: ignore[attr-defined]
Expand Down
21 changes: 21 additions & 0 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ def __init__(
constexpr_args[name] = arg
else:
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))

self._apply_mark_static(args)

with (
_maybe_skip_dtype_check_in_meta_registrations(),
patch_inductor_lowerings(),
Expand All @@ -420,6 +423,20 @@ def __init__(
self.maybe_log_repro(log.warning, args, config=config)
raise

def _apply_mark_static(self, args: tuple[object, ...]) -> None:
"""
Apply torch._dynamo.mark_static() markings from input tensors.

This reads _dynamo_static_indices from each tensor argument and marks
the corresponding dimensions as specialized (constant) in the kernel.
"""
for arg, fake_arg in zip(args, self.fake_args, strict=True):
if isinstance(arg, torch.Tensor) and isinstance(fake_arg, torch.Tensor):
for dim in getattr(arg, "_dynamo_static_indices", ()):
size = fake_arg.size(dim)
if isinstance(size, torch.SymInt):
self.env.specialized_vars.update(size._sympy_().free_symbols)

@property
def settings(self) -> Settings:
"""
Expand Down Expand Up @@ -891,12 +908,14 @@ def kernel(
def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
# NOTE: If a machine has two different gpu types on the same machine,
# obj.device.type will incorrectly hit
static_indices = frozenset(getattr(obj, "_dynamo_static_indices", ()))
if fn.settings.static_shapes:
return (
obj.dtype,
obj.device.type,
(*obj.size(),),
(*obj.stride(),),
static_indices,
)
bucketed = tuple([min(s, 2) for s in obj.size()])
if fn.settings.index_dtype is None:
Expand All @@ -909,11 +928,13 @@ def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable:
obj.device.type,
bucketed,
needs_int64,
static_indices,
)
return (
obj.dtype,
obj.device.type,
bucketed,
static_indices,
)


Expand Down
16 changes: 0 additions & 16 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -460,27 +460,11 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la
_RDIM_SIZE_2 = 64
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
_BLOCK_SIZE_0 = 1
# src[attention.py:N]: q = q_view[tile_b, tile_m, :]
_SHAPE_DIM = q_in.size(3)
_SHAPE_DIM_1 = q_in.size(3)
_SHAPE_DIM_2 = q_in.size(3)
# src[attention.py:N]: for tile_n in hl.tile(v_view.size(1)):
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
# src[attention.py:N]: qk = torch.bmm(q, k)
# src[attention.py:N-N]: ...
_BLOCK_SIZE_3 = 32
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
_SHAPE_DIM_3 = q_in.size(3)
_SHAPE_DIM_4 = q_in.size(3)
_SHAPE_DIM_5 = q_in.size(3)
# src[attention.py:N]: v = v_view[tile_b, tile_n, :]
_SHAPE_DIM_6 = q_in.size(3)
_SHAPE_DIM_7 = q_in.size(3)
_SHAPE_DIM_8 = q_in.size(3)
# src[attention.py:N]: out[tile_b, tile_m, :] = acc.to(out.dtype)
_SHAPE_DIM_9 = q_in.size(3)
_SHAPE_DIM_10 = q_in.size(3)
_SHAPE_DIM_11 = q_in.size(3)
# src[attention.py:N]: for tile_b, tile_m in hl.tile([q_view.size(0), m_dim]):
# src[attention.py:N]: m_i = hl.full([tile_b, tile_m], float("-inf"), dtype=torch.float32)
# src[attention.py:N]: l_i = torch.full_like(m_i, 1.0)
Expand Down
Loading