Skip to content

Commit

Permalink
Changing lowering of select_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Dec 29, 2023
1 parent a0f6b07 commit 037fbcf
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,27 @@ def select_scatter_decomposition(
dim: int,
index: int,
) -> torch.Tensor:
input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
# input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
# check if the dim is less than shape
if input_tensor.shape[dim] < index:
raise AssertionError("The index should not be greater than dim")

# expanding the src_tensor to have the same dimension as input_tensor
src_tensor = torch.expand(torch.unsqueeze(src_tensor, dim), input_tensor.shape)
input_tensor_shape = input_tensor.shape
return torch.where(torch.eq((input_tensor_shape[dim]), index)), src_tensor, input_tensor)
# check if the dimension of the src tensor is same as slice tensor
select_tensor = torch.select(input_tensor, dim, index)
if select_tensor.shape != src_tensor.shape:
raise AssertionError(
"The slice tensor shape should be equal to the src tensor shape"
)

# make the index tensor
# input_tensor_shape = input_tensor.shape
# return torch.where(torch.eq((input_tensor_shape[dim]), index), src_tensor, input_tensor)

unbind_tensors = torch.unbind(input_tensor, dim)
unbind_tensors[index] = src_tensor
return torch.cat(unbind_tensors, dim)


def get_decompositions(
Expand Down

0 comments on commit 037fbcf

Please sign in to comment.