# Chapter 2 Prelimiaries

In [1]:
from d2l import torch as d2l

In [2]:
import torch

In [4]:
x = torch.arange(12, dtype=torch.float32)
x

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])

In [5]:
x.numel()

12

In [7]:
x.shape

torch.Size([12])

In [8]:
X = x.reshape(3, 4)
X

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [12]:
Y = x.reshape(-1, 4)
Y

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [13]:
Z = x.reshape(3, -1)
Z

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [15]:
X = torch.zeros((2, 3, 4))
X

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]])

In [16]:
torch.ones((4, 3, 2))

tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]])

In [18]:
# Create a tensor with elements drawn from 
# a standard Gaussian(normal) distribution 
# with mean 0 and standard deviation 1.

torch.rand(3, 4)

tensor([[0.1008, 0.7066, 0.6447, 0.1720],
        [0.3240, 0.6484, 0.5930, 0.0292],
        [0.2713, 0.5787, 0.8519, 0.1225]])

In [20]:
torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])

tensor([[2, 1, 4, 3],
        [1, 2, 3, 4],
        [4, 3, 2, 1]])

In [None]:
# Use conda forge to reinstall jax and jaxlib
!pip uninstall jax jaxlib
!conda install -c conda-forge jaxlib
!conda install -c conda-forge jax

## 2.1 Data Manipulation

In [1]:
import jax
from jax import numpy as jnp

In [2]:
x = jnp.arange(12)
x

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [3]:
x.size

12

In [5]:
x.shape

(12,)

In [7]:
X = x.reshape(3, 4)
X

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [8]:
jnp.zeros((2, 3, 4))

Array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]], dtype=float32)

In [9]:
jnp.ones((2, 3, 4))

Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)

In [10]:
# Any call of a random function in JAX requires a key to be
# specified, feeding the same key to a radom function will
# always result in the samle sample being generated.

jax.random.normal(jax.random.PRNGKey(0), (3, 4))

Array([[ 1.1901639 , -1.0996888 ,  0.44367844,  0.5984697 ],
       [-0.39189556,  0.69261974,  0.46018356, -2.068578  ],
       [-0.21438177, -0.9898306 , -0.6789304 ,  0.27362573]],      dtype=float32)

In [13]:
jax.random.normal(jax.random.PRNGKey(100), (3, 4))

Array([[ 0.6344312 , -0.3569634 , -1.5672369 , -1.0250307 ],
       [ 0.3720784 ,  1.5012454 ,  1.2544656 ,  0.10508682],
       [ 2.6520667 , -0.9850591 , -0.8260392 , -1.600352  ]],      dtype=float32)

In [14]:
jnp.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])

Array([[2, 1, 4, 3],
       [1, 2, 3, 4],
       [4, 3, 2, 1]], dtype=int32)

In [18]:
X, X[-1], X[1:3]

(Array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]], dtype=int32),
 Array([ 8,  9, 10, 11], dtype=int32),
 Array([[ 4,  5,  6,  7],
        [ 8,  9, 10, 11]], dtype=int32))

In [19]:
# JAX arrays are immutable. jax.numpy.ndarray.at index
# update operators create a new array with the corresponding
# modifications

X_new_1 = X.at[1, 2].set(17)
X_new_1

Array([[ 0,  1,  2,  3],
       [ 4,  5, 17,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [20]:
X_new_2 = X_new_1.at[:2, :].set(12)
X_new_2

Array([[12, 12, 12, 12],
       [12, 12, 12, 12],
       [ 8,  9, 10, 11]], dtype=int32)

In [21]:
jnp.exp(x)

Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
       5.4598152e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
       2.9809580e+03, 8.1030840e+03, 2.2026465e+04, 5.9874141e+04],      dtype=float32)

In [22]:
x = jnp.array([1.0, 2, 4, 8])
y = jnp.array([2, 2, 2, 2])

x + y, x - y, x * y, x / y, x ** y

(Array([ 3.,  4.,  6., 10.], dtype=float32),
 Array([-1.,  0.,  2.,  6.], dtype=float32),
 Array([ 2.,  4.,  8., 16.], dtype=float32),
 Array([0.5, 1. , 2. , 4. ], dtype=float32),
 Array([ 1.,  4., 16., 64.], dtype=float32))

In [23]:
X = jnp.arange(12, dtype=jnp.float32).reshape((3, 4))
Y = jnp.array([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])

jnp.concatenate((X, Y), axis=0), jnp.concatenate((X, Y), axis=1)

(Array([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [ 2.,  1.,  4.,  3.],
        [ 1.,  2.,  3.,  4.],
        [ 4.,  3.,  2.,  1.]], dtype=float32),
 Array([[ 0.,  1.,  2.,  3.,  2.,  1.,  4.,  3.],
        [ 4.,  5.,  6.,  7.,  1.,  2.,  3.,  4.],
        [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]], dtype=float32))

In [24]:
X == Y

Array([[False,  True, False,  True],
       [False, False, False, False],
       [False, False, False, False]], dtype=bool)

In [25]:
X.sum()

Array(66., dtype=float32)

In [26]:
a = jnp.arange(3).reshape((3, 1))
b = jnp.arange(2).reshape((1, 2))
a, b

(Array([[0],
        [1],
        [2]], dtype=int32),
 Array([[0, 1]], dtype=int32))

In [27]:
a + b

Array([[0, 1],
       [1, 2],
       [2, 3]], dtype=int32)

In [28]:
before = id(Y)
Y = Y + X
id(Y) == before

False

In [29]:
A = jax.device_get(X)
B = jax.device_put(A)
type(A), type(B)

(numpy.ndarray, jaxlib.xla_extension.ArrayImpl)

In [30]:
a = jnp.array([3.5])
a, a.item(), float(a), int(a)

TypeError: Only scalar arrays can be converted to Python scalars; got arr.ndim=1

* The tensor class is the main interface for storing and manipulating data in deep learning libraries. Tensors provide a variety of functionalities including construction routines; indexing and slicing; basic mathematics operations; broadcasting; memory-efficient assignment; and conversion to and from other Python objects.

## 2.2 Data Preprocessing

In [31]:
import os