Skip to content

Commit

Permalink
scatter adding test cases for scatter.value and scatter.src
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed May 2, 2024
1 parent a513478 commit ae08f41
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 57 deletions.
16 changes: 2 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def aten_ops_scatter_value(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter_value(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


Expand All @@ -713,19 +713,7 @@ def aten_ops_scatter_src(
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.scatter_src(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
)


def aten_ops_select(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.select(
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
)


Expand Down
48 changes: 34 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -398,8 +399,8 @@ def scatter_value(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: Shape,
index: Shape,
dim: int,
index: Union[TRTTensor, np.ndarray, torch.Tensor],
value: float,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
Expand All @@ -409,26 +410,34 @@ def scatter_value(
)
input_shape = input.shape
index_shape = index.shape
index_shape_list = list(index.shape)
if not (isinstance(index, TRTTensor)):
index = get_trt_tensor(ctx, index, f"_index_tensor")
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
ranks = len(input_shape)
dim = get_positive_dim(cast(int, dim), ranks)
dim = get_positive_dim(dim, len(input_shape))
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

input_dims = len(input.shape)
input_dims = len(input_shape)
for i in range(0, input_dims):
if index[i] >= input.shape[i]:
if i != dim and (index_shape[i] >= input.shape[i]):
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
f"cannot have index size greater than the input size along dimension {dim}"
)
value_tensor = value * torch.ones(index.shape)

value_tensor = get_trt_tensor(
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
)
value_tensor = cast_trt_tensor(
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
)
scatter_layer = ctx.net.add_scatter(
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
input, index, value_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.set_axis(dim)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
Expand All @@ -452,6 +461,8 @@ def scatter_src(
input_shape = input.shape
index_shape = index.shape
src_shape = src.shape
if not (isinstance(index, TRTTensor)):
index = get_trt_tensor(ctx, index, f"_index_tensor")
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
if len(index_shape) != len(src_shape):
Expand All @@ -465,14 +476,23 @@ def scatter_src(
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

for i in range(0, input_dims):
if index[i] >= input.shape[i]:
if i != dim and (index_shape[i] >= input.shape[i]):
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
f"cannot have index size greater than the input size along dimension {dim}"
)
input_dtype = input.dtype
# required for cases where src is a constant
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
if input_dtype != src_dtype:
raise RuntimeError(f"The type of input and src should be made")
src_tensor = src
if not (isinstance(src, TRTTensor)):
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")

scatter_layer = ctx.net.add_scatter(
input, index, src, trt.tensorrt.ScatterModekELEMENT
input, index, src_tensor, trt.ScatterMode.ELEMENT
)
scatter_layer.set_axis(dim)
scatter_layer.axis = dim
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
return out
47 changes: 37 additions & 10 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# type: ignore

import copy
import logging
import time
import unittest
Expand All @@ -14,6 +13,9 @@
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule

Expand Down Expand Up @@ -50,16 +52,20 @@ def setUp(self):
def run_test(
self,
mod,
inputs,
fx_inputs,
trt_interpreter_inputs,
interpreter,
rtol,
atol,
check_dtype=True,
):
with torch.no_grad():
cuda_inputs = []
for i in inputs:
cuda_inputs.append(i.cuda())
cuda_fx_inputs = []
cuda_trt_inputs = []
for i in trt_interpreter_inputs:
cuda_trt_inputs.append(i.cuda())
for i in fx_inputs:
cuda_fx_inputs.append(i.cuda())

mod.eval()
start = time.perf_counter()
Expand All @@ -73,13 +79,13 @@ def run_test(
)

mod = mod.cuda()
ref_outputs = mod(*cuda_inputs)
ref_outputs = mod(*cuda_fx_inputs)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
outputs = trt_mod(*cuda_inputs)
outputs = trt_mod(*cuda_trt_inputs)
end_event.record()
torch.cuda.synchronize()
_LOGGER.info(
Expand Down Expand Up @@ -237,6 +243,25 @@ def run_test(
debug=True,
)

num_inputs = len(inputs)
trt_inputs = inputs
for num_input in range(num_inputs):
input = inputs[num_input]
if input.dtype in (torch.int64, torch.float64):
dtype_32bit = (
torch.int32 if (input.dtype == torch.int64) else torch.int64
)
# should we modify graph here to insert clone nodes?
# ideally not required
trt_inputs = (
list(trt_inputs[:num_input])
+ [
input.to(dtype_32bit),
]
+ list(trt_inputs[num_input + 1 :])
)

trt_input_specs = [Input.from_tensor(i) for i in trt_inputs]
input_specs = [Input.from_tensor(i) for i in inputs]

output_dtypes = None
Expand All @@ -245,7 +270,7 @@ def run_test(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
truncate_long_and_double=compilation_settings.truncate_long_and_double,
)

_LOGGER.debug(f"Compilation settings: {compilation_settings}")
Expand All @@ -254,13 +279,15 @@ def run_test(

interp = TRTInterpreter(
mod,
input_specs,
trt_input_specs,
output_dtypes=output_dtypes,
compilation_settings=compilation_settings,
)

super().run_test(
mod,
inputs,
trt_inputs,
interp,
rtol,
atol,
Expand Down

0 comments on commit ae08f41

Please sign in to comment.