-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Open
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixgood first issuemodule: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 Feature
Teach torch.Tensor.scatter_ to handle index.size(d) > src.size(d)
.
Motivation
Currently, torch.Tensor.scatter_
requires index.size(d) <= src.size(d)
for all dimensions d
, unless src
is float-valued. This constraint seems artificial.
import math
import torch
device = 'cuda' # fails if device == 'cpu' (cf. #61854)
dim = 1
shape_src = 3, 5
shape_idx = 3, 7
index = torch.empty(shape_idx, dtype=torch.long, device=device).random_(0, shape_src[dim])
# succeeds
output = torch.zeros(shape_src, device=device)
output.scatter_(dim, index, 1.0, reduce='add')
print(output)
# desired usage
# output = torch.zeros(shape_src, device=device)
# values = torch.rand(shape_src, device=device)
# output.scatter_(dim, index, values, reduce='add') # <-- failure
# print(output)
Alternatives
An inefficient workaround is to chunk the index tensor:
output = torch.zeros(shape_src, device=device)
values = torch.ones(shape_src, device=device)
chunks = math.ceil(index.shape[dim] / output.shape[dim])
for index_chunk in torch.chunk(index, chunks, dim):
output.scatter_(dim, index_chunk, values, reduce='add')
print(output)
Zettelkasten
Metadata
Metadata
Assignees
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixgood first issuemodule: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module