# Advanced Indexing

Advanced indexing can be thought as indexing a tensor using tensor(s).

Assume we have a 3x3 matrix `M` and an index `idx` that select one column for each row.

In [None]:
import torch
M = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

idx = torch.tensor([2, 1, 1]) # targeting 3, 5, 8.

The following won't work as expected:

In [10]:
print(M[:, idx])

tensor([[3, 2, 2],
        [6, 5, 5],
        [9, 8, 8]])


What it does is simply printing `new_M[:, k] = M[:, idx[k]]`.

One correct way to achieve our expected result is:

In [16]:
print(M[torch.tensor([0, 1, 2]), idx])

tensor([3, 5, 8])


Alternatively, use PyTorch tensor gather [2]:

In [21]:
M.gather(dim=1, index=idx.unsqueeze(-1))

tensor([[3],
        [5],
        [8]])

To understand this introductory enxample, you will need to understand *advanced indexing* [1].

## When indices are next to each other in dimensions
Given example data x:

In [3]:
x = torch.arange(120).reshape(2, 3, 4, 5)
print(x)

tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [ 10,  11,  12,  13,  14],
          [ 15,  16,  17,  18,  19]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 30,  31,  32,  33,  34],
          [ 35,  36,  37,  38,  39]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 50,  51,  52,  53,  54],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 65,  66,  67,  68,  69],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 80,  81,  82,  83,  84],
          [ 85,  86,  87,  88,  89],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])


Assume we want to use `[[0, 1, 0, 1], [2, 3, 2, 3]]` to index the 2nd dimension (dim=2) according to the batch dimension (dim=0).

We will need to construct advance indices `idx0, idx1, idx2` for `dim=0, 1, 2` such that 

`new_x[i, j, k, :] = x[ idx0[i, j, k], idx1[i, j, k], idx2[i, j, k] ]`

thus the shape of `idx0, idx1, idx2` should be (2, 3, 4). Even if they are not, they will be broadcasted to guarantee that the indices are of the same shape during advanced indexing.

In [24]:
index_shape = x.shape[:3]

idx0 = torch.arange(2).reshape(-1, 1, 1).expand(*index_shape)
print(idx0)

idx1 = torch.arange(3).reshape(1, -1, 1).expand(*index_shape)
print(idx1)

idx2 = torch.tensor([[0, 1, 0, 1], [2, 3, 2, 3]]).unsqueeze(1).expand(*index_shape)
print(idx2)

tensor([[[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]])
tensor([[[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]],

        [[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]]])
tensor([[[0, 1, 0, 1],
         [0, 1, 0, 1],
         [0, 1, 0, 1]],

        [[2, 3, 2, 3],
         [2, 3, 2, 3],
         [2, 3, 2, 3]]])


In [25]:
print(x[idx_0, idx_1, idx_2, :])

tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49]]],


        [[[ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])


## When indices are NOT next to each other in dimensions
Advanced indices often will "merge with" the neighbor indices and replace the original subspace with index space, *unless they are **not** next to each other*:

In [26]:
shape = torch.tensor((2,3,4,5,6), dtype=torch.long)
tensor = torch.arange(shape.prod()).reshape(*shape.tolist())

idx = torch.randint(0,1, (7,8))
print(tensor[:, idx].shape)
print(tensor[:, idx, idx].shape)
print(tensor[:, idx, :, idx].shape)

torch.Size([2, 7, 8, 4, 5, 6])
torch.Size([2, 7, 8, 5, 6])
torch.Size([7, 8, 2, 4, 6])


In the last case, because the merged subspace can either be placed at the first index dimension (dim=1) or be placed at the 2nd index dimension (dim=3), it creates ambiguity. In numpy and pytorch, the protocol is to place the merged dimension to the front in this case. 

To understand the last case, let us look at a concrete example:

In [27]:
x = torch.arange(120).reshape(2, 3, 4, 5)
i = torch.tensor([0, 1, 0, 1, 0, 1]) # shape of (6,)
j = torch.tensor([0, 2, 0, 2, 0, 2]) # shape of (6,)
#     2, 3, 4, 5
y = x[:, i, :, j]
print(y.shape)

torch.Size([6, 2, 4])


Here it selects

`y[k, :, :] = x[:, i[k], :, j[k]]`

In [28]:
print(x)
print(y)

tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [ 10,  11,  12,  13,  14],
          [ 15,  16,  17,  18,  19]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 30,  31,  32,  33,  34],
          [ 35,  36,  37,  38,  39]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 50,  51,  52,  53,  54],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 65,  66,  67,  68,  69],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 80,  81,  82,  83,  84],
          [ 85,  86,  87,  88,  89],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])
tensor([[[ 0,  5, 10, 15],
         [60, 65, 70, 75]],

        [[22, 27, 32, 37],
         [82, 

## Reference
1. https://numpy.org/doc/stable/user/basics.indexing.html
2. https://docs.pytorch.org/docs/stable/generated/torch.gather.html