-
Notifications
You must be signed in to change notification settings - Fork 332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aten scatter converter #2664
Aten scatter converter #2664
Changes from 9 commits
628fab7
6fbc0ec
bfd3498
870b79f
812114e
6afbb8d
19049bc
0449068
0318b33
487906d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
set_layer_name, | ||
) | ||
from torch_tensorrt.fx.types import Shape, TRTTensor | ||
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter | ||
|
||
_LOGGER: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
@@ -390,3 +391,52 @@ def index_select( | |
set_layer_name(gather_layer, target, f"{name}_gather", source_ir) | ||
|
||
return gather_layer.get_output(0) | ||
|
||
|
||
def scatter( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
dim: int, | ||
index: Union[TRTTensor, np.ndarray, torch.Tensor], | ||
src: Union[TRTTensor, int, float], | ||
) -> TRTTensor: | ||
input_shape = input.shape | ||
index_shape = index.shape | ||
index_shape_list = list(index.shape) | ||
if not (isinstance(index, TRTTensor)): | ||
if isinstance(index, torch.Tensor): | ||
if index.dtype == torch.int64: | ||
index = index.to(torch.int32) | ||
elif isinstance(index, np.ndarray): | ||
if index.dtype == np.int64: | ||
index = index.astype(np.int32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since TRT supports There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If not, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like in the scatter converter layer in TRT still needs int32 since it complains of As for |
||
index = get_trt_tensor(ctx, index, f"_index_tensor") | ||
dim = get_positive_dim(dim, len(input_shape)) | ||
dynamic_shape = has_dynamic_shape(input.shape) | ||
if dynamic_shape: | ||
# Check whether slice target dim is dynamic shape dim | ||
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be removed, and should be addressed by #2796 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So after this PR is merged the above would need to be addressed accordingly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I believe the above would have to be addressed in a future PR which adds dynamic shape support for this converter. With #2796, this converter would be registered as static-only by default |
||
|
||
src_tensor = src | ||
# scatter.value | ||
if isinstance(src, int) or isinstance(src, float): | ||
src_tensor = get_trt_tensor( | ||
ctx, src * torch.ones(index_shape_list), name + "_value_tensor" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is preferred to use |
||
) | ||
src_tensor = cast_trt_tensor( | ||
ctx, src_tensor, input.dtype, name + "_cast_value_tensor" | ||
) | ||
# scatter.src | ||
elif not (isinstance(src, TRTTensor)): | ||
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor") | ||
|
||
scatter_layer = ctx.net.add_scatter( | ||
input, index, src_tensor, trt.ScatterMode.ELEMENT | ||
) | ||
scatter_layer.axis = dim | ||
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir) | ||
out = scatter_layer.get_output(0) | ||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
# type: ignore | ||
|
||
import copy | ||
import logging | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this import used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No will remove |
||
import time | ||
import unittest | ||
|
@@ -50,16 +49,20 @@ def setUp(self): | |
def run_test( | ||
self, | ||
mod, | ||
inputs, | ||
fx_inputs, | ||
trt_interpreter_inputs, | ||
interpreter, | ||
rtol, | ||
atol, | ||
check_dtype=True, | ||
): | ||
with torch.no_grad(): | ||
cuda_inputs = [] | ||
for i in inputs: | ||
cuda_inputs.append(i.cuda()) | ||
cuda_fx_inputs = [] | ||
cuda_trt_inputs = [] | ||
for i in trt_interpreter_inputs: | ||
cuda_trt_inputs.append(i.cuda()) | ||
for i in fx_inputs: | ||
cuda_fx_inputs.append(i.cuda()) | ||
|
||
mod.eval() | ||
start = time.perf_counter() | ||
|
@@ -73,13 +76,13 @@ def run_test( | |
) | ||
|
||
mod = mod.cuda() | ||
ref_outputs = mod(*cuda_inputs) | ||
ref_outputs = mod(*cuda_fx_inputs) | ||
|
||
torch.cuda.synchronize() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
outputs = trt_mod(*cuda_inputs) | ||
outputs = trt_mod(*cuda_trt_inputs) | ||
end_event.record() | ||
torch.cuda.synchronize() | ||
_LOGGER.info( | ||
|
@@ -220,6 +223,7 @@ def run_test( | |
check_dtype=True, | ||
use_dynamo_tracer=False, | ||
enable_passes=False, | ||
int32_reqd=False, | ||
): | ||
mod.eval() | ||
mod = self.generate_graph( | ||
|
@@ -237,6 +241,32 @@ def run_test( | |
debug=True, | ||
) | ||
|
||
num_inputs = len(inputs) | ||
trt_inputs = inputs | ||
dtype_to_change = [] | ||
if int32_reqd: | ||
dtype_to_change = [torch.int64, torch.float64] | ||
else: | ||
dtype_to_change = [ | ||
torch.float64, | ||
] | ||
for num_input in range(num_inputs): | ||
input = inputs[num_input] | ||
if input.dtype in dtype_to_change: | ||
dtype_32bit = ( | ||
torch.float32 if (input.dtype == torch.float64) else torch.int32 | ||
) | ||
# should we modify graph here to insert clone nodes? | ||
# ideally not required | ||
trt_inputs = ( | ||
list(trt_inputs[:num_input]) | ||
+ [ | ||
input.to(dtype_32bit), | ||
] | ||
+ list(trt_inputs[num_input + 1 :]) | ||
) | ||
|
||
trt_input_specs = [Input.from_tensor(i) for i in trt_inputs] | ||
input_specs = [Input.from_tensor(i) for i in inputs] | ||
|
||
output_dtypes = None | ||
|
@@ -254,13 +284,15 @@ def run_test( | |
|
||
interp = TRTInterpreter( | ||
mod, | ||
input_specs, | ||
trt_input_specs, | ||
output_dtypes=output_dtypes, | ||
compilation_settings=compilation_settings, | ||
) | ||
|
||
super().run_test( | ||
mod, | ||
inputs, | ||
trt_inputs, | ||
interp, | ||
rtol, | ||
atol, | ||
|
@@ -335,4 +367,4 @@ def run_test_with_dynamic_shape( | |
# Since the lowering is based on optimal shape. We need to test with | ||
# different shape(for ex. max shape) for testing dynamic shape | ||
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] | ||
super().run_test(mod, inputs_max, interp, rtol, atol) | ||
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import torch | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
from torch_tensorrt import Input | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestScatterValueConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
( | ||
"scatter_zero_dim_indexOne_constant_value", | ||
0, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
1, | ||
), | ||
( | ||
"scatter_zero_dim_indexTwo_constant_value", | ||
0, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
1, | ||
), | ||
( | ||
"scatter_one_dim_indexOne_constant_value", | ||
1, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
1, | ||
), | ||
( | ||
"scatter_one_dim_indexTwo_costant_value", | ||
1, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
1, | ||
), | ||
] | ||
) | ||
def test_scatter_index_constant(self, _, dim, index, value): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.ops.aten.scatter.value(input, dim, index, value) | ||
|
||
input = torch.zeros(3, 5, dtype=torch.int32) | ||
inputs = [input] | ||
self.run_test(TestModule(), inputs, int32_reqd=True) | ||
|
||
@parameterized.expand( | ||
[ | ||
("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1), | ||
( | ||
"scatter_zero_dim_indexTwo_value", | ||
0, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
1, | ||
), | ||
("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1), | ||
( | ||
"scatter_one_dim_indexTwo_value", | ||
1, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
1, | ||
), | ||
] | ||
) | ||
def test_scatter_index_input(self, _, dim, index, value): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input, index): | ||
return torch.ops.aten.scatter.value(input, dim, index, value) | ||
|
||
input = torch.zeros(3, 5, dtype=torch.int32) | ||
inputs = [input, index] | ||
self.run_test(TestModule(), inputs, int32_reqd=True) | ||
|
||
|
||
class TestScatterSrcConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
( | ||
"scatter_zero_dim_indexOne_src", | ||
0, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_zero_dim_indexTwo_src", | ||
0, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_one_dim_indexOne_src", | ||
1, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_one_dim_indexTwo_src", | ||
1, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_one_dim_indexOne_constant_src", | ||
1, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_one_dim_indexTwo_constant_src", | ||
1, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), | ||
), | ||
] | ||
) | ||
def test_scatter_index_constant(self, _, dim, index, src): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.ops.aten.scatter.src(input, dim, index, src) | ||
|
||
input = torch.zeros(3, 5, dtype=torch.int32) | ||
inputs = [input] | ||
scatter = TestModule() | ||
self.run_test(TestModule(), inputs, int32_reqd=True) | ||
|
||
@parameterized.expand( | ||
[ | ||
( | ||
"scatter_zero_dim_indexOne_constant_src", | ||
0, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_zero_dim_indexTwo_constant_src", | ||
0, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_one_dim_indexOne_constant_src", | ||
1, | ||
torch.tensor([[0, 1, 2, 0]]), | ||
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), | ||
), | ||
( | ||
"scatter_one_dim_indexTwo_constant_src", | ||
1, | ||
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), | ||
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), | ||
), | ||
] | ||
) | ||
def test_scatter_index_input(self, _, dim, index, src): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input, index): | ||
return torch.ops.aten.scatter.src(input, dim, index, src) | ||
|
||
input = torch.zeros(3, 5, dtype=torch.int32) | ||
inputs = [input, index] | ||
self.run_test(TestModule(), inputs, int32_reqd=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove this.