Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aten scatter converter #2664

Merged
merged 10 commits into from
May 23, 2024
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,26 @@ def aten_ops_clamp(
)


@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
def aten_ops_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


@dynamo_tensorrt_converter(torch.ops.aten.select.int)
def aten_ops_select(
ctx: ConversionContext,
Expand Down
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,41 @@ def index_select(
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)

return gather_layer.get_output(0)


def scatter(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
src: Union[TRTTensor, int, float],
) -> TRTTensor:
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index_shape)
if index.dtype == trt.int64:
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
Comment on lines +408 to +409
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the index must be trt.int32 and no other types (trt.float32, trt.float16, etc.) are acceptable, then it is fine to remove the if statement, as the cast_trt_tensor function will not insert a cast if the type is already int32, as here:

if input_val.dtype != trt_dtype:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be required for trt.int64 cases, which will be the case in our test cases since torch requires int64 inputs.

dim = get_positive_dim(dim, len(input_shape))
src_tensor = src
# scatter.value
if isinstance(src, int) or isinstance(src, float):
src_tensor = get_trt_tensor(
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
)
src_tensor = cast_trt_tensor(
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
)
# scatter.src
elif not (isinstance(src, TRTTensor)):
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")

scatter_layer = ctx.net.add_scatter(
input, index, src_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
49 changes: 39 additions & 10 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# type: ignore

import logging
import time
import unittest
Expand Down Expand Up @@ -50,16 +48,20 @@ def setUp(self):
def run_test(
self,
mod,
inputs,
fx_inputs,
trt_interpreter_inputs,
interpreter,
rtol,
atol,
check_dtype=True,
):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
cuda_fx_inputs = []
cuda_trt_inputs = []
for i in trt_interpreter_inputs:
cuda_trt_inputs.append(i.cuda())
for i in fx_inputs:
cuda_fx_inputs.append(i.cuda())

mod.eval()
start = time.perf_counter()
Expand All @@ -73,13 +75,13 @@ def run_test(
)

mod = mod.cuda()
ref_outputs = mod(*cuda_inputs)
ref_outputs = mod(*cuda_fx_inputs)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
outputs = trt_mod(*cuda_inputs)
outputs = trt_mod(*cuda_trt_inputs)
end_event.record()
torch.cuda.synchronize()
_LOGGER.info(
Expand Down Expand Up @@ -220,6 +222,7 @@ def run_test(
check_dtype=True,
use_dynamo_tracer=False,
enable_passes=False,
int32_reqd=False,
):
mod.eval()
mod = self.generate_graph(
Expand All @@ -237,6 +240,30 @@ def run_test(
debug=True,
)

num_inputs = len(inputs)
trt_inputs = inputs
dtype_to_change = []
if int32_reqd:
dtype_to_change = [torch.int64, torch.float64]
else:
dtype_to_change = [
torch.float64,
]
for num_input in range(num_inputs):
input = inputs[num_input]
if input.dtype in dtype_to_change:
dtype_32bit = (
torch.float32 if (input.dtype == torch.float64) else torch.int32
)
trt_inputs = (
list(trt_inputs[:num_input])
+ [
input.to(dtype_32bit),
]
+ list(trt_inputs[num_input + 1 :])
)

trt_input_specs = [Input.from_tensor(i) for i in trt_inputs]
input_specs = [Input.from_tensor(i) for i in inputs]

output_dtypes = None
Expand All @@ -254,13 +281,15 @@ def run_test(

interp = TRTInterpreter(
mod,
input_specs,
trt_input_specs,
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)

super().run_test(
mod,
inputs,
trt_inputs,
interp,
rtol,
atol,
Expand Down Expand Up @@ -335,4 +364,4 @@ def run_test_with_dynamic_shape(
# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
super().run_test(mod, inputs_max, interp, rtol, atol)
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)
177 changes: 177 additions & 0 deletions tests/py/dynamo/conversion/test_scatter_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestScatterValueConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_value",
0,
torch.tensor([[0, 1, 2, 0]]),
1,
),
(
"scatter_zero_dim_indexTwo_constant_value",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
(
"scatter_one_dim_indexOne_constant_value",
1,
torch.tensor([[0, 1, 2, 0]]),
1,
),
(
"scatter_one_dim_indexTwo_costant_value",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
]
)
def test_scatter_index_constant(self, _, dim, index, value):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
self.run_test(TestModule(), inputs, int32_reqd=True)

@parameterized.expand(
[
("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1),
(
"scatter_zero_dim_indexTwo_value",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1),
(
"scatter_one_dim_indexTwo_value",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
]
)
def test_scatter_index_input(self, _, dim, index, value):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs, int32_reqd=True)


class TestScatterSrcConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
]
)
def test_scatter_index_constant(self, _, dim, index, src):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter.src(input, dim, index, src)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
scatter = TestModule()
self.run_test(TestModule(), inputs, int32_reqd=True)

@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_constant_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
),
]
)
def test_scatter_index_input(self, _, dim, index, src):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.scatter.src(input, dim, index, src)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs, int32_reqd=True)


if __name__ == "__main__":
run_tests()