In [1]:
import torch

## One dimensional example

To focus on the basic idea, we will use a one dimensional example.

In [2]:
data_input = torch.tensor([10,20,30,40])
scatter_index = torch.tensor([3,2,1,0])

In [4]:
output = torch.zeros((5),dtype=data_input.dtype)

Each data point in the input tensor has an index(position). e.g. 10 has the idx 0 <br>
Using that index(position) but in the scatter index, we get the value 3 <br>
That 3 is the location within the output where the input value, 10, needs to be placed

In [5]:
output.scatter_(0, scatter_index, data_input)

tensor([40, 30, 20, 10,  0])

## Two dimensional example

Now it becomes more challenging to wrap your head around

One way to think about it is that it is taking the input value and reassigning the idx but only in the dimension specified. Example: you have value 4 at idx [1,3,2] in the input. [1,3,2] corresponds to value 0 in the index tensor. Now we use scatter with dim=1. You will now get value 4 at index [1,0,2] in the output.

In [189]:
src_eg = torch.arange(1,11).reshape(2,5)
idx_eg = torch.tensor([[1,0,2,0]])

In [190]:
src_eg

tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])

In [193]:
idx_eg

tensor([[1, 0, 2, 0]])

In [191]:
out_eg = torch.zeros(3,5, dtype=src_eg.dtype)

We are using scatter with dim 0, so only the source row will change not the column

In [194]:
out_eg.scatter(0,idx_eg, src_eg)

tensor([[0, 2, 0, 4, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 3, 0, 0]])

In [199]:
idx_eg2 = torch.tensor([[1,0,3,2]])

Since we are using dim 1, only the column will change

In [201]:
out_eg.scatter(1,idx_eg2, src_eg)

tensor([[2, 1, 4, 3, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]])

Since we had an index with only one row, only elements from the first row are included. They have kept their row idx( that is 0) and have been reassigned columns

If we add in a second row to the index, these will correspond with the second row of the source. They will occupy the second row of the output (since row is not the dimension we are scattering across) but may change in order depending on the index values

In [206]:
idx_eg3 = torch.tensor([[1,0,3,2],[2,4,3,1]])

In [207]:
out_eg.scatter(1,idx_eg3, src_eg)

tensor([[2, 1, 4, 3, 0],
        [0, 9, 6, 8, 7],
        [0, 0, 0, 0, 0]])

Now it becomes more challenging to wrap your head around

In [135]:
data_input = torch.tensor([[10,20,30,40],[100,200,300,400]])
scatter_index = torch.tensor([[0,1,2,3]])

In [136]:
output = torch.tensor([[11,12,13,14],[21,22,23,34],[31,32,33,34],[41,42,43,44]],dtype=data_input.dtype)

In [137]:
output

tensor([[11, 12, 13, 14],
        [21, 22, 23, 34],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])

In [138]:
data_input

tensor([[ 10,  20,  30,  40],
        [100, 200, 300, 400]])

In [139]:
scatter_index

tensor([[0, 1, 2, 3]])

In [140]:
output = torch.zeros((4,4),dtype=data_input.dtype)

In [141]:
print(f'Input value at[0,0]: {data_input[0,0]}')
print()
print(f'Index value at[0,0]: {scatter_index[0,0]}')
print()
print('Therefore 10 needs to be placed at position/idx 3 along whatever dimension we are working on')
print('Dimension 0')

Input value at[0,0]: 10

Index value at[0,0]: 0

Therefore 10 needs to be placed at position/idx 3 along whatever dimension we are working on
Dimension 0


In [142]:
scatter_index.size(0)<=data_input.size(0)

True

In [143]:
output[0,:]

tensor([0, 0, 0, 0])

In [144]:
output.scatter_(0, scatter_index, data_input)

tensor([[10,  0,  0,  0],
        [ 0, 20,  0,  0],
        [ 0,  0, 30,  0],
        [ 0,  0,  0, 40]])

In [145]:
scatter_index.size(1)<=data_input.size(1)

True

In [146]:
output.scatter_(1, scatter_index, data_input)

tensor([[10, 20, 30, 40],
        [ 0, 20,  0,  0],
        [ 0,  0, 30,  0],
        [ 0,  0,  0, 40]])