In [1]:
##################################################################################### 
# torch.gather example
######################################################################################

# https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather

# torch.gather is simillar to np.choose

In [2]:
######################################################################################
# 1D tensor gather example
######################################################################################

import torch

t = torch.tensor([0, 10, 20, 30])
index = torch.tensor([0, 1, 2, 3, 2, 1, 0])
dim = 0 
torch.gather(t, dim, index)

tensor([ 0, 10, 20, 30, 20, 10,  0])

In [3]:
t.gather(dim, index)

tensor([ 0, 10, 20, 30, 20, 10,  0])

In [4]:
######################################################################################
# 2D tensor gather example
######################################################################################

import torch

# (3, 3) tensor
t = torch.tensor([[ 0,  1,  2],
                  [10, 11, 12],
                  [20, 21, 22]]) 

# index[i][j][k] < t.shape[dim]=3
index =  torch.tensor([[0, 0, 1],
                       [1, 0, 1], 
                       [2, 2, 2]]) 

dim = 0
t.gather(dim, index)

tensor([[ 0,  1, 12],
        [10,  1, 12],
        [20, 21, 22]])

In [5]:
dim = 1
index =  torch.tensor([[0, 0],
                       [1, 0], 
                       [2, 2]]) 
t.gather(dim, index)

tensor([[ 0,  0],
        [11, 10],
        [22, 22]])

In [6]:
######################################################################################
# 3D tensor gather example
######################################################################################

In [7]:
# Gathers values along an axis=0

In [8]:
import torch

dim = 0

# (2,3,4) tensor
t = torch.tensor([[[5, 5, 5, 5],
                   [7, 7, 3, 8],
                   [7, 8, 9, 8]],

                  [[5, 5, 5, 5],
                   [1, 4, 8, 8],
                   [1, 5, 1, 0]]])

# index0[i][j][k] < t.shape[dim]=2
index0 = torch.tensor([[[1, 1, 1, 0],
                        [0, 1, 0, 0],
                        [0, 0, 1, 1]],

                       [[1, 0, 1, 1],
                        [0, 0, 1, 0],
                        [1, 1, 0, 1]]])

t.gather(dim, index0) 

tensor([[[5, 5, 5, 5],
         [7, 4, 3, 8],
         [7, 8, 1, 0]],

        [[5, 5, 5, 5],
         [7, 7, 8, 8],
         [1, 5, 9, 0]]])

In [9]:
# Gathers values along an axis=1

In [10]:
import torch

dim = 1

# (2,3,4) tensor
t = torch.tensor([[[1, 2, 0, 8],
                   [1, 2, 6, 9],
                   [1, 2, 9, 9]],

                  [[5, 9, 7, 9],
                   [2, 2, 2, 6],
                   [3, 2, 9, 2]]])

# index1[i][j][k] < t.shape[dim]=3
index1 = torch.tensor( [[[1, 1, 0, 0],
                         [2, 0, 0, 1],
                         [0, 2, 2, 2]],

                        [[0, 2, 0, 2],
                         [2, 1, 2, 1],
                         [2, 1, 0, 2]]])

t.gather(dim, index1) 

tensor([[[1, 2, 0, 8],
         [1, 2, 0, 9],
         [1, 2, 9, 9]],

        [[5, 2, 7, 2],
         [3, 2, 9, 6],
         [3, 2, 7, 2]]])

In [11]:
# Gathers values along an axis=2

In [12]:
import torch

dim = 2

low = 0
high = 10
size = (2,3,4)
t = torch.randint(low, high, size)

# for simplicity
t[0,0,:] = 0
t[1,0,:] = 1
t

tensor([[[0, 0, 0, 0],
         [6, 6, 3, 0],
         [9, 3, 8, 1]],

        [[1, 1, 1, 1],
         [3, 6, 5, 7],
         [6, 9, 9, 3]]])

In [13]:
low = 0
high = t.shape[dim]
size = (2,3,4)
index2 = torch.randint(low, high, size)
index2 # (2,3,4) tensor

tensor([[[1, 1, 1, 3],
         [0, 0, 0, 2],
         [1, 2, 1, 1]],

        [[2, 0, 2, 0],
         [3, 0, 3, 2],
         [2, 3, 3, 2]]])

In [14]:
t.gather(dim, index2)

tensor([[[0, 0, 0, 0],
         [6, 6, 6, 3],
         [3, 8, 3, 3]],

        [[1, 1, 1, 1],
         [7, 3, 7, 5],
         [9, 3, 3, 9]]])