diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1705dd06db..15a993668b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2186,6 +2186,26 @@ def aten_ops_linear( ) +@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default) +def aten_ops_cdist_forward( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.cdist_forward( + ctx, + target, + SourceIR.ATEN, + name, + x1=args[0], + x2=args[1], + p=args[2], + compute_mode=args_bounds_check(args, 3, None), + ) + + def avg_pool_param_validator(pool_node: Node) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 6db7ed667e..7d832ece91 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -1,3 +1,4 @@ +import logging from typing import Any, List, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -21,6 +22,8 @@ from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims +_LOGGER: logging.Logger = logging.getLogger(__name__) + def batch_norm( ctx: ConversionContext, @@ -446,3 +449,201 @@ def pdist( ) indices = np.triu_indices(shape[0], k=1) return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) + + +def cdist_forward( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + x1: TRTTensor, + x2: TRTTensor, + p: float, + compute_mode: Optional[int], +) -> Union[TRTTensor, Sequence[TRTTensor]]: + """ + Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension + of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting + the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances + are computed for each matching set in these dimensions. + + The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are + merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions + (except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting). + + Args: + x1 (Tensor): input tensor of shape B x P x M. + x2 (Tensor): input tensor of shape B x R x M. + p (float): p value for the p-norm distance to calculate between each vector pair + compute_mode (int): Controls the computation method based on the size of the input sets: + - None ('use_mm_for_euclid_dist_if_necessary'): Default mode. Uses matrix multiplication to calculate + Euclidean distance (p=2) if either the number of vectors in x1 or x2 exceeds 25 (P > 25 or R > 25). + - 1 ('use_mm_for_euclid_dist'): Always use matrix multiplication approach to calculate + euclidean distance (p = 2) + - 2 ('donot_use_mm_for_euclid_dist'): Never use matrix multiplication approach to calculate + euclidean distance (p = 2) + + Example: + - If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20]. + This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2. + - For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features), + since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2. + + Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation, + especially useful when working with large datasets. This parameter allows you to control how the distances are computed, + with different modes available to leverage matrix multiplication for speed improvements. + """ + if compute_mode is None: + compute_mode = 0 + + x1_expand_shape = list(x1.shape[:-1]) + [1, x1.shape[-1]] + x2_expand_shape = list(x2.shape[:-2]) + [1] + list(x2.shape[-2:]) + + # Reshape x1 and x2 for broadcasting + x1_expanded = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x1_expand", x1, x1_expand_shape + ) + x2_expanded = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x2_expand", x2, x2_expand_shape + ) + + diff = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_diff", x1_expanded, x2_expanded + ) + + if p == 0: + diff_non_zero = impl.elementwise.ne( + ctx, target, source_ir, f"{name}_diff_non_zero", diff, 0 + ) + diff_non_zero = cast_trt_tensor( + ctx, diff_non_zero, torch.float32, f"{name}_cast", target, source_ir + ) + dist = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_sum", + diff_non_zero, + dim=-1, + keepdim=False, + ) + elif p == 1: + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + dist = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", abs_val, dim=-1, keepdim=False + ) + elif p == 2: + if ( + compute_mode == 0 and (x1.shape[-2] > 25 or x2.shape[-2] > 25) + ) or compute_mode == 1: + # Compute squared elements + x1_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_x1_squared", x1, 2 + ) + x2_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_x2_squared", x2, 2 + ) + + # Sum squares along the last dimension + x1_sum_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_x1_sum", + x1_squared, + dim=-1, + keepdim=True, + ) + x2_sum_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_x2_sum", + x2_squared, + dim=-1, + keepdim=True, + ) + + # Reshape sums for broadcasting + rank = len(x2.shape) + permute_shape = list(range(rank - 2)) + [rank - 1, rank - 2] + x1_sum_expanded = x1_sum_squared + x2_sum_expanded = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute", x2_sum_squared, permute_shape + ) + + # Compute dot product of x1 and transposed x2 + x2_tr = impl.permutation.permute( + ctx, target, source_ir, f"{name}_permute_mm", x2, permute_shape + ) + dot_product = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + f"{name}_dot_product", + x1, + x2_tr, + input_matrix_op=trt.MatrixOperation.NONE, + other_matrix_op=trt.MatrixOperation.NONE, + ) + + # Combine results to get squared distances + dist_squared = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_dist_squared_initial", + x1_sum_expanded, + x2_sum_expanded, + ) + dist_squared = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_dist_squared", + dist_squared, + impl.elementwise.mul( + ctx, target, source_ir, f"{name}_dot_product_scaled", dot_product, 2 + ), + ) + + # Compute the Euclidean distances + dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_dist", dist_squared) + else: + diff_squared = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_diff_squared", diff, 2 + ) + dist_squared = impl.reduce.sum( + ctx, + target, + source_ir, + f"{name}_dist_sq_sum", + diff_squared, + dim=-1, + keepdim=False, + ) + dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared) + elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"): + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + pow_val = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow_val_1", abs_val, p + ) + sum_val = impl.reduce.sum( + ctx, target, source_ir, f"{name}_sum", pow_val, dim=-1, keepdim=False + ) + dist = impl.elementwise.pow( + ctx, target, source_ir, f"{name}_pow_val_2", sum_val, 1 / p + ) + elif p == float("inf"): + abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff) + dist = impl.reduce.max( + ctx, + target, + source_ir, + f"{name}_max", + abs_val, + dim=-1, + keepdim=False, + return_indices=False, + ) + return dist diff --git a/tests/py/dynamo/conversion/test_cdist_aten.py b/tests/py/dynamo/conversion/test_cdist_aten.py new file mode 100644 index 0000000000..71628df510 --- /dev/null +++ b/tests/py/dynamo/conversion/test_cdist_aten.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestCdistConverter(DispatchTestCase): + @parameterized.expand( + [ + ("p_0", (4, 3, 4), 0, 0), + ("p>0_p<1_1", (10, 3, 5, 2, 6), 0.5, 1), + ("p>0_p<1_2", (10, 2, 15, 2, 7, 2), 0.5, 1), + ("p_1", (15, 10, 5), 1, None), + ("p>1_p<2", (19, 11, 5), 1.5, None), + ("small_p_2_mode_1", (6, 6, 5), 2.0, 1), + ("large_p_2_mode_0", (35, 35, 5), 2.0, 0), + ("p>2", (15, 10, 5), 2.99, None), + ("p_inf", (5, 15, 5), float("inf"), 0), + ] + ) + def test_cdist_float_same_shape(self, name, shape, p, compute_mode): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) + + inputs = [torch.randn(shape), torch.randn(shape)] + self.run_test( + Cdist(), + inputs, + ) + + @parameterized.expand( + [ + ("p_0", (1, 5), (2, 3, 5), 0, 0), + ("p_1", (4, 5), (2, 3, 5), 1, None), + ("diff_shape_p_0", (2, 5, 4, 5), (2, 5, 8, 5), 0, 2), + ("diff_shape_p_1", (2, 4, 5), (2, 3, 5), 1, 1), + ("p>0_p<1", (2, 2, 4, 5), (2, 3, 5), 0.5, None), + ("p>1_p<2", (5, 2, 12, 5), (2, 3, 5), 1.5, 1), + ("p_2", (2, 2, 14, 5), (2, 3, 5), 2, 0), + ("p>2", (2, 2, 4, 5), (2, 10, 5), 2.99, 2), + ("p_inf", (2, 2, 3, 5), (2, 8, 5), float("inf"), None), + ] + ) + def test_cdist_float_broadcast_and_diff_shape( + self, name, shape_1, shape_2, p, compute_mode + ): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) + + inputs = [torch.randn(shape_1), torch.randn(shape_2)] + self.run_test( + Cdist(), + inputs, + ) + + @parameterized.expand( + [ + ("compute_mode_0", (15, 10, 5), (15, 35, 5), 2.0, 0), + ("compute_mode_1", (35, 35, 5), (35, 45, 5), 2.0, 0), + ("compute_mode_2", (15, 10, 5), (15, 35, 5), 2.0, 1), + ("compute_mode_3", (35, 35, 5), (35, 45, 5), 2.0, 2), + ("p_2_mm_shape_1", (2, 2, 14, 5), (3, 5), 2, 1), + ("p_2_mm_shape_2", (2, 2, 14, 5), (2, 3, 5), 2, 1), + ("p_2_mm_shape_3", (2, 2, 14, 5), (2, 2, 3, 5), 2, 1), + ] + ) + def test_cdist_p_2_compute_mode(self, name, shape_1, shape_2, p, compute_mode): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) + + inputs = [torch.randn(shape_1), torch.randn(shape_2)] + self.run_test(Cdist(), inputs) + + @parameterized.expand( + [ + ("p_2_matmul", (50, 40, 30, 30), (50, 40, 35, 30), 2, 1), + ("p_2_elementwise_pow", (50, 40, 30, 50), (50, 40, 35, 50), 2, 2), + ] + ) + def test_cdist_efficiency_p_2_compute_mode( + self, name, shape_1, shape_2, p, compute_mode + ): + class Cdist(nn.Module): + def forward(self, x1, x2): + return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode) + + inputs = [torch.randn(shape_1), torch.randn(shape_2)] + self.run_test(Cdist(), inputs) + + +if __name__ == "__main__": + run_tests()