In [1]:
from jax import numpy as jnp

## Scalars

In [2]:
x = jnp.array(3.0)
y = jnp.array(2.0)

x+y, x*y, x/y, x**y

(Array(5., dtype=float32, weak_type=True),
 Array(6., dtype=float32, weak_type=True),
 Array(1.5, dtype=float32, weak_type=True),
 Array(9., dtype=float32, weak_type=True))

## Vectors

In [3]:
x = jnp.arange(3)
x

Array([0, 1, 2], dtype=int32)

In [4]:
x[2]

Array(2, dtype=int32)

In [7]:
len(x), x.size, x.shape

(3, 3, (3,))

## Matrices

In [8]:
A = jnp.arange(6).reshape(3,2)

In [9]:
A

Array([[0, 1],
       [2, 3],
       [4, 5]], dtype=int32)

In [10]:
A.T

Array([[0, 2, 4],
       [1, 3, 5]], dtype=int32)

In [11]:
# Symmetric matrix
A = jnp.array([[1,2,3],[2,0,4],[3,4,5]])
A == A.T

Array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]], dtype=bool)

## Tensors - order>2

In [12]:
jnp.arange(24).reshape(2,3,4)

Array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]], dtype=int32)

## Basic Properties of Tensor Arithmetic

In [13]:
A = jnp.arange(6, dtype=jnp.float32).reshape(2,3)
B = A
A, A+B

(Array([[0., 1., 2.],
        [3., 4., 5.]], dtype=float32),
 Array([[ 0.,  2.,  4.],
        [ 6.,  8., 10.]], dtype=float32))

In [14]:
A * B

Array([[ 0.,  1.,  4.],
       [ 9., 16., 25.]], dtype=float32)

In [15]:
a = 2
X = jnp.arange(24).reshape(2,3,4)
a+X, (a*X).shape

(Array([[[ 2,  3,  4,  5],
         [ 6,  7,  8,  9],
         [10, 11, 12, 13]],
 
        [[14, 15, 16, 17],
         [18, 19, 20, 21],
         [22, 23, 24, 25]]], dtype=int32),
 (2, 3, 4))

## Tensor reduction

In [16]:
x = jnp.arange(3, dtype=jnp.float32)
x, x.sum()

(Array([0., 1., 2.], dtype=float32), Array(3., dtype=float32))

In [17]:
A.shape, A.sum()

((2, 3), Array(15., dtype=float32))

In [18]:
A.shape, A.sum(axis=0).shape

((2, 3), (3,))

In [19]:
A.shape, A.sum(axis=1).shape

((2, 3), (2,))

In [20]:
A.sum(axis=[0,1]) == A.sum()

Array(True, dtype=bool)

In [21]:
A.mean(), A.sum() / A.size

(Array(2.5, dtype=float32), Array(2.5, dtype=float32))

In [22]:
A.mean(axis=0), A.sum(axis=0) / A.shape[0]

(Array([1.5, 2.5, 3.5], dtype=float32), Array([1.5, 2.5, 3.5], dtype=float32))

## Non-reduction Sum

In [23]:
sum_A = A.sum(axis=1, keepdims=True)
sum_A, sum_A.shape

(Array([[ 3.],
        [12.]], dtype=float32),
 (2, 1))

In [24]:
A / sum_A

Array([[0.        , 0.33333334, 0.6666667 ],
       [0.25      , 0.33333334, 0.4166667 ]], dtype=float32)

In [25]:
A.cumsum(axis=0), A

(Array([[0., 1., 2.],
        [3., 5., 7.]], dtype=float32),
 Array([[0., 1., 2.],
        [3., 4., 5.]], dtype=float32))

## Dot Products

In [26]:
y = jnp.ones(3, dtype=jnp.float32)
x, y, jnp.dot(x, y)

(Array([0., 1., 2.], dtype=float32),
 Array([1., 1., 1.], dtype=float32),
 Array(3., dtype=float32))

In [27]:
jnp.sum(x * y)

Array(3., dtype=float32)

## Matrix Vector Products

In [28]:
A.shape, x.shape, jnp.matmul(A, x)

((2, 3), (3,), Array([ 5., 14.], dtype=float32))

## Matrix-Matrix Multiplication

In [29]:
B = jnp.ones((3,4))

In [30]:
jnp.matmul(A, B)

Array([[ 3.,  3.,  3.,  3.],
       [12., 12., 12., 12.]], dtype=float32)

## Norms

In [31]:
u = jnp.array([3.0,-4.0])
jnp.linalg.norm(u), jnp.sqrt((u*u).sum())

(Array(5., dtype=float32), Array(5., dtype=float32))

In [32]:
jnp.linalg.norm(u, ord=1), jnp.abs(u).sum()

(Array(7., dtype=float32), Array(7., dtype=float32))

In [34]:
jnp.linalg.norm(jnp.ones((4,9)))

Array(6., dtype=float32)