diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 4f0ab13e162..b5c5683ab5d 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -15,6 +15,7 @@ import torch from facto.inputgen.argtuple.gen import ArgumentTupleGenerator from facto.inputgen.specs.model import ConstraintProducer as cp +from facto.inputgen.utils.random_manager import seeded_random_manager as rm from facto.inputgen.variable.type import ScalarDtype from facto.specdb.db import SpecDictDB @@ -26,6 +27,33 @@ _shape_cache: dict[str, list[int]] = {} +def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int, ...]]: + """ + Generate valid permutations using only positive dimension indices. + This is required for Cadence/Xtensa kernels that don't support negative indexing. + + Args: + tensor: Input tensor to generate permutations for + length: Number of dimensions in the permutation (must equal tensor.dim()) + + Returns: + Set of valid permutation tuples containing only positive indices [0, rank-1] + """ + if length > tensor.dim(): + return set() + + n = tensor.dim() + pool = list(range(n)) + + # Generate multiple valid permutations (only positive indices) + permutations: set[tuple[int, ...]] = set() + for _ in range(3): # Generate 3 different permutations for diversity + perm = tuple(rm.get_random().sample(pool, length)) + permutations.add(perm) + + return permutations + + def apply_tensor_contraints(op_name: str, index: int) -> list[object]: # Constraint to limit tensor size to < 4000 bytes with fully randomized shapes import random @@ -161,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: if index == 0: # condition tensor_constraints = [ cp.Dtype.In(lambda deps: [torch.bool]), - cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), - cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Value.Ge(lambda deps, dtype, struct: 0), + cp.Value.Le(lambda deps, dtype, struct: 1), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), max_size_constraint, ] elif index == 1: # input tensor(a) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d) + ), max_size_constraint, ] else: # input tensor(b) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), cp.Dtype.Eq(lambda deps: deps[1].dtype), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with( + fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d + ) + ), max_size_constraint, ] case "embedding.default": @@ -248,6 +266,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: tensor_constraints.extend( [ cp.Dtype.In(lambda deps: [torch.float32, torch.int32]), + # Avoid NaN/Inf values that expose clamp NaN handling bugs + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), ] ) case "rsqrt.default": @@ -323,12 +344,15 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: ] ) case "constant_pad_nd.default": - tensor_constraints.extend( - [ - cp.Dtype.In(lambda deps: [torch.float32]), - cp.Size.Le(lambda deps, r, d: 2**2), - ] - ) + tensor_constraints = [ + cp.Dtype.In(lambda deps: [torch.float32]), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Rank.Le(lambda deps: 2), # Reduced from 3 to 2 (max 2D tensors) + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 3), # Max dimension size of 3 + ] case "avg_pool2d.default": tensor_constraints.extend( [ @@ -344,14 +368,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: ] ) case "div.Tensor": - tensor_constraints.extend( - [ - cp.Value.Ne(lambda deps, dtype, struct: 0), - cp.Value.Le(lambda deps, dtype, struct: 2**3), - cp.Size.Le(lambda deps, r, d: 2**3), - cp.Rank.Le(lambda deps: 2**2), - ] - ) + if index == 1: # Only apply zero-prevention to divisor + tensor_constraints.extend( + [ + cp.Value.Ne( + lambda deps, dtype, struct: 0 + ), # Prevent division by zero + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), + ] + ) + else: + tensor_constraints.extend( + [ + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), + ] + ) case "pow.Tensor_Scalar": tensor_constraints.extend( [ @@ -405,6 +440,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: cp.Size.Le(lambda deps, r, d: 2**2), ] ) + case "flip.default": + tensor_constraints.extend( + [ + cp.Dtype.In(lambda deps: [torch.float32]), + ] + ) case _: pass return tensor_constraints @@ -418,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]: | "mul.Scalar" | "div.Scalar" | "constant_pad_nd.default" + | "clamp.default" ): return [ScalarDtype.int] case "full.default": @@ -445,7 +487,32 @@ def facto_testcase_gen( # noqa: C901 cp.Size.Le(lambda deps, r, d: 2**2), ] ) - if in_spec.name == "max_val": # hardtanh + # Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None + if op_name == "clamp.default": + if in_spec.name == "min": + # min must always be provided (not None) and bounded, leave room for max + spec.inspec[index].constraints.extend( + [ + cp.Optional.Eq(lambda deps: False), # Never None + cp.Value.Ge(lambda deps, dtype: -(2**4)), + cp.Value.Le( + lambda deps, dtype: 2**4 - 2 + ), # Leave room for max (at least 2 units) + ] + ) + elif in_spec.name == "max": + # max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded + spec.inspec[index].deps = [0, 1] # deps on input tensor and min + spec.inspec[index].constraints.extend( + [ + cp.Optional.Eq(lambda deps: False), # Never None + cp.Value.Ge( + lambda deps, dtype: deps[1] + 2 + ), # max >= min + 2 (sufficient gap) + cp.Value.Le(lambda deps, dtype: 2**4), + ] + ) + elif in_spec.name == "max_val": # hardtanh spec.inspec[index].deps = [0, 1] spec.inspec[index].constraints.extend( [cp.Value.Ge(lambda deps, _: deps[1])] @@ -482,12 +549,32 @@ def facto_testcase_gen( # noqa: C901 apply_tensor_contraints(op_name, index) ) elif in_spec.type.is_dim_list(): - spec.inspec[index].constraints.extend( - [ - cp.Length.Ge(lambda deps: 1), - cp.Optional.Eq(lambda deps: False), - ] - ) + # Special handling for permute_copy.default to ensure valid permutation + if op_name == "permute_copy.default": + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Length.Eq( + lambda deps: deps[0].dim() + ), # Must be a complete permutation + cp.Optional.Eq(lambda deps: False), + # Generate valid permutations using only positive indices + # Cadence/Xtensa hardware kernels do not support negative dimension indices + cp.Value.Gen( + lambda deps, length: ( + _positive_valid_dim_list(deps[0], length), + fn.invalid_dim_list(deps[0], length), + ) + ), + ] + ) + else: + spec.inspec[index].constraints.extend( + [ + cp.Length.Ge(lambda deps: 1), + cp.Optional.Eq(lambda deps: False), + ] + ) elif in_spec.type.is_bool(): spec.inspec[index].constraints.extend( [