Skip to content

Commit

Permalink
chore: move functions to organize code better
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed May 22, 2024
1 parent c4965cd commit af31bce
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,40 @@ def aten_ops_tile(
)


def zero_output_validator(node: Node) -> bool:
if 0 in node.args[1]:
_LOGGER.debug(
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
)
return False
else:
return True


@dynamo_tensorrt_converter(
torch.ops.aten.as_strided.default,
capability_validator=zero_output_validator,
)
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.as_strided(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
size=args[1],
stride=args[2],
storage_offset=args_bounds_check(args, 3, None),
)


@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
@enforce_tensor_types(
{
Expand Down Expand Up @@ -2186,40 +2220,6 @@ def aten_ops_linear(
)


def zero_output_validator(node: Node) -> bool:
if 0 in node.args[1]:
_LOGGER.debug(
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
)
return False
else:
return True


@dynamo_tensorrt_converter(
torch.ops.aten.as_strided.default,
capability_validator=zero_output_validator,
)
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.slice.as_strided(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
size=args[1],
stride=args[2],
storage_offset=args_bounds_check(args, 3, None),
)


def avg_pool_param_validator(pool_node: Node) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)
Expand Down

0 comments on commit af31bce

Please sign in to comment.