# Tensor Indexing

In [1]:
import torch

In [2]:
batch_size = 5
features = 10
x = torch.rand((batch_size, features))
x

tensor([[2.5771e-01, 7.2878e-01, 2.5868e-01, 7.4583e-01, 2.5402e-01, 6.3390e-01,
         4.9045e-01, 1.8012e-01, 1.0019e-01, 1.3178e-01],
        [6.3346e-01, 3.4503e-01, 4.1611e-01, 4.9231e-01, 7.2861e-01, 9.0911e-02,
         3.5250e-04, 4.8047e-01, 8.6511e-01, 8.1382e-01],
        [7.9724e-01, 5.2502e-01, 6.6859e-01, 8.5960e-01, 8.2475e-01, 4.5662e-01,
         1.3679e-01, 4.6364e-02, 9.6118e-01, 7.0214e-01],
        [6.5021e-01, 5.8216e-01, 6.6198e-01, 5.4765e-01, 4.5686e-01, 3.2887e-01,
         1.3587e-01, 3.5749e-01, 7.7175e-01, 2.6793e-01],
        [8.5976e-01, 4.8549e-01, 5.7066e-01, 2.4329e-01, 6.3985e-01, 3.0036e-01,
         8.0435e-01, 7.0461e-01, 9.3837e-01, 2.9636e-01]])

In [3]:
x[0]

tensor([0.2577, 0.7288, 0.2587, 0.7458, 0.2540, 0.6339, 0.4904, 0.1801, 0.1002,
        0.1318])

In [4]:
x[:, 0]

tensor([0.2577, 0.6335, 0.7972, 0.6502, 0.8598])

In [5]:
x[2, 0:5]

tensor([0.7972, 0.5250, 0.6686, 0.8596, 0.8248])

In [6]:
x[0, 0] = 100
x

tensor([[1.0000e+02, 7.2878e-01, 2.5868e-01, 7.4583e-01, 2.5402e-01, 6.3390e-01,
         4.9045e-01, 1.8012e-01, 1.0019e-01, 1.3178e-01],
        [6.3346e-01, 3.4503e-01, 4.1611e-01, 4.9231e-01, 7.2861e-01, 9.0911e-02,
         3.5250e-04, 4.8047e-01, 8.6511e-01, 8.1382e-01],
        [7.9724e-01, 5.2502e-01, 6.6859e-01, 8.5960e-01, 8.2475e-01, 4.5662e-01,
         1.3679e-01, 4.6364e-02, 9.6118e-01, 7.0214e-01],
        [6.5021e-01, 5.8216e-01, 6.6198e-01, 5.4765e-01, 4.5686e-01, 3.2887e-01,
         1.3587e-01, 3.5749e-01, 7.7175e-01, 2.6793e-01],
        [8.5976e-01, 4.8549e-01, 5.7066e-01, 2.4329e-01, 6.3985e-01, 3.0036e-01,
         8.0435e-01, 7.0461e-01, 9.3837e-01, 2.9636e-01]])

### Fancy indexing

In [7]:
x = torch.arange(10)
indices = [2, 5, 8]
x[indices]

tensor([2, 5, 8])

In [8]:
x = torch.rand((3, 5))
x

tensor([[0.9523, 0.9659, 0.9590, 0.1427, 0.4035],
        [0.7013, 0.6542, 0.6871, 0.6641, 0.2054],
        [0.6900, 0.6252, 0.3419, 0.5304, 0.0620]])

In [9]:
rows = torch.tensor([1, 0])
cols = torch.tensor([4, 0])
x[rows, cols]

tensor([0.2054, 0.9523])

### More advanced indexing

In [10]:
x = torch.arange(10)
x[(x < 2) | (x > 8)]

tensor([0, 1, 9])

In [11]:
x[(x > 2) & (x < 8)]

tensor([3, 4, 5, 6, 7])

In [12]:
x[x.remainder(2) == 0]

tensor([0, 2, 4, 6, 8])

### Useful operations

In [13]:
torch.where(x > 5, x, x*2)

tensor([ 0,  2,  4,  6,  8, 10,  6,  7,  8,  9])

In [14]:
torch.tensor([0, 0, 1, 2, 2, 3, 4]).unique()

tensor([0, 1, 2, 3, 4])

In [15]:
x.ndimension()

1

In [16]:
x.numel()

10