diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index fd056cd08cc..5b204e99fcb 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -23,6 +23,9 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: + # Constraint to limit tensor size product to < 4000 + max_size_constraint = cp.Size.Le(lambda deps, r, d: max(1, int((3999) ** (1 / r)))) + tensor_constraints = ( [ cp.Dtype.In( @@ -39,7 +42,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, cp.Rank.Le(lambda deps: 2**3), ] if op_name @@ -62,7 +65,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, cp.Rank.Le(lambda deps: 2**3), ] ) @@ -76,7 +79,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, ] else: tensor_constraints = [ @@ -94,7 +97,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, ] case "embedding.default": tensor_constraints = [ @@ -104,7 +107,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]: cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), - cp.Size.Le(lambda deps, r, d: 2**9), + max_size_constraint, ] case "sigmoid.default": tensor_constraints.extend(