In [21]:
import jax
from jax import numpy as jnp

In [22]:
x = jnp.arange(12)

In [23]:
x

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [24]:
x.size

12

In [25]:
x.shape

(12,)

In [26]:
X = x.reshape(3,4)
X.size

12

In [27]:
X

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [28]:
jnp.zeros((2,3,4))

Array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]], dtype=float32)

In [29]:
jnp.ones((2,3,4))

Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)

In [30]:
k = jax.random.PRNGKey(0)

jax.random.normal(k, (3,4))

Array([[ 1.1901639 , -1.0996888 ,  0.44367844,  0.5984697 ],
       [-0.39189556,  0.69261974,  0.46018356, -2.068578  ],
       [-0.21438177, -0.9898306 , -0.6789304 ,  0.27362573]],      dtype=float32)

In [31]:
jnp.array([[2,1,4,3],[1,2,3,4],[4,3,2,1]])

Array([[2, 1, 4, 3],
       [1, 2, 3, 4],
       [4, 3, 2, 1]], dtype=int32)

## Indexing and Slicing

In [32]:
X[-1], X[1:3]

(Array([ 8,  9, 10, 11], dtype=int32),
 Array([[ 4,  5,  6,  7],
        [ 8,  9, 10, 11]], dtype=int32))

In [33]:
X_new_1 = X.at[1,2].set(17)
X_new_1

Array([[ 0,  1,  2,  3],
       [ 4,  5, 17,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [34]:
X_new_2 = X_new_1.at[:2, :].set(12)
X_new_2

Array([[12, 12, 12, 12],
       [12, 12, 12, 12],
       [ 8,  9, 10, 11]], dtype=int32)

## Operations

In [35]:
jnp.exp(x)

Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
       5.4598148e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
       2.9809580e+03, 8.1030840e+03, 2.2026467e+04, 5.9874145e+04],      dtype=float32)

In [36]:
x = jnp.array([1.0, 2, 4, 8])
y = jnp.array([2, 2, 2, 2])
x+y, x - y, x * y, x / y, x ** y

(Array([ 3.,  4.,  6., 10.], dtype=float32),
 Array([-1.,  0.,  2.,  6.], dtype=float32),
 Array([ 2.,  4.,  8., 16.], dtype=float32),
 Array([0.5, 1. , 2. , 4. ], dtype=float32),
 Array([ 1.,  4., 16., 64.], dtype=float32))

In [37]:
X = jnp.arange(12, dtype=jnp.float32).reshape(3,4)
Y = jnp.array([[2.0, 1, 4, 3], [1,2,3,4],[4,3,2,1]])
jnp.concatenate((X,Y), axis=0), jnp.concatenate((X, Y), axis=1)

(Array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [ 2.,  1.,  4.,  3.],
        [ 1.,  2.,  3.,  4.],
        [ 4.,  3.,  2.,  1.]], dtype=float32),
 Array([[ 0.,  1.,  2.,  3.,  2.,  1.,  4.,  3.],
        [ 4.,  5.,  6.,  7.,  1.,  2.,  3.,  4.],
        [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]], dtype=float32))

In [38]:
X == Y

Array([[False,  True, False,  True],
       [False, False, False, False],
       [False, False, False, False]], dtype=bool)

In [39]:
X.sum()

Array(66., dtype=float32)

## Broadcasting

In [40]:
a = jnp.arange(3).reshape((3,1))
b = jnp.arange(2).reshape((1,2))
a,b

(Array([[0],
        [1],
        [2]], dtype=int32),
 Array([[0, 1]], dtype=int32))

In [41]:
a + b

Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

## Saving memory

In [42]:
before = id(Y)

In [43]:
Y = Y + X

In [44]:
id(Y) == before

False

## Convertion to other python objects

In [45]:
A = jax.device_get(X)
B = jax.device_put(A)

type(A), type(B)

(numpy.ndarray, jaxlib.xla_extension.ArrayImpl)

In [48]:
a = jnp.array([3.5])
a, a.item()

(Array([3.5], dtype=float32), 3.5)