Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aten scatter converter #2664

Merged
merged 10 commits into from
May 23, 2024
19 changes: 19 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,25 @@ def aten_ops_clamp(
)


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
def aten_ops_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


@dynamo_tensorrt_converter(torch.ops.aten.select.int)
def aten_ops_select(
ctx: ConversionContext,
Expand Down
50 changes: 50 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove this.


_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -390,3 +391,52 @@ def index_select(
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)

return gather_layer.get_output(0)


def scatter(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
src: Union[TRTTensor, int, float],
) -> TRTTensor:
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index.shape)
if not (isinstance(index, TRTTensor)):
if isinstance(index, torch.Tensor):
if index.dtype == torch.int64:
index = index.to(torch.int32)
elif isinstance(index, np.ndarray):
if index.dtype == np.int64:
index = index.astype(np.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since TRT supports int64 now, is this still necessary?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, the enforce_tensor_types could be used to autocast any input to an ITensor

Copy link
Collaborator Author

@apbose apbose May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like in the scatter converter layer in TRT still needs int32 since it complains of
ERROR:torch_tensorrt [TensorRT Conversion Context]:4: [SCATTER]-[aten_ops.scatter.value]-[scatter_value_scatter_layer]: Indices tensor must be Int32.
It remains a case where torch needs int64 while scatter layer in TRT needs int32. Earlier this was handled in get_trt_tensor due to TensorRT requirement, but not anymore.

As for enforce_tensor_types it would take care of if the input is TRTTensor, torch.Tensor or np.ndarray and cast it otherwise if promote is true. In this case we want to cast to int64 to int32 as well. Ok so I will replace this with enforce_tensor_type and then use cast_trt_tensor

index = get_trt_tensor(ctx, index, f"_index_tensor")
dim = get_positive_dim(dim, len(input_shape))
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, and should be addressed by #2796

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after this PR is merged the above would need to be addressed accordingly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe the above would have to be addressed in a future PR which adds dynamic shape support for this converter. With #2796, this converter would be registered as static-only by default


src_tensor = src
# scatter.value
if isinstance(src, int) or isinstance(src, float):
src_tensor = get_trt_tensor(
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is preferred to use np.ones here to avoid FakeTensor issues

)
src_tensor = cast_trt_tensor(
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
)
# scatter.src
elif not (isinstance(src, TRTTensor)):
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")

scatter_layer = ctx.net.add_scatter(
input, index, src_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
52 changes: 42 additions & 10 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# type: ignore

import copy
import logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No will remove

import time
import unittest
Expand Down Expand Up @@ -50,16 +49,20 @@ def setUp(self):
def run_test(
self,
mod,
inputs,
fx_inputs,
trt_interpreter_inputs,
interpreter,
rtol,
atol,
check_dtype=True,
):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
cuda_fx_inputs = []
cuda_trt_inputs = []
for i in trt_interpreter_inputs:
cuda_trt_inputs.append(i.cuda())
for i in fx_inputs:
cuda_fx_inputs.append(i.cuda())

mod.eval()
start = time.perf_counter()
Expand All @@ -73,13 +76,13 @@ def run_test(
)

mod = mod.cuda()
ref_outputs = mod(*cuda_inputs)
ref_outputs = mod(*cuda_fx_inputs)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
outputs = trt_mod(*cuda_inputs)
outputs = trt_mod(*cuda_trt_inputs)
end_event.record()
torch.cuda.synchronize()
_LOGGER.info(
Expand Down Expand Up @@ -220,6 +223,7 @@ def run_test(
check_dtype=True,
use_dynamo_tracer=False,
enable_passes=False,
int32_reqd=False,
):
mod.eval()
mod = self.generate_graph(
Expand All @@ -237,6 +241,32 @@ def run_test(
debug=True,
)

num_inputs = len(inputs)
trt_inputs = inputs
dtype_to_change = []
if int32_reqd:
dtype_to_change = [torch.int64, torch.float64]
else:
dtype_to_change = [
torch.float64,
]
for num_input in range(num_inputs):
input = inputs[num_input]
if input.dtype in dtype_to_change:
dtype_32bit = (
torch.float32 if (input.dtype == torch.float64) else torch.int32
)
# should we modify graph here to insert clone nodes?
# ideally not required
trt_inputs = (
list(trt_inputs[:num_input])
+ [
input.to(dtype_32bit),
]
+ list(trt_inputs[num_input + 1 :])
)

trt_input_specs = [Input.from_tensor(i) for i in trt_inputs]
input_specs = [Input.from_tensor(i) for i in inputs]

output_dtypes = None
Expand All @@ -254,13 +284,15 @@ def run_test(

interp = TRTInterpreter(
mod,
input_specs,
trt_input_specs,
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)

super().run_test(
mod,
inputs,
trt_inputs,
interp,
rtol,
atol,
Expand Down Expand Up @@ -335,4 +367,4 @@ def run_test_with_dynamic_shape(
# Since the lowering is based on optimal shape. We need to test with
# different shape(for ex. max shape) for testing dynamic shape
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
super().run_test(mod, inputs_max, interp, rtol, atol)
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)
177 changes: 177 additions & 0 deletions tests/py/dynamo/conversion/test_scatter_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestScatterValueConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_value",
0,
torch.tensor([[0, 1, 2, 0]]),
1,
),
(
"scatter_zero_dim_indexTwo_constant_value",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
(
"scatter_one_dim_indexOne_constant_value",
1,
torch.tensor([[0, 1, 2, 0]]),
1,
),
(
"scatter_one_dim_indexTwo_costant_value",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
]
)
def test_scatter_index_constant(self, _, dim, index, value):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
self.run_test(TestModule(), inputs, int32_reqd=True)

@parameterized.expand(
[
("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1),
(
"scatter_zero_dim_indexTwo_value",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1),
(
"scatter_one_dim_indexTwo_value",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
1,
),
]
)
def test_scatter_index_input(self, _, dim, index, value):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs, int32_reqd=True)


class TestScatterSrcConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
]
)
def test_scatter_index_constant(self, _, dim, index, src):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter.src(input, dim, index, src)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input]
scatter = TestModule()
self.run_test(TestModule(), inputs, int32_reqd=True)

@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_constant_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
),
]
)
def test_scatter_index_input(self, _, dim, index, src):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, index):
return torch.ops.aten.scatter.src(input, dim, index, src)

input = torch.zeros(3, 5, dtype=torch.int32)
inputs = [input, index]
self.run_test(TestModule(), inputs, int32_reqd=True)


if __name__ == "__main__":
run_tests()