Skip to content

Commit

Permalink
addressing review comments and changing test names
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Apr 26, 2024
1 parent 855265d commit aea9d94
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 94 deletions.
23 changes: 8 additions & 15 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,28 +699,21 @@ def aten_ops_clamp(
)


@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
def aten_ops_scatter_value(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter_value(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
def aten_ops_scatter_src(
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
def aten_ops_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter_src(
return impl.select.scatter(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)

Expand Down
86 changes: 12 additions & 74 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,100 +395,38 @@ def index_select(
return gather_layer.get_output(0)


def scatter_value(
def scatter(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
value: float,
src: Union[TRTTensor, int, float],
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"scatter_tensor received input {input} that is not part "
"of the TensorRT region!"
)
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index.shape)
if not (isinstance(index, TRTTensor)):
index = get_trt_tensor(ctx, index, f"_index_tensor")
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
dim = get_positive_dim(dim, len(input_shape))
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

input_dims = len(input_shape)
for i in range(0, input_dims):
if i != dim and (index_shape[i] >= input.shape[i]):
raise RuntimeError(
f"cannot have index size greater than the input size along dimension {dim}"
)

value_tensor = get_trt_tensor(
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
)
value_tensor = cast_trt_tensor(
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
)
scatter_layer = ctx.net.add_scatter(
input, index, value_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out


def scatter_src(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Shape,
index: Shape,
src: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"scatter_tensor received input {input} that is not part "
"of the TensorRT region!"
)
input_shape = input.shape
index_shape = index.shape
src_shape = src.shape
if not (isinstance(index, TRTTensor)):
index = get_trt_tensor(ctx, index, f"_index_tensor")
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
if len(index_shape) != len(src_shape):
raise RuntimeError(f"The no of dimensions of src and index should be equal")

input_dims = len(input_shape)
dim = get_positive_dim(cast(int, dim), input_dims)
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

for i in range(0, input_dims):
if i != dim and (index_shape[i] >= input.shape[i]):
raise RuntimeError(
f"cannot have index size greater than the input size along dimension {dim}"
)
input_dtype = input.dtype
# required for cases where src is a constant
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
if input_dtype != src_dtype:
raise RuntimeError(f"The type of input and src should be made")
src_tensor = src
if not (isinstance(src, TRTTensor)):
# scatter.value
if isinstance(src, int) or isinstance(src, float):
src_tensor = get_trt_tensor(
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
)
src_tensor = cast_trt_tensor(
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
)
# scatter.src
elif not (isinstance(src, TRTTensor)):
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")

scatter_layer = ctx.net.add_scatter(
Expand Down
11 changes: 6 additions & 5 deletions tests/py/dynamo/conversion/test_scatter_aten.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from harness import DispatchTestCase
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestScatterValueConverter(DispatchTestCase):
@parameterized.expand(
Expand Down Expand Up @@ -87,25 +88,25 @@ class TestScatterSrcConverter(DispatchTestCase):
@parameterized.expand(
[
(
"scatter_zero_dim_indexOne_constant_src",
"scatter_zero_dim_indexOne_src",
0,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
),
(
"scatter_zero_dim_indexTwo_constant_src",
"scatter_zero_dim_indexTwo_src",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
),
(
"scatter_one_dim_indexOne_constant_src",
"scatter_one_dim_indexOne_src",
1,
torch.tensor([[0, 1, 2, 0]]),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
),
(
"scatter_one_dim_indexTwo_constant_src",
"scatter_one_dim_indexTwo_src",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
Expand Down

0 comments on commit aea9d94

Please sign in to comment.