### Understanding the torch.Gather Functionality

Basic Functionality:

In [44]:
import torch

'''

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1

'''

indexes = torch.Tensor([[1,1],
                        [0,0]])
tensor = torch.Tensor([[1,1],
                       [2,2]]) # size n_link x 2


torch.gather(tensor, 0, indexes.long())



tensor([[2., 2.],
        [1., 1.]])

If the dimensions between the tensors aren't exact - PyTorch drops the undefined columns in the output tensor, it still pulls from all entries of the input tensor



In [51]:
indexes = torch.Tensor([[1],
                        [0]])
tensor = torch.Tensor([[1,2],
                       [3,4]]) 

''' 
If doing lookup on dim = 0
 ___                                   ___
|                                         |
|  T[indexes[0,0],0] , T[indexes[0,1],1]  |
|                                         |
|  T[indexes[1,0],0] , T[indexes[1,1],1]  |
|___                                   ___|

The last column is not defined in the index, so it is ignored

'''

print(torch.gather(tensor, 0, indexes.long()))
print(torch.gather(tensor, 1, indexes.long()))

tensor([[3.],
        [1.]])
tensor([[2.],
        [3.]])


### For the model implementation

The input tensor (base) is a the link hidden state matrix.
The link Ids list is the flattened list of the paths which are defined by the links that make up the path
- e.g. [0,1,1,2,3,...]

We want to map the state array from to a state of links on paths array : [s_0,s_1,s_1,s_2,s_3]

Below is how we accomplish that

In [52]:
base = torch.zeros(4,6)
base[:,0] = torch.Tensor([1,2,2,4])

In [53]:
link_ids = torch.tensor([[0,0,0,0,0,0],
                         [3,3,3,3,3,3],
                         [2,2,2,2,2,2],
                         [1,1,1,1,1,1]])

In [54]:
torch.gather(base,0,link_ids)

tensor([[1., 0., 0., 0., 0., 0.],
        [4., 0., 0., 0., 0., 0.],
        [2., 0., 0., 0., 0., 0.],
        [2., 0., 0., 0., 0., 0.]])

### Index Put

In [66]:
base = torch.zeros(3,6)
base[:,0] = torch.Tensor([1,2,3])
link_ids = torch.tensor([[0,0,0,0,0,0],
                         [1,1,1,1,1,1],
                         [2,2,2,2,2,2],
                         [1,1,1,1,1,1]])

state_array = torch.gather(base,0,link_ids)
print(state_array)

p_id = torch.Tensor([0,0,1,1]).long()
s_id = torch.Tensor([0,1,0,1]).long()

path_array = torch.zeros(2,2,6)
path_array.index_put_([p_id,s_id],state_array)


tensor([[1., 0., 0., 0., 0., 0.],
        [2., 0., 0., 0., 0., 0.],
        [3., 0., 0., 0., 0., 0.],
        [2., 0., 0., 0., 0., 0.]])
tensor([[[1., 0., 0., 0., 0., 0.],
         [2., 0., 0., 0., 0., 0.]],

        [[3., 0., 0., 0., 0., 0.],
         [2., 0., 0., 0., 0., 0.]]]) torch.Size([2, 2, 6])


In [65]:
print([p_id,s_id])

[tensor([0, 0, 1, 1]), tensor([0, 1, 0, 1])]
