You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[discussion] Support other index dtypes for indexing, take, put, scatter, scatter_reduce and other indexing functions in addition to int64: int8, uint8, int16, int32, uint16 etc (without casting copies/reallocations) #61819
I checked briefly, but could not find another issue about this. If there exists another existing one, please feel free to close this one.
I propose to support other index dtypes than current int64 for torch.scatter_ / gather / index_select and other similar indexing functions. Often we already have int16, int32 or even uint8 as indices (when these correspond to some cluster IDs and we want to aggregate by cluster), so conversion just inflates used memory.
Even if the underlying code wants to deal only with int64, could it not cast the indices internally to the desired dtype during loop/iteration? Upcasting integers is easy and maybe can be done efficiently without templating (or maybe could be jit-compiled?).