Skip to content

Commit

Permalink
Linting fix
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed May 2, 2024
1 parent a7ef253 commit a513478
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
32 changes: 15 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,25 +409,25 @@ def scatter_value(
)
input_shape = input.shape
index_shape = index.shape
if (len(input_shape) != len(index_shape)):
raise RuntimeError(
f"The no of dimensions of input and index should be equal"
)
if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
ranks = len(input_shape)
dim = get_positive_dim(cast(int, dim), ranks)
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

input_dims = len(input.shape)
for i in range(0, input_dims):
if index[i] >= input.shape[i]:
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
)
value_tensor = value * torch.ones(index.shape)
scatter_layer = ctx.net.add_scatter(input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT)
scatter_layer = ctx.net.add_scatter(
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
)
scatter_layer.set_axis(dim)
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
Expand All @@ -452,28 +452,26 @@ def scatter_src(
input_shape = input.shape
index_shape = index.shape
src_shape = src.shape
if (len(input_shape) != len(index_shape)):
raise RuntimeError(
f"The no of dimensions of input and index should be equal"
)
if (len(index_shape) != len(src_shape)):
raise RuntimeError(
f"The no of dimensions of src and index should be equal"
)

if len(input_shape) != len(index_shape):
raise RuntimeError(f"The no of dimensions of input and index should be equal")
if len(index_shape) != len(src_shape):
raise RuntimeError(f"The no of dimensions of src and index should be equal")

input_dims = len(input_shape)
dim = get_positive_dim(cast(int, dim), input_dims)
dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"

for i in range(0, input_dims):
if index[i] >= input.shape[i]:
raise RuntimeError(
f"cannot have index greater than the dimension length! {input.shape[dim]}"
)
scatter_layer = ctx.net.add_scatter(input, index, src, trt.tensorrt.ScatterModekELEMENT)
scatter_layer = ctx.net.add_scatter(
input, index, src, trt.tensorrt.ScatterModekELEMENT
)
scatter_layer.set_axis(dim)
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
out = scatter_layer.get_output(0)
Expand Down
11 changes: 4 additions & 7 deletions tests/py/dynamo/conversion/test_scatter_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
def forward(self, input, src):
return torch.ops.aten.scatter.value(input, dim, index, value)

input = [torch.zeros(3, 5, dtype = torch.int32)]
input = [torch.zeros(3, 5, dtype=torch.int32)]
self.run_test(
TestModule(),
input,
Expand All @@ -46,14 +46,11 @@ def __init__(self):

def forward(self, input, src):
return torch.ops.aten.scatter.src(input, dim, index, src)
src = [torch.arange(1, 11).reshape((2,5))]
input = torch.zeros(3, 5, dtype = src.dtype)

src = [torch.arange(1, 11).reshape((2, 5))]
input = torch.zeros(3, 5, dtype=src.dtype)
inputs = [input, src]
self.run_test(
TestModule(),
inputs,
)



0 comments on commit a513478

Please sign in to comment.