diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 7a7afbac128..fd056cd08cc 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -23,23 +23,49 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: - tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.float32, - ] - ), - cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), - cp.Value.Le(lambda deps, dtype, struct: 2**4), - cp.Rank.Ge(lambda deps: 1), - cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), - cp.Rank.Le(lambda deps: 2**3), - ] + tensor_constraints = ( + [ + cp.Dtype.In( + lambda deps: [ + torch.int8, + torch.int16, + torch.uint8, + torch.uint16, + torch.int32, + torch.float32, + ] + ), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 2**9), + cp.Rank.Le(lambda deps: 2**3), + ] + if op_name + not in ( + "slice_copy.Tensor", + "add.Scalar", + "sub.Scalar", + "mul.Scalar", + "div.Tensor", + "neg.default", + ) + else [ + cp.Dtype.In( + lambda deps: [ + torch.int32, + torch.float32, + ] + ), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.Le(lambda deps, r, d: 2**9), + cp.Rank.Le(lambda deps: 2**3), + ] + ) match op_name: case "where.self": @@ -60,6 +86,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: torch.int16, torch.uint8, torch.uint16, + torch.int32, torch.float32, ] ), @@ -143,6 +170,9 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: tensor_constraints.extend( [ cp.Value.Ne(lambda deps, dtype, struct: 0), + cp.Value.Le(lambda deps, dtype, struct: 2**3), + cp.Size.Le(lambda deps, r, d: 2**3), + cp.Rank.Le(lambda deps: 2**2), ] ) case "div.Tensor_mode" | "minimum.default":