Skip to content

torch.scatter_: support index.size(d) > src.size(d) #63265

@ekhahniii

Description

@ekhahniii

🚀 Feature

Teach torch.Tensor.scatter_ to handle index.size(d) > src.size(d).

Motivation

Currently, torch.Tensor.scatter_ requires index.size(d) <= src.size(d) for all dimensions d, unless src is float-valued. This constraint seems artificial.

import math
import torch

device = 'cuda'  # fails if device == 'cpu' (cf. #61854)

dim = 1
shape_src = 3, 5
shape_idx = 3, 7

index = torch.empty(shape_idx, dtype=torch.long, device=device).random_(0, shape_src[dim])

# succeeds
output = torch.zeros(shape_src, device=device)
output.scatter_(dim, index, 1.0, reduce='add')
print(output)

# desired usage
# output = torch.zeros(shape_src, device=device)
# values = torch.rand(shape_src, device=device)
# output.scatter_(dim, index, values, reduce='add')   # <-- failure
# print(output)

Alternatives

An inefficient workaround is to chunk the index tensor:

output = torch.zeros(shape_src, device=device)
values = torch.ones(shape_src, device=device)
chunks = math.ceil(index.shape[dim] / output.shape[dim])
for index_chunk in torch.chunk(index, chunks, dim):
    output.scatter_(dim, index_chunk, values, reduce='add')
print(output)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixgood first issuemodule: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions