In [1]:
import jax.numpy as jnp

print("Matrix Multiplication")
print("=" * 60)

# Basic matrix multiplication
# A has 3 rows and 2 columns
A = jnp.array([[1, 2],
               [3, 4],
               [5, 6]])  # Shape (3, 2)

# B has 2 rows and 3 columns
B = jnp.array([[7, 8, 9],
               [10, 11, 12]])  # Shape (2, 3)

# Using @ operator (recommended for matrix multiplication)
# This computes the dot product of rows of A with columns of B.
# Inner dimensions must match: (3, 2) @ (2, 3) -> (3, 3)
C = A @ B
print(f"A (3, 2):\n{A}\n")
print(f"B (2, 3):\n{B}\n")
print(f"A @ B (3, 3):\n{C}")
print(f"\nShape check: {A.shape} @ {B.shape} = {C.shape}")

Matrix Multiplication
A (3, 2):
[[1 2]
 [3 4]
 [5 6]]

B (2, 3):
[[ 7  8  9]
 [10 11 12]]

A @ B (3, 3):
[[ 27  30  33]
 [ 61  68  75]
 [ 95 106 117]]

Shape check: (3, 2) @ (2, 3) = (3, 3)


# JAX Deep Learning Tutorial: From Zero to Hero

**A Comprehensive Introduction to Tensors, Operations, and Neural Networks with JAX**

---

## Table of Contents

1. [Introduction to JAX](#1-introduction-to-jax)
2. [Tensor Fundamentals](#2-tensor-fundamentals)
3. [Tensor Operations](#3-tensor-operations)
4. [Linear Algebra for Deep Learning](#4-linear-algebra-for-deep-learning)
5. [Einstein Summation (einsum)](#5-einstein-summation-einsum)
6. [Automatic Differentiation](#6-automatic-differentiation)
7. [JIT Compilation](#7-jit-compilation)
8. [Vectorization with vmap](#8-vectorization-with-vmap)
9. [Building Neural Networks from Scratch](#9-building-neural-networks-from-scratch)
10. [Common Deep Learning Patterns](#10-common-deep-learning-patterns)
11. [Practical Examples](#11-practical-examples)

---

## 1. Introduction to JAX

### What is JAX?

**JAX** is a high-performance numerical computing library developed by Google that combines:

- **NumPy-like API**: Familiar syntax for array operations
- **Automatic Differentiation**: `grad()` for computing gradients
- **JIT Compilation**: `jit()` for XLA-accelerated code
- **Vectorization**: `vmap()` for automatic batching
- **Parallelization**: `pmap()` for multi-device computation

### Why JAX for Deep Learning?

| Feature | NumPy | PyTorch | TensorFlow | JAX |
|---------|-------|---------|------------|-----|
| NumPy-like API | ✓ | ~ | ~ | ✓ |
| Auto-differentiation | ✗ | ✓ | ✓ | ✓ |
| JIT Compilation | ✗ | ✓ | ✓ | ✓ |
| Functional paradigm | ~ | ✗ | ✗ | ✓ |
| Pure functions | ✗ | ✗ | ✗ | ✓ |
| Composable transforms | ✗ | ~ | ~ | ✓ |

### JAX Philosophy: Functional Programming

JAX embraces **pure functions** - functions that:
1. Always produce the same output for the same input
2. Have no side effects (don't modify external state)

This enables powerful **function transformations** like `grad`, `jit`, and `vmap`.

In [2]:
# Install JAX (run this cell first in Colab)
# For CPU:
!pip install -q jax jaxlib

# For GPU (uncomment if using GPU runtime in Colab):
# !pip install -q jax[cuda12]

In [3]:
# Core imports
import jax
import jax.numpy as jnp  # JAX's version of NumPy, usually imported as jnp
from jax import grad, jit, vmap, random  # Key JAX transformations
import numpy as np  # Standard NumPy, often used for data loading/preprocessing
import matplotlib.pyplot as plt  # For plotting visualizations

# Check JAX version and devices
# This helps ensure we are running on the expected hardware (CPU/GPU/TPU)
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

JAX version: 0.7.2
Available devices: [CpuDevice(id=0)]
Default backend: cpu


---

## 2. Tensor Fundamentals

In JAX, tensors are called **arrays** (similar to NumPy's `ndarray`). They are the fundamental data structure for all computations.

### 2.1 What is a Tensor?

A **tensor** is a multi-dimensional array of numbers:

```
Scalar (0D):    5                           shape: ()
Vector (1D):    [1, 2, 3]                   shape: (3,)
Matrix (2D):    [[1, 2], [3, 4]]            shape: (2, 2)
3D Tensor:      [[[1,2],[3,4]],[[5,6],[7,8]]] shape: (2, 2, 2)
```

### Visual Representation:

```
0D (Scalar)     1D (Vector)      2D (Matrix)         3D (Tensor)
                                                      
    [5]         [1, 2, 3]        [[1, 2, 3],         [[[1, 2],
                                  [4, 5, 6]]           [3, 4]],
                                                       [[5, 6],
                                                        [7, 8]]]
```

### 2.2 Creating Arrays

In [4]:
# ============================================
# METHOD 1: From Python lists
# ============================================

# Scalar (0-dimensional)
scalar = jnp.array(5)
print(f"Scalar: {scalar}")
print(f"  Shape: {scalar.shape}")
print(f"  Dimensions: {scalar.ndim}")
print()

Scalar: 5
  Shape: ()
  Dimensions: 0



In [5]:
# Vector (1-dimensional)
vector = jnp.array([1, 2, 3, 4, 5])
print(f"Vector: {vector}")
print(f"  Shape: {vector.shape}")
print(f"  Dimensions: {vector.ndim}")
print()

Vector: [1 2 3 4 5]
  Shape: (5,)
  Dimensions: 1



In [6]:
# Matrix (2-dimensional)
matrix = jnp.array([[1, 2, 3],
                    [4, 5, 6]])
print(f"Matrix:\n{matrix}")
print(f"  Shape: {matrix.shape}  (rows, cols)")
print(f"  Dimensions: {matrix.ndim}")
print()

Matrix:
[[1 2 3]
 [4 5 6]]
  Shape: (2, 3)  (rows, cols)
  Dimensions: 2



In [7]:
# 3D Tensor
tensor_3d = jnp.array([[[1, 2], [3, 4]],
                       [[5, 6], [7, 8]]])
print(f"3D Tensor:\n{tensor_3d}")
print(f"  Shape: {tensor_3d.shape}  (depth, rows, cols)")
print(f"  Dimensions: {tensor_3d.ndim}")

3D Tensor:
[[[1 2]
  [3 4]]

 [[5 6]
  [7 8]]]
  Shape: (2, 2, 2)  (depth, rows, cols)
  Dimensions: 3


In [8]:
# ============================================
# METHOD 2: Using creation functions
# ============================================

print("=" * 50)
print("Array Creation Functions")
print("=" * 50)

# Zeros and Ones
zeros = jnp.zeros((3, 4))
ones = jnp.ones((2, 3))
print(f"\nzeros((3, 4)):\n{zeros}")
print(f"\nones((2, 3)):\n{ones}")

Array Creation Functions

zeros((3, 4)):
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]

ones((2, 3)):
[[1. 1. 1.]
 [1. 1. 1.]]


In [9]:
# Full (fill with specific value)
full = jnp.full((2, 3), fill_value=7)
print(f"full((2, 3), 7):\n{full}")

full((2, 3), 7):
[[7 7 7]
 [7 7 7]]


In [10]:
# Identity matrix
eye = jnp.eye(4)
print(f"\neye(4) - Identity matrix:\n{eye}")


eye(4) - Identity matrix:
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


In [11]:
# Ranges
arange = jnp.arange(0, 10, 2)  # start, stop, step
linspace = jnp.linspace(0, 1, 5)  # start, stop, num_points
print(f"\narange(0, 10, 2): {arange}")
print(f"linspace(0, 1, 5): {linspace}")


arange(0, 10, 2): [0 2 4 6 8]
linspace(0, 1, 5): [0.   0.25 0.5  0.75 1.  ]


In [12]:
# ============================================
# METHOD 3: Random arrays (JAX style)
# ============================================

print("\n" + "=" * 50)
print("Random Arrays (JAX uses explicit PRNG keys!)")
print("=" * 50)

# JAX requires explicit random keys for reproducibility.
# Unlike NumPy, there is no global random state (making it 'pure').
key = random.PRNGKey(42)  # Initialize with a seed (e.g., 42)

# To generate new random numbers, we must split the key.
# 'key' is the source, 'subkey' is used for the actual operation.
key, subkey = random.split(key)  # Split key for each use
uniform = random.uniform(subkey, shape=(2, 3)) # Generate uniform numbers between 0 and 1
print(f"\nUniform [0, 1):\n{uniform}")


Random Arrays (JAX uses explicit PRNG keys!)

Uniform [0, 1):
[[0.72766423 0.78786755 0.18169427]
 [0.26263022 0.11072934 0.20263076]]


In [13]:
# Normal distribution (mean=0, std=1)
key, subkey = random.split(key)
normal = random.normal(subkey, shape=(2, 3))
print(f"Normal (mean=0, std=1):\n{normal}")

Normal (mean=0, std=1):
[[-0.21089035 -1.3627948  -0.04500385]
 [-1.1536394   1.9141139  -0.47701314]]


In [14]:
# Random integers
key, subkey = random.split(key)
randint = random.randint(subkey, shape=(3, 3), minval=0, maxval=10)
print(f"\nRandom integers [0, 10):\n{randint}")


Random integers [0, 10):
[[0 2 0]
 [0 8 2]
 [1 2 4]]


### 2.3 Understanding JAX's Random Number Generation

Unlike NumPy/PyTorch, JAX uses **explicit PRNG (Pseudo-Random Number Generator) keys**:

```python
# NumPy style (implicit state - NOT pure)
np.random.seed(42)
x = np.random.randn(3)  # Modifies hidden global state

# JAX style (explicit state - pure functions)
key = random.PRNGKey(42)
key, subkey = random.split(key)  # Get new key for next use
x = random.normal(subkey, (3,))  # No hidden state!
```

**Why?** This makes JAX functions **pure** and enables JIT compilation.

In [15]:
# Demonstration: Same key = same random numbers (reproducible!)
key1 = random.PRNGKey(0)
key2 = random.PRNGKey(0)

print("Same key produces same numbers:")
print(f"  Key 1: {random.normal(key1, (3,))}")
print(f"  Key 2: {random.normal(key2, (3,))}")

# Different keys = different numbers
key3 = random.PRNGKey(1)
print(f"  Key 3 (different seed): {random.normal(key3, (3,))}")

Same key produces same numbers:
  Key 1: [ 1.6226422   2.0252647  -0.43359444]
  Key 2: [ 1.6226422   2.0252647  -0.43359444]
  Key 3 (different seed): [-0.15443718  0.08470728 -0.13598049]


### 2.4 Data Types (dtypes)

In [16]:
print("=" * 50)
print("Common Data Types in JAX")
print("=" * 50)

# Integer types
int32_arr = jnp.array([1, 2, 3], dtype=jnp.int32)
int64_arr = jnp.array([1, 2, 3], dtype=jnp.int64)

# Float types (most common for deep learning)
float16_arr = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float16)  # Half precision
float32_arr = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)  # Single precision (DEFAULT)
float64_arr = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float64)  # Double precision

# Boolean
bool_arr = jnp.array([True, False, True], dtype=jnp.bool_)

print(f"\nint32:   {int32_arr} -> dtype: {int32_arr.dtype}")
print(f"int64:   {int64_arr} -> dtype: {int64_arr.dtype}")
print(f"float16: {float16_arr} -> dtype: {float16_arr.dtype}")
print(f"float32: {float32_arr} -> dtype: {float32_arr.dtype}")
print(f"float64: {float64_arr} -> dtype: {float64_arr.dtype}")
print(f"bool:    {bool_arr} -> dtype: {bool_arr.dtype}")

Common Data Types in JAX

int32:   [1 2 3] -> dtype: int32
int64:   [1 2 3] -> dtype: int32
float16: [1. 2. 3.] -> dtype: float16
float32: [1. 2. 3.] -> dtype: float32
float64: [1. 2. 3.] -> dtype: float32
bool:    [ True False  True] -> dtype: bool


  int64_arr = jnp.array([1, 2, 3], dtype=jnp.int64)
  float64_arr = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float64)  # Double precision


In [17]:
# Type conversion
original = jnp.array([1.7, 2.3, 3.9])
to_int = original.astype(jnp.int32)
to_float16 = original.astype(jnp.float16)

print(f"\nType conversion:")
print(f"  Original (float32): {original}")
print(f"  To int32: {to_int}")
print(f"  To float16: {to_float16}")


Type conversion:
  Original (float32): [1.7 2.3 3.9]
  To int32: [1 2 3]
  To float16: [1.7 2.3 3.9]


### 2.5 Array Attributes

In [18]:
# Create a sample array
arr = jnp.array([[1, 2, 3, 4],
                 [5, 6, 7, 8],
                 [9, 10, 11, 12]], dtype=jnp.float32)

print("Array Attributes")
print("=" * 40)
print(f"Array:\n{arr}\n")

# .shape tells you the size of each dimension (e.g. 3 rows, 4 columns)
print(f"Shape:      {arr.shape}      # (rows, cols)")
# .ndim tells you how many axes/dimensions there are (2 for a matrix)
print(f"Dimensions: {arr.ndim}           # Number of axes")
# .size is the total count of numbers in the array
print(f"Size:       {arr.size}          # Total elements")
# .dtype tells you if it's float, int, etc.
print(f"Dtype:      {arr.dtype}   # Data type")
# .devices() tells you if it's on CPU, GPU, or TPU
print(f"Device:     {arr.devices()}  # Where data lives")

Array Attributes
Array:
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]]

Shape:      (3, 4)      # (rows, cols)
Dimensions: 2           # Number of axes
Size:       12          # Total elements
Dtype:      float32   # Data type
Device:     {CpuDevice(id=0)}  # Where data lives


### 2.6 Indexing and Slicing

In [19]:
arr = jnp.array([[1, 2, 3, 4],
                 [5, 6, 7, 8],
                 [9, 10, 11, 12]])

print("Original array:")
print(arr)
print()

# Basic indexing
print("=" * 40)
print("Basic Indexing")
print("=" * 40)
print(f"arr[0, 0] = {arr[0, 0]}      # First element")
print(f"arr[1, 2] = {arr[1, 2]}      # Row 1, Col 2")
print(f"arr[-1, -1] = {arr[-1, -1]}   # Last element")

Original array:
[[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]]

Basic Indexing
arr[0, 0] = 1      # First element
arr[1, 2] = 7      # Row 1, Col 2
arr[-1, -1] = 12   # Last element


In [20]:
# Slicing
print("\n" + "=" * 40)
print("Slicing [start:stop:step]")
print("=" * 40)
print(f"arr[0, :] = {arr[0, :]}     # First row (all cols)")
print(f"arr[:, 0] = {arr[:, 0]}        # First column (all rows)")
print(f"arr[0:2, 1:3] =\n{arr[0:2, 1:3]}    # Submatrix")
print(f"arr[::2, :] =\n{arr[::2, :]}     # Every 2nd row")


Slicing [start:stop:step]
arr[0, :] = [1 2 3 4]     # First row (all cols)
arr[:, 0] = [1 5 9]        # First column (all rows)
arr[0:2, 1:3] =
[[2 3]
 [6 7]]    # Submatrix
arr[::2, :] =
[[ 1  2  3  4]
 [ 9 10 11 12]]     # Every 2nd row


In [21]:
# Boolean indexing
print("\n" + "=" * 40)
print("Boolean Indexing")
print("=" * 40)
mask = arr > 6
print(f"Mask (arr > 6):\n{mask}")
print(f"arr[arr > 6] = {arr[mask]}  # Elements > 6")


Boolean Indexing
Mask (arr > 6):
[[False False False False]
 [False False  True  True]
 [ True  True  True  True]]
arr[arr > 6] = [ 7  8  9 10 11 12]  # Elements > 6


### 2.7 JAX Arrays are Immutable!

**Important difference from NumPy**: JAX arrays cannot be modified in-place.

In [22]:
arr = jnp.array([1, 2, 3, 4, 5])

# This would work in NumPy but FAILS in JAX:
# arr[0] = 100  # TypeError! JAX arrays are immutable (read-only once created).

# Instead, use .at[].set() to create a NEW array with the updated value.
# The original 'arr' remains unchanged (functional programming style).
arr_modified = arr.at[0].set(100)

print(f"Original: {arr}")       # Unchanged!
print(f"Modified: {arr_modified}")  # New array with change

Original: [1 2 3 4 5]
Modified: [100   2   3   4   5]


In [23]:
# More .at[] operations
arr = jnp.array([1, 2, 3, 4, 5])

print("Functional updates with .at[]")
print(f"Original:           {arr}")
print(f".at[0].set(100):    {arr.at[0].set(100)}")
print(f".at[1:3].set(0):    {arr.at[1:3].set(0)}")
print(f".at[2].add(10):     {arr.at[2].add(10)}")
print(f".at[2].multiply(2): {arr.at[2].multiply(2)}")

Functional updates with .at[]
Original:           [1 2 3 4 5]
.at[0].set(100):    [100   2   3   4   5]
.at[1:3].set(0):    [1 0 0 4 5]
.at[2].add(10):     [ 1  2 13  4  5]
.at[2].multiply(2): [1 2 6 4 5]


---

## 3. Tensor Operations

### 3.1 Element-wise Operations

In [24]:
a = jnp.array([1, 2, 3, 4], dtype=jnp.float32)
b = jnp.array([5, 6, 7, 8], dtype=jnp.float32)

print("Element-wise Arithmetic")
print("=" * 40)
print(f"a = {a}")
print(f"b = {b}")
print()

# These operations happen element-by-element (e.g., 1+5, 2+6, ...)
print(f"a + b = {a + b}     # Addition")
print(f"a - b = {a - b}  # Subtraction")
print(f"a * b = {a * b}  # Multiplication")
print(f"a / b = {a / b}  # Division")
print(f"a ** 2 = {a ** 2}   # Power (square each element)")
print(f"a % 2 = {a % 2}     # Modulo (remainder of division)")

Element-wise Arithmetic
a = [1. 2. 3. 4.]
b = [5. 6. 7. 8.]

a + b = [ 6.  8. 10. 12.]     # Addition
a - b = [-4. -4. -4. -4.]  # Subtraction
a * b = [ 5. 12. 21. 32.]  # Multiplication
a / b = [0.2        0.33333334 0.42857143 0.5       ]  # Division
a ** 2 = [ 1.  4.  9. 16.]   # Power (square each element)
a % 2 = [1. 0. 1. 0.]     # Modulo (remainder of division)


In [25]:
# Mathematical functions
x = jnp.array([0, jnp.pi/6, jnp.pi/4, jnp.pi/3, jnp.pi/2])

print("\nMathematical Functions")
print("=" * 40)
print(f"x (radians): {x}")
print(f"sin(x): {jnp.sin(x)}")
print(f"cos(x): {jnp.cos(x)}")
print(f"exp(x): {jnp.exp(jnp.array([0, 1, 2]))}")
print(f"log(x): {jnp.log(jnp.array([1, jnp.e, jnp.e**2]))}")
print(f"sqrt(x): {jnp.sqrt(jnp.array([1, 4, 9, 16]))}")


Mathematical Functions
x (radians): [0.        0.5235988 0.7853982 1.0471976 1.5707964]
sin(x): [0.         0.5        0.70710677 0.86602545 1.        ]
cos(x): [ 1.0000000e+00  8.6602539e-01  7.0710677e-01  4.9999997e-01
 -4.3711388e-08]
exp(x): [1.        2.7182817 7.389056 ]
log(x): [0. 1. 2.]
sqrt(x): [1. 2. 3. 4.]


### 3.2 Broadcasting

Broadcasting allows operations between arrays of different shapes.

In [26]:
print("Broadcasting Rules")
print("=" * 50)
print("""
1. If arrays have different ndim, prepend 1s to the smaller shape
2. Arrays are compatible when dimensions are equal OR one is 1
3. The result shape is the maximum along each dimension
""")

# Scalar + Array
# The scalar (10) is automatically "stretched" (broadcasted) to match the shape of 'arr'.
arr = jnp.array([1, 2, 3])
print(f"\nScalar + Array:")
print(f"  [1, 2, 3] + 10 = {arr + 10}")
print(f"  Shapes: (3,) + () -> (3,)")

Broadcasting Rules

1. If arrays have different ndim, prepend 1s to the smaller shape
2. Arrays are compatible when dimensions are equal OR one is 1
3. The result shape is the maximum along each dimension


Scalar + Array:
  [1, 2, 3] + 10 = [11 12 13]
  Shapes: (3,) + () -> (3,)


In [27]:
# Vector + Matrix (row-wise)
matrix = jnp.array([[1, 2, 3],
                    [4, 5, 6]])
row = jnp.array([10, 20, 30])

print(f"\nVector + Matrix (row-wise):")
print(f"  Matrix (2, 3):\n{matrix}")
print(f"  Row (3,): {row}")
print(f"  Result:\n{matrix + row}")
print(f"  Shapes: (2, 3) + (3,) -> (2, 3)")


Vector + Matrix (row-wise):
  Matrix (2, 3):
[[1 2 3]
 [4 5 6]]
  Row (3,): [10 20 30]
  Result:
[[11 22 33]
 [14 25 36]]
  Shapes: (2, 3) + (3,) -> (2, 3)


In [28]:
# Column broadcast (need to reshape)
col = jnp.array([[100], [200]])  # Shape (2, 1)

print(f"\nColumn broadcast:")
print(f"  Matrix (2, 3):\n{matrix}")
print(f"  Column (2, 1):\n{col}")
print(f"  Result:\n{matrix + col}")
print(f"  Shapes: (2, 3) + (2, 1) -> (2, 3)")


Column broadcast:
  Matrix (2, 3):
[[1 2 3]
 [4 5 6]]
  Column (2, 1):
[[100]
 [200]]
  Result:
[[101 102 103]
 [204 205 206]]
  Shapes: (2, 3) + (2, 1) -> (2, 3)


In [29]:
# Outer product via broadcasting!
a = jnp.array([1, 2, 3])          # Shape (3,)
b = jnp.array([[10], [20]])        # Shape (2, 1)

print(f"\nOuter-like operation via broadcasting:")
print(f"  a (3,): {a}")
print(f"  b (2, 1):\n{b}")
print(f"  a * b (2, 3):\n{a * b}")


Outer-like operation via broadcasting:
  a (3,): [1 2 3]
  b (2, 1):
[[10]
 [20]]
  a * b (2, 3):
[[10 20 30]
 [20 40 60]]


### 3.3 Reduction Operations

In [30]:
arr = jnp.array([[1, 2, 3],
                 [4, 5, 6]])

print("Reduction Operations")
print("=" * 50)
print(f"Array:\n{arr}\n")

# Global reductions
print("Global (all elements):")
print(f"  sum:  {jnp.sum(arr)}")
print(f"  mean: {jnp.mean(arr)}")
print(f"  min:  {jnp.min(arr)}")
print(f"  max:  {jnp.max(arr)}")
print(f"  std:  {jnp.std(arr):.4f}")
print(f"  prod: {jnp.prod(arr)}")

Reduction Operations
Array:
[[1 2 3]
 [4 5 6]]

Global (all elements):
  sum:  21
  mean: 3.5
  min:  1
  max:  6
  std:  1.7078
  prod: 720


In [31]:
# Axis-wise reductions
print("\nAlong axis=0 (columns, reduce rows):")
print(f"  sum: {jnp.sum(arr, axis=0)}")
print(f"  mean: {jnp.mean(arr, axis=0)}")

print("\nAlong axis=1 (rows, reduce columns):")
print(f"  sum: {jnp.sum(arr, axis=1)}")
print(f"  mean: {jnp.mean(arr, axis=1)}")


Along axis=0 (columns, reduce rows):
  sum: [5 7 9]
  mean: [2.5 3.5 4.5]

Along axis=1 (rows, reduce columns):
  sum: [ 6 15]
  mean: [2. 5.]


In [32]:
# Keep dimensions (useful for broadcasting)
print("\nKeep dimensions (keepdims=True):")
print(f"  sum(axis=1): {jnp.sum(arr, axis=1)} shape: {jnp.sum(arr, axis=1).shape}")
print(f"  sum(axis=1, keepdims=True):\n{jnp.sum(arr, axis=1, keepdims=True)} shape: {jnp.sum(arr, axis=1, keepdims=True).shape}")


Keep dimensions (keepdims=True):
  sum(axis=1): [ 6 15] shape: (2,)
  sum(axis=1, keepdims=True):
[[ 6]
 [15]] shape: (2, 1)


### 3.4 Reshaping Operations

In [33]:
arr = jnp.arange(12)  # [0, 1, 2, ..., 11]

print("Reshaping Operations")
print("=" * 50)
print(f"Original: {arr}, shape: {arr.shape}\n")

# Reshape
reshaped = arr.reshape(3, 4)
print(f"reshape(3, 4):\n{reshaped}")

reshaped2 = arr.reshape(2, 2, 3)
print(f"\nreshape(2, 2, 3):\n{reshaped2}")

Reshaping Operations
Original: [ 0  1  2  3  4  5  6  7  8  9 10 11], shape: (12,)

reshape(3, 4):
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]

reshape(2, 2, 3):
[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]


In [34]:
# Using -1 to infer dimension
auto_reshape = arr.reshape(3, -1)  # -1 means "figure it out"
print(f"\nreshape(3, -1) -> shape {auto_reshape.shape}:\n{auto_reshape}")


reshape(3, -1) -> shape (3, 4):
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]


In [35]:
# Flatten
matrix = jnp.array([[1, 2, 3], [4, 5, 6]])
print(f"\nFlatten:")
print(f"  Original:\n{matrix}")
print(f"  flatten(): {matrix.flatten()}")
print(f"  ravel():   {matrix.ravel()}")


Flatten:
  Original:
[[1 2 3]
 [4 5 6]]
  flatten(): [1 2 3 4 5 6]
  ravel():   [1 2 3 4 5 6]


In [36]:
# Transpose
print(f"\nTranspose:")
print(f"  Original (2, 3):\n{matrix}")
print(f"  Transposed (3, 2):\n{matrix.T}")


Transpose:
  Original (2, 3):
[[1 2 3]
 [4 5 6]]
  Transposed (3, 2):
[[1 4]
 [2 5]
 [3 6]]


In [37]:
# Expand dimensions (add axis)
vec = jnp.array([1, 2, 3])
print(f"\nExpand dimensions:")
print(f"  Original: {vec}, shape: {vec.shape}")
print(f"  expand_dims(axis=0): {jnp.expand_dims(vec, axis=0)}, shape: {jnp.expand_dims(vec, axis=0).shape}")
print(f"  expand_dims(axis=1):\n{jnp.expand_dims(vec, axis=1)}, shape: {jnp.expand_dims(vec, axis=1).shape}")


Expand dimensions:
  Original: [1 2 3], shape: (3,)
  expand_dims(axis=0): [[1 2 3]], shape: (1, 3)
  expand_dims(axis=1):
[[1]
 [2]
 [3]], shape: (3, 1)


In [38]:
# Squeeze (remove dimensions of size 1)
arr_with_ones = jnp.array([[[1, 2, 3]]])  # Shape (1, 1, 3)
print(f"\nSqueeze:")
print(f"  Original shape: {arr_with_ones.shape}")
print(f"  Squeezed shape: {jnp.squeeze(arr_with_ones).shape}")


Squeeze:
  Original shape: (1, 1, 3)
  Squeezed shape: (3,)


### 3.5 Concatenation and Stacking

In [39]:
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])

print("Concatenation and Stacking")
print("=" * 50)
print(f"a:\n{a}")
print(f"b:\n{b}")

# Concatenate (join along existing axis)
print(f"\nconcatenate along axis=0 (vertically):\n{jnp.concatenate([a, b], axis=0)}")
print(f"\nconcatenate along axis=1 (horizontally):\n{jnp.concatenate([a, b], axis=1)}")

Concatenation and Stacking
a:
[[1 2]
 [3 4]]
b:
[[5 6]
 [7 8]]

concatenate along axis=0 (vertically):
[[1 2]
 [3 4]
 [5 6]
 [7 8]]

concatenate along axis=1 (horizontally):
[[1 2 5 6]
 [3 4 7 8]]


In [40]:
# Stack (join along NEW axis)
stacked = jnp.stack([a, b], axis=0)
print(f"\nstack along axis=0 (creates new dimension):")
print(f"  Result shape: {stacked.shape}")
print(f"  Result:\n{stacked}")


stack along axis=0 (creates new dimension):
  Result shape: (2, 2, 2)
  Result:
[[[1 2]
  [3 4]]

 [[5 6]
  [7 8]]]


In [41]:
# Convenient shortcuts
print(f"\nvstack (vertical):")
print(jnp.vstack([a, b]))

print(f"\nhstack (horizontal):")
print(jnp.hstack([a, b]))


vstack (vertical):
[[1 2]
 [3 4]
 [5 6]
 [7 8]]

hstack (horizontal):
[[1 2 5 6]
 [3 4 7 8]]


---

## 4. Linear Algebra for Deep Learning

Linear algebra is the foundation of deep learning. Let's explore the essential operations.

### 4.1 Matrix Multiplication

Matrix multiplication is THE core operation in neural networks.

```
Forward pass: output = input @ weights + bias
```

**Shapes Matter!**
```
(batch, in_features) @ (in_features, out_features) = (batch, out_features)
```