diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 5b204e99fcb..9c543bcec21 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -81,7 +81,25 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Size.Ge(lambda deps, r, d: 1), max_size_constraint, ] - else: + elif index == 1: # input tensor(a) + tensor_constraints = [ + cp.Dtype.In( + lambda deps: [ + torch.int8, + torch.int16, + torch.uint8, + torch.uint16, + torch.int32, + torch.float32, + ] + ), + cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), + cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Rank.Ge(lambda deps: 1), + cp.Size.Ge(lambda deps, r, d: 1), + max_size_constraint, + ] + else: # input tensor(b) tensor_constraints = [ cp.Dtype.In( lambda deps: [ @@ -93,6 +111,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: torch.float32, ] ), + cp.Dtype.Eq(lambda deps: deps[1].dtype), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1),