Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 130 additions & 43 deletions backends/cadence/utils/facto_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.inputgen.specs.model import ConstraintProducer as cp
from facto.inputgen.utils.random_manager import seeded_random_manager as rm
from facto.inputgen.variable.type import ScalarDtype
from facto.specdb.db import SpecDictDB

Expand All @@ -26,6 +27,33 @@
_shape_cache: dict[str, list[int]] = {}


def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int, ...]]:
"""
Generate valid permutations using only positive dimension indices.
This is required for Cadence/Xtensa kernels that don't support negative indexing.

Args:
tensor: Input tensor to generate permutations for
length: Number of dimensions in the permutation (must equal tensor.dim())

Returns:
Set of valid permutation tuples containing only positive indices [0, rank-1]
"""
if length > tensor.dim():
return set()

n = tensor.dim()
pool = list(range(n))

# Generate multiple valid permutations (only positive indices)
permutations: set[tuple[int, ...]] = set()
for _ in range(3): # Generate 3 different permutations for diversity
perm = tuple(rm.get_random().sample(pool, length))
permutations.add(perm)

return permutations


def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
# Constraint to limit tensor size to < 4000 bytes with fully randomized shapes
import random
Expand Down Expand Up @@ -161,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
if index == 0: # condition
tensor_constraints = [
cp.Dtype.In(lambda deps: [torch.bool]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Value.Ge(lambda deps, dtype, struct: 0),
cp.Value.Le(lambda deps, dtype, struct: 1),
cp.Rank.Ge(lambda deps: 1),
cp.Size.Ge(lambda deps, r, d: 1),
max_size_constraint,
]
elif index == 1: # input tensor(a)
tensor_constraints = [
cp.Dtype.In(
lambda deps: [
torch.int8,
torch.int16,
torch.uint8,
torch.uint16,
torch.int32,
torch.float32,
]
),
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Rank.Ge(lambda deps: 1),
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.In(
lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)
),
max_size_constraint,
]
else: # input tensor(b)
tensor_constraints = [
cp.Dtype.In(
lambda deps: [
torch.int8,
torch.int16,
torch.uint8,
torch.uint16,
torch.int32,
torch.float32,
]
),
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Dtype.Eq(lambda deps: deps[1].dtype),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Rank.Ge(lambda deps: 1),
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.In(
lambda deps, r, d: fn.broadcast_with(
fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d
)
),
max_size_constraint,
]
case "embedding.default":
Expand Down Expand Up @@ -248,6 +266,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
# Avoid NaN/Inf values that expose clamp NaN handling bugs
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
]
)
case "rsqrt.default":
Expand Down Expand Up @@ -323,12 +344,15 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
]
)
case "constant_pad_nd.default":
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
tensor_constraints = [
cp.Dtype.In(lambda deps: [torch.float32]),
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
cp.Value.Le(lambda deps, dtype, struct: 2**4),
cp.Rank.Ge(lambda deps: 1),
cp.Rank.Le(lambda deps: 2), # Reduced from 3 to 2 (max 2D tensors)
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.Le(lambda deps, r, d: 3), # Max dimension size of 3
]
case "avg_pool2d.default":
tensor_constraints.extend(
[
Expand All @@ -344,14 +368,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
]
)
case "div.Tensor":
tensor_constraints.extend(
[
cp.Value.Ne(lambda deps, dtype, struct: 0),
cp.Value.Le(lambda deps, dtype, struct: 2**3),
cp.Size.Le(lambda deps, r, d: 2**3),
cp.Rank.Le(lambda deps: 2**2),
]
)
if index == 1: # Only apply zero-prevention to divisor
tensor_constraints.extend(
[
cp.Value.Ne(
lambda deps, dtype, struct: 0
), # Prevent division by zero
cp.Value.Le(lambda deps, dtype, struct: 2**3),
cp.Size.Le(lambda deps, r, d: 2**3),
cp.Rank.Le(lambda deps: 2**2),
]
)
else:
tensor_constraints.extend(
[
cp.Value.Le(lambda deps, dtype, struct: 2**3),
cp.Size.Le(lambda deps, r, d: 2**3),
cp.Rank.Le(lambda deps: 2**2),
]
)
case "pow.Tensor_Scalar":
tensor_constraints.extend(
[
Expand Down Expand Up @@ -405,6 +440,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
case "flip.default":
tensor_constraints.extend(
[
cp.Dtype.In(lambda deps: [torch.float32]),
]
)
case _:
pass
return tensor_constraints
Expand All @@ -418,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
| "mul.Scalar"
| "div.Scalar"
| "constant_pad_nd.default"
| "clamp.default"
):
return [ScalarDtype.int]
case "full.default":
Expand Down Expand Up @@ -445,7 +487,32 @@ def facto_testcase_gen( # noqa: C901
cp.Size.Le(lambda deps, r, d: 2**2),
]
)
if in_spec.name == "max_val": # hardtanh
# Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None
if op_name == "clamp.default":
if in_spec.name == "min":
# min must always be provided (not None) and bounded, leave room for max
spec.inspec[index].constraints.extend(
[
cp.Optional.Eq(lambda deps: False), # Never None
cp.Value.Ge(lambda deps, dtype: -(2**4)),
cp.Value.Le(
lambda deps, dtype: 2**4 - 2
), # Leave room for max (at least 2 units)
]
)
elif in_spec.name == "max":
# max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded
spec.inspec[index].deps = [0, 1] # deps on input tensor and min
spec.inspec[index].constraints.extend(
[
cp.Optional.Eq(lambda deps: False), # Never None
cp.Value.Ge(
lambda deps, dtype: deps[1] + 2
), # max >= min + 2 (sufficient gap)
cp.Value.Le(lambda deps, dtype: 2**4),
]
)
elif in_spec.name == "max_val": # hardtanh
spec.inspec[index].deps = [0, 1]
spec.inspec[index].constraints.extend(
[cp.Value.Ge(lambda deps, _: deps[1])]
Expand Down Expand Up @@ -482,12 +549,32 @@ def facto_testcase_gen( # noqa: C901
apply_tensor_contraints(op_name, index)
)
elif in_spec.type.is_dim_list():
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Optional.Eq(lambda deps: False),
]
)
# Special handling for permute_copy.default to ensure valid permutation
if op_name == "permute_copy.default":
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Length.Eq(
lambda deps: deps[0].dim()
), # Must be a complete permutation
cp.Optional.Eq(lambda deps: False),
# Generate valid permutations using only positive indices
# Cadence/Xtensa hardware kernels do not support negative dimension indices
cp.Value.Gen(
lambda deps, length: (
_positive_valid_dim_list(deps[0], length),
fn.invalid_dim_list(deps[0], length),
)
),
]
)
else:
spec.inspec[index].constraints.extend(
[
cp.Length.Ge(lambda deps: 1),
cp.Optional.Eq(lambda deps: False),
]
)
elif in_spec.type.is_bool():
spec.inspec[index].constraints.extend(
[
Expand Down
Loading