diff --git a/backends/cadence/utils/facto_util.py b/backends/cadence/utils/facto_util.py index 385193776a3..001fc882685 100644 --- a/backends/cadence/utils/facto_util.py +++ b/backends/cadence/utils/facto_util.py @@ -189,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int: if index == 0: # condition tensor_constraints = [ cp.Dtype.In(lambda deps: [torch.bool]), - cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), - cp.Value.Le(lambda deps, dtype, struct: 2**4), + cp.Value.Ge(lambda deps, dtype, struct: 0), + cp.Value.Le(lambda deps, dtype, struct: 1), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), max_size_constraint, ] elif index == 1: # input tensor(a) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d) + ), max_size_constraint, ] else: # input tensor(b) tensor_constraints = [ - cp.Dtype.In( - lambda deps: [ - torch.int8, - torch.int16, - torch.uint8, - torch.uint16, - torch.int32, - torch.float32, - ] - ), + cp.Dtype.In(lambda deps: [torch.float32]), cp.Dtype.Eq(lambda deps: deps[1].dtype), cp.Value.Ge(lambda deps, dtype, struct: -(2**4)), cp.Value.Le(lambda deps, dtype, struct: 2**4), cp.Rank.Ge(lambda deps: 1), cp.Size.Ge(lambda deps, r, d: 1), + cp.Size.In( + lambda deps, r, d: fn.broadcast_with( + fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d + ) + ), max_size_constraint, ] case "embedding.default":