Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten._cdist_forward converter #2726

Merged
merged 5 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2186,6 +2186,26 @@ def aten_ops_linear(
)


@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
def aten_ops_cdist_forward(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.normalization.cdist_forward(
ctx,
target,
SourceIR.ATEN,
name,
x1=args[0],
x2=args[1],
p=args[2],
compute_mode=args_bounds_check(args, 3, None),
)


def avg_pool_param_validator(pool_node: Node) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)
Expand Down
201 changes: 201 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
Expand All @@ -21,6 +22,8 @@
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

_LOGGER: logging.Logger = logging.getLogger(__name__)


def batch_norm(
ctx: ConversionContext,
Expand Down Expand Up @@ -446,3 +449,201 @@ def pdist(
)
indices = np.triu_indices(shape[0], k=1)
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)


def cdist_forward(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
x1: TRTTensor,
x2: TRTTensor,
p: float,
compute_mode: Optional[int],
) -> Union[TRTTensor, Sequence[TRTTensor]]:
"""
Computes pairwise distances between sets of vectors in tensors x1 and x2 using the p-norm. The function treats the last dimension
of x1 and x2 as feature dimensions, which must be identical for both inputs. The second-to-last dimensions can differ, reflecting
the number of vectors in each tensor. The dimensions preceding the last are considered as batch dimensions, and pairwise distances
are computed for each matching set in these dimensions.

The output tensor's shape is derived by matching the batch dimensions of x1 and x2, where the mismatched batch dimensions are
merged, and the resulting shape reflects the computed distances for each pair of vectors. It's crucial that the batch dimensions
(except for the size of sets of vectors to compare) of x1 and x2 either match or one of them is 1 (broadcasting).

Args:
x1 (Tensor): input tensor of shape B x P x M.
x2 (Tensor): input tensor of shape B x R x M.
p (float): p value for the p-norm distance to calculate between each vector pair
compute_mode (int): Controls the computation method based on the size of the input sets:
- None ('use_mm_for_euclid_dist_if_necessary'): Default mode. Uses matrix multiplication to calculate
Euclidean distance (p=2) if either the number of vectors in x1 or x2 exceeds 25 (P > 25 or R > 25).
- 1 ('use_mm_for_euclid_dist'): Always use matrix multiplication approach to calculate
euclidean distance (p = 2)
- 2 ('donot_use_mm_for_euclid_dist'): Never use matrix multiplication approach to calculate
euclidean distance (p = 2)

Example:
- If x1.shape = [2, 3, 10, 5] and x2.shape = [2, 3, 20, 5], both having the same batch dimensions [2, 3], the output shape will be [2, 3, 10, 20].
This represents computing distances in two batches of three groups, each comparing 10 vectors from x1 with 20 vectors from x2.
- For x1.shape = [10, 5] (10 vectors, each of 5 features) and x2.shape = [20, 5] (20 vectors, each of 5 features),
since there are no batch dimensions to match, the output shape is simply [10, 20], comparing all vectors from x1 against all vectors from x2.

Note: The `compute_mode` parameter is designed to optimize the performance of the Euclidean distance calculation,
especially useful when working with large datasets. This parameter allows you to control how the distances are computed,
with different modes available to leverage matrix multiplication for speed improvements.
"""
if compute_mode is None:
compute_mode = 0

x1_expand_shape = list(x1.shape[:-1]) + [1, x1.shape[-1]]
x2_expand_shape = list(x2.shape[:-2]) + [1] + list(x2.shape[-2:])

# Reshape x1 and x2 for broadcasting
x1_expanded = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_x1_expand", x1, x1_expand_shape
)
x2_expanded = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_x2_expand", x2, x2_expand_shape
)

diff = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_diff", x1_expanded, x2_expanded
)

if p == 0:
diff_non_zero = impl.elementwise.ne(
ctx, target, source_ir, f"{name}_diff_non_zero", diff, 0
)
diff_non_zero = cast_trt_tensor(
ctx, diff_non_zero, torch.float32, f"{name}_cast", target, source_ir
)
dist = impl.reduce.sum(
ctx,
target,
source_ir,
f"{name}_sum",
diff_non_zero,
dim=-1,
keepdim=False,
)
elif p == 1:
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff)
dist = impl.reduce.sum(
ctx, target, source_ir, f"{name}_sum", abs_val, dim=-1, keepdim=False
)
elif p == 2:
if (
compute_mode == 0 and (x1.shape[-2] > 25 or x2.shape[-2] > 25)
) or compute_mode == 1:
# Compute squared elements
x1_squared = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_x1_squared", x1, 2
)
x2_squared = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_x2_squared", x2, 2
)

# Sum squares along the last dimension
x1_sum_squared = impl.reduce.sum(
ctx,
target,
source_ir,
f"{name}_x1_sum",
x1_squared,
dim=-1,
keepdim=True,
)
x2_sum_squared = impl.reduce.sum(
ctx,
target,
source_ir,
f"{name}_x2_sum",
x2_squared,
dim=-1,
keepdim=True,
)

# Reshape sums for broadcasting
rank = len(x2.shape)
permute_shape = list(range(rank - 2)) + [rank - 1, rank - 2]
x1_sum_expanded = x1_sum_squared
x2_sum_expanded = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", x2_sum_squared, permute_shape
)

# Compute dot product of x1 and transposed x2
x2_tr = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute_mm", x2, permute_shape
)
dot_product = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
f"{name}_dot_product",
x1,
x2_tr,
input_matrix_op=trt.MatrixOperation.NONE,
other_matrix_op=trt.MatrixOperation.NONE,
)

# Combine results to get squared distances
dist_squared = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_dist_squared_initial",
x1_sum_expanded,
x2_sum_expanded,
)
dist_squared = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_dist_squared",
dist_squared,
impl.elementwise.mul(
ctx, target, source_ir, f"{name}_dot_product_scaled", dot_product, 2
),
)

# Compute the Euclidean distances
dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_dist", dist_squared)
else:
diff_squared = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_diff_squared", diff, 2
)
dist_squared = impl.reduce.sum(
ctx,
target,
source_ir,
f"{name}_dist_sq_sum",
diff_squared,
dim=-1,
keepdim=False,
)
dist = impl.unary.sqrt(ctx, target, source_ir, f"{name}_sqrt", dist_squared)
elif 0 < p < 1 or 1 < p < 2 or 2 < p < float("inf"):
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff)
pow_val = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_pow_val_1", abs_val, p
)
sum_val = impl.reduce.sum(
ctx, target, source_ir, f"{name}_sum", pow_val, dim=-1, keepdim=False
)
dist = impl.elementwise.pow(
ctx, target, source_ir, f"{name}_pow_val_2", sum_val, 1 / p
)
elif p == float("inf"):
abs_val = impl.unary.abs(ctx, target, source_ir, f"{name}_abs_val", diff)
dist = impl.reduce.max(
ctx,
target,
source_ir,
f"{name}_max",
abs_val,
dim=-1,
keepdim=False,
return_indices=False,
)
return dist
97 changes: 97 additions & 0 deletions tests/py/dynamo/conversion/test_cdist_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestCdistConverter(DispatchTestCase):
@parameterized.expand(
[
("p_0", (4, 3, 4), 0, 0),
("p>0_p<1_1", (10, 3, 5, 2, 6), 0.5, 1),
("p>0_p<1_2", (10, 2, 15, 2, 7, 2), 0.5, 1),
("p_1", (15, 10, 5), 1, None),
("p>1_p<2", (19, 11, 5), 1.5, None),
("small_p_2_mode_1", (6, 6, 5), 2.0, 1),
("large_p_2_mode_0", (35, 35, 5), 2.0, 0),
("p>2", (15, 10, 5), 2.99, None),
("p_inf", (5, 15, 5), float("inf"), 0),
]
)
def test_cdist_float_same_shape(self, name, shape, p, compute_mode):
class Cdist(nn.Module):
def forward(self, x1, x2):
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)

inputs = [torch.randn(shape), torch.randn(shape)]
zewenli98 marked this conversation as resolved.
Show resolved Hide resolved
self.run_test(
Cdist(),
inputs,
)

@parameterized.expand(
[
("p_0", (1, 5), (2, 3, 5), 0, 0),
("p_1", (4, 5), (2, 3, 5), 1, None),
("diff_shape_p_0", (2, 5, 4, 5), (2, 5, 8, 5), 0, 2),
("diff_shape_p_1", (2, 4, 5), (2, 3, 5), 1, 1),
("p>0_p<1", (2, 2, 4, 5), (2, 3, 5), 0.5, None),
("p>1_p<2", (5, 2, 12, 5), (2, 3, 5), 1.5, 1),
("p_2", (2, 2, 14, 5), (2, 3, 5), 2, 0),
("p>2", (2, 2, 4, 5), (2, 10, 5), 2.99, 2),
("p_inf", (2, 2, 3, 5), (2, 8, 5), float("inf"), None),
]
)
def test_cdist_float_broadcast_and_diff_shape(
self, name, shape_1, shape_2, p, compute_mode
):
class Cdist(nn.Module):
def forward(self, x1, x2):
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)

inputs = [torch.randn(shape_1), torch.randn(shape_2)]
self.run_test(
Cdist(),
inputs,
)

@parameterized.expand(
[
("compute_mode_0", (15, 10, 5), (15, 35, 5), 2.0, 0),
("compute_mode_1", (35, 35, 5), (35, 45, 5), 2.0, 0),
("compute_mode_2", (15, 10, 5), (15, 35, 5), 2.0, 1),
("compute_mode_3", (35, 35, 5), (35, 45, 5), 2.0, 2),
("p_2_mm_shape_1", (2, 2, 14, 5), (3, 5), 2, 1),
("p_2_mm_shape_2", (2, 2, 14, 5), (2, 3, 5), 2, 1),
("p_2_mm_shape_3", (2, 2, 14, 5), (2, 2, 3, 5), 2, 1),
]
)
def test_cdist_p_2_compute_mode(self, name, shape_1, shape_2, p, compute_mode):
class Cdist(nn.Module):
def forward(self, x1, x2):
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)

inputs = [torch.randn(shape_1), torch.randn(shape_2)]
self.run_test(Cdist(), inputs)

@parameterized.expand(
[
("p_2_matmul", (150, 100, 50, 50), (150, 100, 30, 50), 2, 1),
("p_2_elementwise_pow", (150, 100, 50, 50), (150, 100, 30, 50), 2, 2),
]
)
def test_cdist_efficiency_p_2_compute_mode(
self, name, shape_1, shape_2, p, compute_mode
):
class Cdist(nn.Module):
def forward(self, x1, x2):
return torch.ops.aten._cdist_forward.default(x1, x2, p, compute_mode)

inputs = [torch.randn(shape_1), torch.randn(shape_2)]
self.run_test(Cdist(), inputs)


if __name__ == "__main__":
run_tests()
Loading