Skip to content

[discussion] Support other index dtypes for indexing, take, put, scatter, scatter_reduce and other indexing functions in addition to int64: int8, uint8, int16, int32, uint16 etc (without casting copies/reallocations) #61819

@vadimkantorov

Description

@vadimkantorov

I checked briefly, but could not find another issue about this. If there exists another existing one, please feel free to close this one.

I propose to support other index dtypes than current int64 for torch.scatter_ / gather / index_select and other similar indexing functions. Often we already have int16, int32 or even uint8 as indices (when these correspond to some cluster IDs and we want to aggregate by cluster), so conversion just inflates used memory.

Even if the underlying code wants to deal only with int64, could it not cast the indices internally to the desired dtype during loop/iteration? Upcasting integers is easy and maybe can be done efficiently without templating (or maybe could be jit-compiled?).

cc @nairbv @mruberry @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionableenhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: advanced indexingRelated to x[i] = y, index functionsmodule: scatter & gather opsmodule: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions