Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/experimental/torchfuzz/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self):
# Neural network operations
"torch.nn.functional.embedding",
"torch.nn.functional.linear",
"torch.nn.functional.scaled_dot_product_attention",
# Activation functions
"torch.nn.functional.relu",
"torch.nn.functional.leaky_relu",
Expand Down
2 changes: 2 additions & 0 deletions tools/experimental/torchfuzz/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LayerNormOperator,
LinearOperator,
ReLUOperator,
ScaledDotProductAttentionOperator,
SoftmaxOperator,
)
from torchfuzz.operators.registry import (
Expand Down Expand Up @@ -77,6 +78,7 @@
"EmbeddingOperator",
"LinearOperator",
"ReLUOperator",
"ScaledDotProductAttentionOperator",
"SoftmaxOperator",
"DropoutOperator",
"LayerNormOperator",
Expand Down
218 changes: 218 additions & 0 deletions tools/experimental/torchfuzz/operators/nn_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,221 @@ def codegen(

input_name = input_names[0]
return f"{output_name} = torch.nn.functional.silu({input_name})"


class ScaledDotProductAttentionOperator(Operator):
"""Operator for torch.nn.functional.scaled_dot_product_attention."""

def __init__(self):
super().__init__("torch.nn.functional.scaled_dot_product_attention")

@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.nn.functional.scaled_dot_product_attention"

def can_produce(self, output_spec: Spec) -> bool:
"""Scaled dot product attention can produce tensor outputs with floating point dtypes."""
if not isinstance(output_spec, TensorSpec):
return False
# SDPA needs at least 3 dimensions (batch, seq_len, embed_dim)
if len(output_spec.size) < 3:
return False
return is_float_dtype(output_spec.dtype)

def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for scaled_dot_product_attention.

SDPA requires:
- query: (batch, seq_len, embed_dim) or (batch, num_heads, seq_len, head_dim)
- key: (batch, seq_len, embed_dim) or (batch, num_heads, seq_len_kv, head_dim)
- value: (batch, seq_len, embed_dim) or (batch, num_heads, seq_len_kv, head_dim)
Output shape matches query shape.
"""
if not isinstance(output_spec, TensorSpec):
raise ValueError(
"ScaledDotProductAttentionOperator can only produce TensorSpec outputs"
)

if len(output_spec.size) < 3:
raise ValueError("SDPA output must have at least 3 dimensions")

# Query has the same shape as output
query_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)

# Key and value: match query shape for simplicity
# In practice, seq_len for key/value can differ, but we'll keep it simple
key_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
value_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)

return [query_spec, key_spec, value_spec]

def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for scaled_dot_product_attention operation."""
if len(input_names) != 3:
raise ValueError("SDPA requires exactly 3 inputs: query, key, value")

# Ensure dtype compatibility by converting all inputs to the expected output dtype
target_dtype = str(output_spec.dtype)
query_name, key_name, value_name = input_names
return f"{output_name} = torch.nn.functional.scaled_dot_product_attention({query_name}.to({target_dtype}), {key_name}.to({target_dtype}), {value_name}.to({target_dtype}))"


class MultiHeadAttentionForwardOperator(Operator):
"""Operator for torch.nn.functional.multi_head_attention_forward."""

def __init__(self):
super().__init__("torch.nn.functional.multi_head_attention_forward")

@property
def torch_op_name(self) -> Optional[str]:
"""Return the torch operation name."""
return "torch.nn.functional.multi_head_attention_forward"

def can_produce(self, output_spec: Spec) -> bool:
"""Multi-head attention forward can produce tensor outputs with floating point dtypes."""
if not isinstance(output_spec, TensorSpec):
return False
# MHA needs at least 3 dimensions (seq_len, batch, embed_dim)
if len(output_spec.size) < 3:
return False
# MHA cannot handle 0-sized dimensions (seq_len, batch, or embed_dim must be > 0)
if any(dim == 0 for dim in output_spec.size):
return False
return is_float_dtype(output_spec.dtype)

def fuzz_inputs_specs(self, output_spec: Spec) -> list[Spec]:
"""Generate input specs for multi_head_attention_forward.

MHA requires:
- query, key, value: (seq_len, batch, embed_dim)
- in_proj_weight: (3*embed_dim, embed_dim) for combined QKV projection
- in_proj_bias: (3*embed_dim,) optional
- out_proj_weight: (embed_dim, embed_dim)
- out_proj_bias: (embed_dim,) optional

For simplicity, we'll use the combined in_proj_weight path.

IMPORTANT: The order of optional parameters matters for codegen!
We must ensure that when we have 6 inputs, they are in the order:
query, key, value, in_proj_weight, in_proj_bias, out_proj_weight
NOT: query, key, value, in_proj_weight, out_proj_weight, out_proj_bias
"""
if not isinstance(output_spec, TensorSpec):
raise ValueError(
"MultiHeadAttentionForwardOperator can only produce TensorSpec outputs"
)

if len(output_spec.size) < 3:
raise ValueError("MHA output must have at least 3 dimensions")

# Output shape: (seq_len, batch, embed_dim)
seq_len, batch, embed_dim = output_spec.size[:3]

# Query, key, value have the same shape as output
query_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
key_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)
value_spec = TensorSpec(
size=output_spec.size, stride=output_spec.stride, dtype=output_spec.dtype
)

# in_proj_weight: (3*embed_dim, embed_dim)
in_proj_weight_spec = TensorSpec(
size=(3 * embed_dim, embed_dim),
stride=(embed_dim, 1),
dtype=output_spec.dtype,
)

# out_proj_weight: (embed_dim, embed_dim)
out_proj_weight_spec = TensorSpec(
size=(embed_dim, embed_dim),
stride=(embed_dim, 1),
dtype=output_spec.dtype,
)

# For simplicity and correctness, always generate all required tensors
# This avoids ambiguity in the codegen about which optional parameters are present
# We'll use a simplified signature: query, key, value, in_proj_weight, out_proj_weight only
specs = [
query_spec,
key_spec,
value_spec,
in_proj_weight_spec,
out_proj_weight_spec,
]

from typing import cast

return cast(list[Spec], specs)

def _calculate_stride(self, size):
"""Calculate stride for a given size."""
if not size:
return ()
stride = []
current_stride = 1
for dim_size in reversed(size):
stride.append(current_stride)
current_stride *= dim_size
return tuple(reversed(stride))

def codegen(
self, output_name: str, input_names: list[str], output_spec: Spec
) -> str:
"""Generate code for multi_head_attention_forward operation."""
if len(input_names) != 5:
raise ValueError(
"MHA requires exactly 5 inputs: query, key, value, in_proj_weight, out_proj_weight"
)

if not isinstance(output_spec, TensorSpec):
raise ValueError(
"MultiHeadAttentionForwardOperator can only produce TensorSpec outputs"
)

target_dtype = str(output_spec.dtype)
embed_dim = output_spec.size[-1]

# Determine number of heads (must divide embed_dim evenly)
# Common choices: 8, 4, 2, 1
possible_heads = [h for h in [8, 4, 2, 1] if embed_dim % h == 0]
num_heads = possible_heads[0] if possible_heads else 1

query_name = input_names[0]
key_name = input_names[1]
value_name = input_names[2]
in_proj_weight_name = input_names[3]
out_proj_weight_name = input_names[4]

# Build the function call without optional biases
code = f"""{output_name}, _ = torch.nn.functional.multi_head_attention_forward(
{query_name}.to({target_dtype}),
{key_name}.to({target_dtype}),
{value_name}.to({target_dtype}),
{embed_dim},
{num_heads},
{in_proj_weight_name}.to({target_dtype}),
None, # in_proj_bias
None, # bias_k
None, # bias_v
False, # add_zero_attn
0.0, # dropout_p (no dropout for testing)
{out_proj_weight_name}.to({target_dtype}),
None, # out_proj_bias
training=False, # Use eval mode for deterministic behavior
need_weights=False, # Don't compute attention weights for performance
)"""

return code
2 changes: 2 additions & 0 deletions tools/experimental/torchfuzz/operators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
LinearOperator,
ReLUOperator,
RMSNormOperator,
ScaledDotProductAttentionOperator,
SigmoidOperator,
SiLUOperator,
SoftmaxOperator,
Expand Down Expand Up @@ -101,6 +102,7 @@ def _register_default_operators(self):
# Neural network functional operators
self.register(EmbeddingOperator())
self.register(LinearOperator())
self.register(ScaledDotProductAttentionOperator())

# Activation functions
self.register(ReLUOperator())
Expand Down
Loading