Numpy arrays may be indexed by other numpy arrays or Python lists. We call this fancy indexing. The sequence that is passed as an index in square brackets contains the indices which should be selected. Have a look:

In [3]:
import numpy as np

# First let's create the array we will be working on:
A = np.linspace(2, 3, 11)
A

array([2. , 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3. ])

In [4]:
# Now let's select the elements at indices 3, 4 and 8 using another numpy array:
A[np.array([3, 4, 8])]

array([2.3, 2.4, 2.8])

In [5]:
# We can define the array beforehand:
arr = np.array([3, 4, 8])
A[arr]

array([2.3, 2.4, 2.8])

In [6]:
# Now let's try a Python list:
lst = [3, 4, 8]
A[lst]

array([2.3, 2.4, 2.8])

Now an example with a multidimensional array. Let's first create one from a function:

In [12]:
f = lambda x, y: (x + 1) ** (y + 1)
B = np.fromfunction(f, (5, 5), dtype = int)
B

array([[   1,    1,    1,    1,    1],
       [   2,    4,    8,   16,   32],
       [   3,    9,   27,   81,  243],
       [   4,   16,   64,  256, 1024],
       [   5,   25,  125,  625, 3125]], dtype=int32)

In [20]:
# Now let's create another 2 dimensional array that we will use to index the B array:
arr = np.array([2, 4])
arr

array([2, 4])

In [18]:
# Now let's index the B array with the arr array. We want to select the rows 2 and 4, as specified in the arr array.
B[arr]

array([[   3,    9,   27,   81,  243],
       [   5,   25,  125,  625, 3125]], dtype=int32)

In [19]:
# And now let's select the columns 2 and 4:
B[:, arr]

array([[   1,    1],
       [   8,   32],
       [  27,  243],
       [  64, 1024],
       [ 125, 3125]], dtype=int32)

In [21]:
# And now the elements at indices 2, 2 and 4, 4
B[arr, arr]

array([  27, 3125], dtype=int32)

EXERCISE

Create a 1-dimensional array with numbers divisible by 5 from 0 to 100 (including 100). Use the arange function to do that. Then use another numpy array to select the first, the third and the last element.

SOLUTION

In [28]:
X = np.arange(0, 101, 5)
print(X)
arr = np.array([0, 2, -1])
print(X[arr])

[  0   5  10  15  20  25  30  35  40  45  50  55  60  65  70  75  80  85
  90  95 100]
[  0  10 100]
