From 3ab76d734217546ee9f486d0504e9adaf7664a86 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 9 May 2024 13:33:03 +0900 Subject: [PATCH 1/5] feat: support aten._cdist_forward converter --- .../dynamo/conversion/aten_ops_converters.py | 21 ++++ .../dynamo/conversion/impl/linear.py | 112 +++++++++++++++++- tests/py/dynamo/conversion/test_cdist_aten.py | 64 ++++++++++ 3 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_cdist_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1705dd06db..b5712365d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2186,6 +2186,27 @@ def aten_ops_linear( ) +@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default) +@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward) +def aten_ops_cdist_forward( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.linear.cdist_forward( + ctx, + target, + SourceIR.ATEN, + name, + x1=args[0], + x2=args[1], + p=args[2], + compute_mode=args_bounds_check(args, 3, None), + ) + + def avg_pool_param_validator(pool_node: Node) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py index 69ef73964d..5e23ea0180 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/linear.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Sequence, Union import numpy as np import tensorrt as trt @@ -6,7 +6,11 @@ from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.fx.types import TRTTensor @@ -52,3 +56,107 @@ def linear( out = impl.elementwise.add(ctx, target, source_ir, name, out, bias) return out + + +def cdist_forward( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + x1: TRTTensor, + x2: TRTTensor, + p: float, + compute_mode: Optional[int] = None, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + """ + Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension + of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting + the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances + are computed for each matching set in these dimensions. + + The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are + merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions + (except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting). + + Example: + - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20]. + This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2. + - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), + since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. + """ + x1_expand_shape = list(x1.shape[:-1]) + [1, x1.shape[-1]] + x2_expand_shape = list(x2.shape[:-2]) + [1] + list(x2.shape[-2:]) + + # Reshape x1 and x2 for broadcasting + x1_expanded = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x1_expand", x1, x1_expand_shape + ) + x2_expanded = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x2_expand", x2, x2_expand_shape + ) + + diff = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_diff", x1_expanded, x2_expanded + ) + + if p == 0: + diff_non_zero = impl.elementwise.ne( + ctx, target, source_ir, f"{name}_diff_non_zero", diff, 0 + ) + diff_non_zero = cast_trt_tensor( + ctx, diff_non_zero, torch.float32, f"{name}_cast", target, source_ir + ) + dist = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_sum", + diff_non_zero, + dim=-1, + keepdim=False, + ) + elif p == 1: + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + dist = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", abs_val, dim=-1, keepdim=False + ) + elif p == 2: + diff_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_diff_squared", diff, 2 + ) + dist_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_dist_sq_sum", + diff_squared, + dim=-1, + keepdim=False, + ) + dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared) + elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"): + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + pow_val = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow_val_1", abs_val, p + ) + sum_val = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", pow_val, dim=-1, keepdim=False + ) + dist = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow_val_2", sum_val, 1 / p + ) + elif p == float("inf"): + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + dist = impl.reduce.max( + ctx, + target, + source_ir, + f"{name}_max", + abs_val, + dim=-1, + keepdim=False, + return_indices=False, + ) + else: + raise NotImplementedError(f"Currently, p={p} is not implemented.") + return dist diff --git a/tests/py/dynamo/conversion/test_cdist_aten.py b/tests/py/dynamo/conversion/test_cdist_aten.py new file mode 100644 index 0000000000..51660600cc --- /dev/null +++ b/tests/py/dynamo/conversion/test_cdist_aten.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestCdistConverter(DispatchTestCase): + @parameterized.expand( + [ + ((4, 3, 4), 0), + ((10, 3, 5, 2, 6), 0.5), + ((15, 10, 5), 0), + ((15, 10, 5), 0.5), + ((15, 10, 5), 1.0), + ((15, 10, 5), 1.5), + ((15, 10, 5), 2.0), + ((15, 10, 5), 2.99), + ((15, 10, 5), float("inf")), + ] + ) + def test_cdist_float(self, shape, p): + class Cdist(nn.Module): + def forward(self, x1, x2): + print("x1 : ", x1) + print("x2 : ", x2) + + return torch.ops.aten._cdist_forward.default(x1, x2, p, None) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + Cdist(), + inputs, + ) + + @parameterized.expand( + [ + ((1, 5), (2, 3, 5), 1), + ((4, 5), (2, 3, 5), 1), + ((2, 4, 5), (2, 3, 5), 1), + ((2, 2, 4, 5), (2, 3, 5), 0), + ((2, 2, 4, 5), (2, 3, 5), 0.5), + ((2, 2, 4, 5), (2, 3, 5), 1), + ((2, 2, 4, 5), (2, 3, 5), 1.5), + ((2, 2, 4, 5), (2, 3, 5), 2), + ((2, 2, 4, 5), (2, 3, 5), 2.99), + ((2, 2, 4, 5), (2, 3, 5), float("inf")), + ] + ) + def test_cdist_broadcast_float(self, shape_1, shape_2, p): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, None) + + inputs = [torch.randn(shape_1), torch.randn(shape_2)] + self.run_test( + Cdist(), + inputs, + ) + + +if __name__ == "__main__": + run_tests() From b563763ff74a55dd63761fc39e7dd8f1768c004a Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 9 May 2024 13:33:03 +0900 Subject: [PATCH 2/5] feat: cdist compute_mode compatibility --- .../dynamo/conversion/aten_ops_converters.py | 5 +- .../dynamo/conversion/impl/linear.py | 112 +---------------- .../conversion/impl/normalization/ops.py | 119 ++++++++++++++++++ tests/py/dynamo/conversion/test_cdist_aten.py | 66 ++++++---- 4 files changed, 163 insertions(+), 139 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index b5712365d1..48287b0be0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2187,7 +2187,6 @@ def aten_ops_linear( @dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default) -@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward) def aten_ops_cdist_forward( ctx: ConversionContext, target: Target, @@ -2195,7 +2194,7 @@ def aten_ops_cdist_forward( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.linear.cdist_forward( + return impl.normalization.cdist_forward( ctx, target, SourceIR.ATEN, @@ -2203,7 +2202,7 @@ def aten_ops_cdist_forward( x1=args[0], x2=args[1], p=args[2], - compute_mode=args_bounds_check(args, 3, None), + compute_mode=args[3], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py index 5e23ea0180..69ef73964d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/linear.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import Optional, Union import numpy as np import tensorrt as trt @@ -6,11 +6,7 @@ from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, - cast_trt_tensor, - get_trt_tensor, -) +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.types import TRTTensor @@ -56,107 +52,3 @@ def linear( out = impl.elementwise.add(ctx, target, source_ir, name, out, bias) return out - - -def cdist_forward( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - x1: TRTTensor, - x2: TRTTensor, - p: float, - compute_mode: Optional[int] = None, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - """ - Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension - of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting - the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances - are computed for each matching set in these dimensions. - - The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are - merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions - (except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting). - - Example: - - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20]. - This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2. - - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), - since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. - """ - x1_expand_shape = list(x1.shape[:-1]) + [1, x1.shape[-1]] - x2_expand_shape = list(x2.shape[:-2]) + [1] + list(x2.shape[-2:]) - - # Reshape x1 and x2 for broadcasting - x1_expanded = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_x1_expand", x1, x1_expand_shape - ) - x2_expanded = impl.shuffle.reshape( - ctx, target, source_ir, f"{name}_x2_expand", x2, x2_expand_shape - ) - - diff = impl.elementwise.sub( - ctx, target, source_ir, f"{name}_diff", x1_expanded, x2_expanded - ) - - if p == 0: - diff_non_zero = impl.elementwise.ne( - ctx, target, source_ir, f"{name}_diff_non_zero", diff, 0 - ) - diff_non_zero = cast_trt_tensor( - ctx, diff_non_zero, torch.float32, f"{name}_cast", target, source_ir - ) - dist = impl.reduce.sum( - ctx, - target, - source_ir, - f"{name}_sum", - diff_non_zero, - dim=-1, - keepdim=False, - ) - elif p == 1: - abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) - dist = impl.reduce.sum( - ctx, target, source_ir, f"{name}_sum", abs_val, dim=-1, keepdim=False - ) - elif p == 2: - diff_squared = impl.elementwise.pow( - ctx, target, source_ir, f"{name}_diff_squared", diff, 2 - ) - dist_squared = impl.reduce.sum( - ctx, - target, - source_ir, - f"{name}_dist_sq_sum", - diff_squared, - dim=-1, - keepdim=False, - ) - dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared) - elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"): - abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) - pow_val = impl.elementwise.pow( - ctx, target, source_ir, f"{name}_pow_val_1", abs_val, p - ) - sum_val = impl.reduce.sum( - ctx, target, source_ir, f"{name}_sum", pow_val, dim=-1, keepdim=False - ) - dist = impl.elementwise.pow( - ctx, target, source_ir, f"{name}_pow_val_2", sum_val, 1 / p - ) - elif p == float("inf"): - abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) - dist = impl.reduce.max( - ctx, - target, - source_ir, - f"{name}_max", - abs_val, - dim=-1, - keepdim=False, - return_indices=False, - ) - else: - raise NotImplementedError(f"Currently, p={p} is not implemented.") - return dist diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 6db7ed667e..2a9849f1b5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,3 +1,4 @@ +import logging from typing import Any, List, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -21,6 +22,8 @@ from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims +_LOGGER: logging.Logger = logging.getLogger(__name__) + def batch_norm( ctx: ConversionContext, @@ -446,3 +449,119 @@ def pdist( ) indices = np.triu_indices(shape[0], k=1) return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) + + +def cdist_forward( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + x1: TRTTensor, + x2: TRTTensor, + p: float, + compute_mode: int, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + """ + Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension + of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting + the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances + are computed for each matching set in these dimensions. + + The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are + merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions + (except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting). + + Example: + - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20]. + This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2. + - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), + since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. + + Note: The `compute_mode` parameter is accepted for compatibility with PyTorch's cdist function signature, + but it does not influence the computational path in this implementation. All modes lead to the same computational logic. + """ + if compute_mode is None: + compute_mode = 0 + + x1_expand_shape = list(x1.shape[:-1]) + [1, x1.shape[-1]] + x2_expand_shape = list(x2.shape[:-2]) + [1] + list(x2.shape[-2:]) + + # Reshape x1 and x2 for broadcasting + x1_expanded = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x1_expand", x1, x1_expand_shape + ) + x2_expanded = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x2_expand", x2, x2_expand_shape + ) + + diff = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_diff", x1_expanded, x2_expanded + ) + + if p == 0: + diff_non_zero = impl.elementwise.ne( + ctx, target, source_ir, f"{name}_diff_non_zero", diff, 0 + ) + diff_non_zero = cast_trt_tensor( + ctx, diff_non_zero, torch.float32, f"{name}_cast", target, source_ir + ) + dist = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_sum", + diff_non_zero, + dim=-1, + keepdim=False, + ) + elif p == 1: + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + dist = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", abs_val, dim=-1, keepdim=False + ) + elif p == 2: + if ( + compute_mode == 0 and (x1.shape[-2] > 25 or x2.shape[-2] > 25) + ) or compute_mode == 1: + _LOGGER.warning( + "compute_mode to use matrix multiplication for euclid distance calculation is not utilized in the current implementation." + ) + diff_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_diff_squared", diff, 2 + ) + dist_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_dist_sq_sum", + diff_squared, + dim=-1, + keepdim=False, + ) + dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared) + elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"): + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + pow_val = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow_val_1", abs_val, p + ) + sum_val = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", pow_val, dim=-1, keepdim=False + ) + dist = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow_val_2", sum_val, 1 / p + ) + elif p == float("inf"): + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + dist = impl.reduce.max( + ctx, + target, + source_ir, + f"{name}_max", + abs_val, + dim=-1, + keepdim=False, + return_indices=False, + ) + else: + raise NotImplementedError(f"Currently, p={p} is not implemented.") + return dist diff --git a/tests/py/dynamo/conversion/test_cdist_aten.py b/tests/py/dynamo/conversion/test_cdist_aten.py index 51660600cc..1380d5089b 100644 --- a/tests/py/dynamo/conversion/test_cdist_aten.py +++ b/tests/py/dynamo/conversion/test_cdist_aten.py @@ -9,24 +9,21 @@ class TestCdistConverter(DispatchTestCase): @parameterized.expand( [ - ((4, 3, 4), 0), - ((10, 3, 5, 2, 6), 0.5), - ((15, 10, 5), 0), - ((15, 10, 5), 0.5), - ((15, 10, 5), 1.0), - ((15, 10, 5), 1.5), - ((15, 10, 5), 2.0), - ((15, 10, 5), 2.99), - ((15, 10, 5), float("inf")), + ("p_0", (4, 3, 4), 0, 0), + ("p>0_p<1_1", (10, 3, 5, 2, 6), 0.5, 1), + ("p>0_p<1_2", (10, 2, 15, 2, 7, 2), 0.5, 1), + ("p_1", (15, 10, 5), 1, None), + ("p>1_p<2", (19, 11, 5), 1.5, None), + ("small_p_2_mode_1", (6, 6, 5), 2.0, 1), + ("large_p_2_mode_0", (35, 35, 5), 2.0, 0), + ("p>2", (15, 10, 5), 2.99, None), + ("p_inf", (5, 15, 5), float("inf"), 0), ] ) - def test_cdist_float(self, shape, p): + def test_cdist_float_same_shape(self, name, shape, p, compute_mode): class Cdist(nn.Module): def forward(self, x1, x2): - print("x1 : ", x1) - print("x2 : ", x2) - - return torch.ops.aten._cdist_forward.default(x1, x2, p, None) + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) inputs = [torch.randn(shape), torch.randn(shape)] self.run_test( @@ -36,22 +33,23 @@ def forward(self, x1, x2): @parameterized.expand( [ - ((1, 5), (2, 3, 5), 1), - ((4, 5), (2, 3, 5), 1), - ((2, 4, 5), (2, 3, 5), 1), - ((2, 2, 4, 5), (2, 3, 5), 0), - ((2, 2, 4, 5), (2, 3, 5), 0.5), - ((2, 2, 4, 5), (2, 3, 5), 1), - ((2, 2, 4, 5), (2, 3, 5), 1.5), - ((2, 2, 4, 5), (2, 3, 5), 2), - ((2, 2, 4, 5), (2, 3, 5), 2.99), - ((2, 2, 4, 5), (2, 3, 5), float("inf")), + ("p_0", (1, 5), (2, 3, 5), 0, 0), + ("p_1", (4, 5), (2, 3, 5), 1, None), + ("diff_shape_p_0", (2, 5, 4, 5), (2, 5, 8, 5), 0, 2), + ("diff_shape_p_1", (2, 4, 5), (2, 3, 5), 1, 1), + ("p>0_p<1", (2, 2, 4, 5), (2, 3, 5), 0.5, None), + ("p>1_p<2", (5, 2, 12, 5), (2, 3, 5), 1.5, 1), + ("p_2", (2, 2, 14, 5), (2, 3, 5), 2, 0), + ("p>2", (2, 2, 4, 5), (2, 10, 5), 2.99, 2), + ("p_inf", (2, 2, 3, 5), (2, 8, 5), float("inf"), None), ] ) - def test_cdist_broadcast_float(self, shape_1, shape_2, p): + def test_cdist_float_broadcast_and_diff_shape( + self, name, shape_1, shape_2, p, compute_mode + ): class Cdist(nn.Module): def forward(self, x1, x2): - return torch.ops.aten._cdist_forward.default(x1, x2, p, None) + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) inputs = [torch.randn(shape_1), torch.randn(shape_2)] self.run_test( @@ -59,6 +57,22 @@ def forward(self, x1, x2): inputs, ) + @parameterized.expand( + [ + ("compute_mode_0", (15, 10, 5), (15, 35, 5), 2.0, 0), + ("compute_mode_1", (35, 35, 5), (35, 45, 5), 2.0, 0), + ("compute_mode_2", (15, 10, 5), (15, 35, 5), 2.0, 1), + ("compute_mode_3", (35, 35, 5), (35, 45, 5), 2.0, 2), + ] + ) + def test_cdist_p_2_compute_mode(self, name, shape_1, shape_2, p, compute_mode): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) + + inputs = [torch.randn(shape_1), torch.randn(shape_2)] + self.run_test(Cdist(), inputs) + if __name__ == "__main__": run_tests() From cc21afee6d39b943285d65c502998b9243c7fc75 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 9 May 2024 13:33:03 +0900 Subject: [PATCH 3/5] feat: add matrix multiply for p2 norm (efficiency) --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../conversion/impl/normalization/ops.py | 121 +++++++++++++++--- tests/py/dynamo/conversion/test_cdist_aten.py | 19 +++ 3 files changed, 121 insertions(+), 21 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 48287b0be0..15a993668b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2202,7 +2202,7 @@ def aten_ops_cdist_forward( x1=args[0], x2=args[1], p=args[2], - compute_mode=args[3], + compute_mode=args_bounds_check(args, 3, None), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 2a9849f1b5..de952c4015 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -459,7 +459,7 @@ def cdist_forward( x1: TRTTensor, x2: TRTTensor, p: float, - compute_mode: int, + compute_mode: Optional[int], ) -> Union[TRTTensor, Sequence[TRTTensor]]: """ Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension @@ -471,14 +471,26 @@ def cdist_forward( merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions (except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting). + Args: + x1 (Tensor): input tensor of shape B x P x M. + x2 (Tensor): input tensor of shape B x R x M. + p (float): p value for the p-norm distance to calculate between each vector pair + compute_mode (int): Controls the computation method based on the size of the input sets: + - None ('use_mm_for_euclid_dist_if_necessary'): Default mode. Uses matrix multiplication to calculate + Euclidean distance (p=2) if either the number of vectors in x1 or x2 exceeds 25 (P > 25 or R > 25). + - 1 ('use_mm_for_euclid_dist'): Always use matrix multiplication approach to calculate + euclidean distance (p = 2) + - 2 ('donot_use_mm_for_euclid_dist'): Never use matrix multiplication approach to calculate + euclidean distance (p = 2) + Example: - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20]. This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2. - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. - Note: The `compute_mode` parameter is accepted for compatibility with PyTorch's cdist function signature, - but it does not influence the computational path in this implementation. All modes lead to the same computational logic. + Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, especially useful when working with large datasets. + This parameter allows you to control how the distances are computed, with different modes available to leverage matrix multiplication for speed improvements. """ if compute_mode is None: compute_mode = 0 @@ -523,22 +535,93 @@ def cdist_forward( if ( compute_mode == 0 and (x1.shape[-2] > 25 or x2.shape[-2] > 25) ) or compute_mode == 1: - _LOGGER.warning( - "compute_mode to use matrix multiplication for euclid distance calculation is not utilized in the current implementation." + # Compute squared elements + x1_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_x1_squared", x1, 2 ) - diff_squared = impl.elementwise.pow( - ctx, target, source_ir, f"{name}_diff_squared", diff, 2 - ) - dist_squared = impl.reduce.sum( - ctx, - target, - source_ir, - f"{name}_dist_sq_sum", - diff_squared, - dim=-1, - keepdim=False, - ) - dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared) + x2_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_x2_squared", x2, 2 + ) + + # Sum squares along the last dimension + x1_sum_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_x1_sum", + x1_squared, + dim=-1, + keepdim=True, + ) + x2_sum_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_x2_sum", + x2_squared, + dim=-1, + keepdim=True, + ) + + # Reshape sums for broadcasting + rank = len(x2.shape) + permute_shape = list(range(rank - 2)) + [rank - 1, rank - 2] + x1_sum_expanded = x1_sum_squared + x2_sum_expanded = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute", x2_sum_squared, permute_shape + ) + + # Compute dot product of x1 and transposed x2 + x2_tr = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute_mm", x2, permute_shape + ) + dot_product = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + f"{name}_dot_product", + x1, + x2_tr, + input_matrix_op=trt.MatrixOperation.NONE, + other_matrix_op=trt.MatrixOperation.NONE, + ) + + # Combine results to get squared distances + dist_squared = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_dist_squared_initial", + x1_sum_expanded, + x2_sum_expanded, + ) + dist_squared = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_dist_squared", + dist_squared, + impl.elementwise.mul( + ctx, target, source_ir, f"{name}_dot_product_scaled", dot_product, 2 + ), + ) + + # Compute the Euclidean distances + dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_dist", dist_squared) + else: + diff_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_diff_squared", diff, 2 + ) + dist_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_dist_sq_sum", + diff_squared, + dim=-1, + keepdim=False, + ) + dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared) elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"): abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) pow_val = impl.elementwise.pow( @@ -562,6 +645,4 @@ def cdist_forward( keepdim=False, return_indices=False, ) - else: - raise NotImplementedError(f"Currently, p={p} is not implemented.") return dist diff --git a/tests/py/dynamo/conversion/test_cdist_aten.py b/tests/py/dynamo/conversion/test_cdist_aten.py index 1380d5089b..7397c057b7 100644 --- a/tests/py/dynamo/conversion/test_cdist_aten.py +++ b/tests/py/dynamo/conversion/test_cdist_aten.py @@ -63,6 +63,9 @@ def forward(self, x1, x2): ("compute_mode_1", (35, 35, 5), (35, 45, 5), 2.0, 0), ("compute_mode_2", (15, 10, 5), (15, 35, 5), 2.0, 1), ("compute_mode_3", (35, 35, 5), (35, 45, 5), 2.0, 2), + ("p_2_mm_shape_1", (2, 2, 14, 5), (3, 5), 2, 1), + ("p_2_mm_shape_2", (2, 2, 14, 5), (2, 3, 5), 2, 1), + ("p_2_mm_shape_3", (2, 2, 14, 5), (2, 2, 3, 5), 2, 1), ] ) def test_cdist_p_2_compute_mode(self, name, shape_1, shape_2, p, compute_mode): @@ -73,6 +76,22 @@ def forward(self, x1, x2): inputs = [torch.randn(shape_1), torch.randn(shape_2)] self.run_test(Cdist(), inputs) + @parameterized.expand( + [ + ("p_2_matmul", (150, 100, 50, 50), (150, 100, 30, 50), 2, 1), + ("p_2_elementwise_pow", (150, 100, 50, 50), (150, 100, 30, 50), 2, 2), + ] + ) + def test_cdist_efficiency_p_2_compute_mode( + self, name, shape_1, shape_2, p, compute_mode + ): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) + + inputs = [torch.randn(shape_1), torch.randn(shape_2)] + self.run_test(Cdist(), inputs) + if __name__ == "__main__": run_tests() From 9f901f60ce5a8f411ece49a40b5e9da51511532d Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 9 May 2024 13:33:03 +0900 Subject: [PATCH 4/5] chore: minor linting --- .../dynamo/conversion/impl/normalization/ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index de952c4015..7d832ece91 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -489,8 +489,9 @@ def cdist_forward( - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. - Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, especially useful when working with large datasets. - This parameter allows you to control how the distances are computed, with different modes available to leverage matrix multiplication for speed improvements. + Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, + especially useful when working with large datasets. This parameter allows you to control how the distances are computed, + with different modes available to leverage matrix multiplication for speed improvements. """ if compute_mode is None: compute_mode = 0 From b6a9e48410400c370712a964042fbc215a88b2c5 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Sat, 11 May 2024 00:05:36 +0900 Subject: [PATCH 5/5] chore: smaller dataset size for matrix mul test --- tests/py/dynamo/conversion/test_cdist_aten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/py/dynamo/conversion/test_cdist_aten.py b/tests/py/dynamo/conversion/test_cdist_aten.py index 7397c057b7..71628df510 100644 --- a/tests/py/dynamo/conversion/test_cdist_aten.py +++ b/tests/py/dynamo/conversion/test_cdist_aten.py @@ -78,8 +78,8 @@ def forward(self, x1, x2): @parameterized.expand( [ - ("p_2_matmul", (150, 100, 50, 50), (150, 100, 30, 50), 2, 1), - ("p_2_elementwise_pow", (150, 100, 50, 50), (150, 100, 30, 50), 2, 2), + ("p_2_matmul", (50, 40, 30, 30), (50, 40, 35, 30), 2, 1), + ("p_2_elementwise_pow", (50, 40, 30, 50), (50, 40, 35, 50), 2, 2), ] ) def test_cdist_efficiency_p_2_compute_mode(