In [2]:
import torch

t = torch.arange(1, 25)
t = t.reshape(4, 3, 2)
t

tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],

        [[ 7,  8],
         [ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16],
         [17, 18]],

        [[19, 20],
         [21, 22],
         [23, 24]]])

In [3]:
t.shape

torch.Size([4, 3, 2])

In [None]:
# Which means, axis 0 has 4 elements, 
# axis 1 has 3 elements, 
# and axis 2 has 2 elements.

# torch.Size([(axis 0), (axis 1), (axis 2)])

In [None]:
# Axis 2, here in 3D matrix, is the 'simplest' or 'least layered' axis.
# It is the axis in parallel with the line formed by element [1] and [2].

# Axis 1 is the axis in parallel with the line formed by element [1] and [5].

# Axis 0 is ALWAYS the 'most complicated' or 'most layered' axis.
# It is the axis in parallel with the line formed by element [1] and [19].

## index_select - 2D

In [5]:
# Let's try an index_select() example:

A = torch.Tensor([[1, 2],
                  [3, 4]])

index = torch.tensor([0])

output = torch.index_select(A, 1, index)

output

# Axis specified is axis 1.

# A is a 2D matrix, so axis 1 is the 'least layered' axis, 
# which means that it is the axis in parallel 
# with the line formed by element [1] and [2].

# Within those lines, elements with index 0 are selected.

tensor([[1.],
        [3.]])

In [6]:
output = torch.index_select(A, 0, index)

output

tensor([[1., 2.]])

## index_select - 3D

In [8]:
A = torch.arange(1, 9)
A = A.reshape(2, 2, 2)
A

tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])

In [12]:
output = torch.index_select(A, 2, index)
output

# Check the axis orientation.
# It is parallel with the line formed by elements [1] and [2].
# This is the 'least layered' axis.

tensor([[[1],
         [3]],

        [[5],
         [7]]])

In [13]:
output = torch.index_select(A, 1, index)
output

# Check the axis orientation.
# It is parallel with the line formed by elements [1] and [3].

tensor([[[1, 2]],

        [[5, 6]]])

In [14]:
output = torch.index_select(A, 0, index)
output

# Check the axis orientation.
# It is parallel with the line formed by elements [1] and [5].
# This is the 'most layered' axis.

tensor([[[1, 2],
         [3, 4]]])

## Unsqueeze 

In [24]:
t = torch.arange(1, 5)
t = t.reshape(2, 2)
print(t.shape)
print(t)

torch.Size([2, 2])
tensor([[1, 2],
        [3, 4]])


In [25]:
output = t.unsqueeze(0)
print(output.shape)
print(output)

# Dimension of size one is inserted at the specified axis - axis 0,
# which is the 'most layered' axis.

torch.Size([1, 2, 2])
tensor([[[1, 2],
         [3, 4]]])


In [26]:
output = t.unsqueeze(1)
print(output.shape)
print(output)

torch.Size([2, 1, 2])
tensor([[[1, 2]],

        [[3, 4]]])


## Scientific convention

In [27]:
t = torch.arange(1, 7)
t = t.reshape(2, 3)
t

tensor([[1, 2, 3],
        [4, 5, 6]])

In [30]:
# The scientific convention is the call a 2D matrix as
# '(number of rows) by (number of cols) matrix'.

# In this case, t is a 2 by 3 matrix.

# And the 'shape' of t in PyTorch is:
t.shape

torch.Size([2, 3])