# JAX Tutorial: Understanding `vmap` and `pmap`

## Introduction to `vmap` (Vectorizing Map)

`vmap` is a JAX transformation that allows you to automatically vectorize functions. This means you can write a function that operates on a single data point, and `vmap` will transform it into a function that operates efficiently on a batch of data points without you needing to manually add batch dimensions or loops.

**Key benefits of `vmap`:**
- **Simplicity:** Write code for single examples, and `vmap` handles batching.
- **Efficiency:** Leverages JAX's XLA compilation for optimized batch operations.
- **Flexibility:** Works with complex functions and PyTrees.

In [1]:
import jax
import jax.numpy as jnp
import time

# Ensure JAX is using the CPU for consistent behavior in this example notebook
# You can change this to 'gpu' or 'tpu' if available and desired.
jax.config.update('jax_platform_name', 'cpu')

print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Available JAX devices: {jax.devices()}")
print(f"Number of local devices: {jax.local_device_count()}")



JAX version: 0.6.1
JAX backend: cpu
Available JAX devices: [CpuDevice(id=0)]
Number of local devices: 1


In [2]:
# Example: A simple scalar function
def scalar_func(x):
  # print(f"scalar_func called with x: {x}") # For demonstration of tracing
  return x * x + 2

# Let's try it with a single value
print(f"scalar_func(3.0): {scalar_func(jnp.array(3.0))}")

# Now, let's vectorize it with vmap
batched_func_vmap = jax.vmap(scalar_func)

inputs = jnp.array([1.0, 2.0, 3.0, 4.0])
outputs = batched_func_vmap(inputs)

print(f"\nInputs: {inputs}")
print(f"Outputs from vmapped function: {outputs}")
# Note: The print inside scalar_func (if uncommented) will only show one traced call with an abstract value for vmap.

scalar_func(3.0): 11.0

Inputs: [1. 2. 3. 4.]
Outputs from vmapped function: [ 3.  6. 11. 18.]


## `vmap` with `in_axes` and `out_axes`

`in_axes` and `out_axes` arguments in `vmap` give you fine-grained control over how vectorization happens:

- **`in_axes`**: Specifies which axes of the input arguments should be mapped over. 
    - An integer indicates the axis to map over for that argument.
    - `None` means the argument is broadcasted (not mapped over).
    - It's a tuple/list matching the number of positional arguments.

- **`out_axes`**: Specifies where the mapped axis should appear in the output. 
    - By default, it's `0`.

In [3]:
# Example function with two arguments
def power_scale(x, y_scalar):
  # x is a vector, y_scalar is a scalar to be broadcasted
  # print(f"power_scale called with x: {x}, y_scalar: {y_scalar}") # For demonstration
  return x ** y_scalar

# Map over the first argument (x), broadcast the second (y_scalar)
vmapped_power_scale = jax.vmap(power_scale, in_axes=(0, None))

xs = jnp.array([[1., 2.], [3., 4.]]) # Batch of 2, each item is a 2-element vector
y = jnp.array(3.0)

result = vmapped_power_scale(xs, y)
print(f"\nInput xs:\n{xs}")
print(f"Input y: {y}")
print(f"Result of vmapped_power_scale (in_axes=(0, None)):\n{result}")
print(f"Result shape: {result.shape}")

# Example: Mapping over different axes
def matrix_vector_product(matrix, vector):
    # matrix is (N, M), vector is (M,)
    return jnp.dot(matrix, vector)

# Batch of matrices (axis 0), batch of vectors (axis 0)
batched_mvp = jax.vmap(matrix_vector_product, in_axes=(0, 0))

matrices = jnp.stack([jnp.arange(1,5).reshape(2,2), jnp.arange(5,9).reshape(2,2)])
vectors = jnp.stack([jnp.array([1.,2.]), jnp.array([3.,4.])])

print(f"\nBatched matrices:\n{matrices}")
print(f"Batched vectors:\n{vectors}")
output_mvp = batched_mvp(matrices, vectors)
print(f"Output of batched MVP (in_axes=(0,0)):\n{output_mvp}")
print(f"Output shape: {output_mvp.shape}")


Input xs:
[[1. 2.]
 [3. 4.]]
Input y: 3.0
Result of vmapped_power_scale (in_axes=(0, None)):
[[ 1.  8.]
 [27. 64.]]
Result shape: (2, 2)

Batched matrices:
[[[1 2]
  [3 4]]

 [[5 6]
  [7 8]]]
Batched vectors:
[[1. 2.]
 [3. 4.]]
Output of batched MVP (in_axes=(0,0)):
[[ 5. 11.]
 [39. 53.]]
Output shape: (2, 2)


In [4]:
# Example using out_axes
def simple_add(x, y):
    return x + y

# Standard vmap, output batch axis is 0
vmapped_add_default_out = jax.vmap(simple_add)
xs_add = jnp.array([1, 2, 3])
ys_add = jnp.array([10, 20, 30])
result_default_out = vmapped_add_default_out(xs_add, ys_add)
print(f"Result with default out_axes=0: {result_default_out}, shape: {result_default_out.shape}")

def identity_func(x):
    return x # returns the input as is

inputs_2d = jnp.array([[1,2],[3,4],[5,6]]) # Shape (3, 2)

# Map over axis 0, place the mapped axis at dimension 0 in output (default)
vmapped_identity_out0 = jax.vmap(identity_func, in_axes=0, out_axes=0)
output_0 = vmapped_identity_out0(inputs_2d)
print(f"\nInput (3,2):\n{inputs_2d}")
print(f"Output with out_axes=0:\n{output_0}\nShape: {output_0.shape}")

# Map over axis 0, place the mapped axis at dimension 1 in output
vmapped_identity_out1 = jax.vmap(identity_func, in_axes=0, out_axes=1)
output_1 = vmapped_identity_out1(inputs_2d)
print(f"Output with out_axes=1:\n{output_1}\nShape: {output_1.shape}")

# Map over axis 1, place the mapped axis at dimension 0 in output
vmapped_identity_in1_out0 = jax.vmap(identity_func, in_axes=1, out_axes=0)
output_in1_out0 = vmapped_identity_in1_out0(inputs_2d)
print(f"Output with in_axes=1, out_axes=0:\n{output_in1_out0}\nShape: {output_in1_out0.shape}")

Result with default out_axes=0: [11 22 33], shape: (3,)

Input (3,2):
[[1 2]
 [3 4]
 [5 6]]
Output with out_axes=0:
[[1 2]
 [3 4]
 [5 6]]
Shape: (3, 2)
Output with out_axes=1:
[[1 3 5]
 [2 4 6]]
Shape: (2, 3)
Output with in_axes=1, out_axes=0:
[[1 3 5]
 [2 4 6]]
Shape: (2, 3)


In [1]:
import os

# Method 1: Force CPU-only and create fake devices
os.environ["JAX_PLATFORMS"] = "cpu"  # This forces JAX to use CPU only
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax
import jax.numpy as jnp
import time

## Introduction to `pmap` (Parallel Map)

`pmap` is JAX's transformation for data parallelism. It compiles a function to run in parallel across multiple devices (e.g., GPUs or TPU cores). This is a form of SPMD (Single Program, Multiple Data) programming.

**Key aspects of `pmap`:**
- **Device Parallelism:** Executes the same function on different data shards on different devices.
- **Data Sharding:** Input data must be explicitly sharded (split) across the first axis to match the number of devices.
- **Collective Operations:** Supports operations that communicate across devices (e.g., `jax.lax.psum` for sum-reduction).
- **Requires Multiple Devices:** To see true parallelism, you need to run on hardware with multiple JAX-addressable devices.

In [2]:
# Get the number of available devices
num_devices = jax.local_device_count()
print(f"Number of JAX local devices: {num_devices}")

# If only 1 device, pmap will act like a map on that single device.
# For true parallelism, you need num_devices > 1.

def simple_device_computation(x):
  # This function will run on each device with its shard of x
  # print(f"simple_device_computation on device {jax.devices()[jax.process_index()]} with x: {x}") # This print might be tricky with pmap
  return x * 2

# Parallelize the function with pmap
pmapped_computation = jax.pmap(simple_device_computation)

# Create data. The first dimension must be equal to the number of devices.
# If num_devices is 1, this will be shape (1, ...)
if num_devices > 0:
    data_to_shard = jnp.arange(num_devices * 4).reshape((num_devices, 4))
    print(f"\nData to shard (shape {data_to_shard.shape}):\n{data_to_shard}")

    # Apply the pmapped function
    # Each device gets one row of 'data_to_shard'
    result_pmap = pmapped_computation(data_to_shard)
    result_pmap.block_until_ready() # Ensure computation finishes for accurate inspection
    print(f"Result from pmapped function (shape {result_pmap.shape}):\n{result_pmap}")
    print("Note: Each device processed its own slice of the input.")
else:
    print("\nNo JAX devices found to demonstrate pmap. This usually means JAX isn't properly initialized or no backend is available.")

# Example: pmap with a scalar input (broadcasted)
def add_scalar_on_device(x, s):
    return x + s

# pmap with in_axes=(0, None) to map over x and broadcast s
pmapped_add_scalar = jax.pmap(add_scalar_on_device, in_axes=(0, None))

if num_devices > 0:
    data_sharded = jnp.arange(num_devices * 2).reshape((num_devices, 2))
    scalar_val = jnp.array(100.0)
    print(f"\nData sharded for add_scalar:\n{data_sharded}")
    print(f"Scalar value: {scalar_val}")
    result_add_scalar = pmapped_add_scalar(data_sharded, scalar_val)
    result_add_scalar.block_until_ready()
    print(f"Result from pmapped_add_scalar:\n{result_add_scalar}")

Number of JAX local devices: 8

Data to shard (shape (8, 4)):
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]
 [24 25 26 27]
 [28 29 30 31]]
Result from pmapped function (shape (8, 4)):
[[ 0  2  4  6]
 [ 8 10 12 14]
 [16 18 20 22]
 [24 26 28 30]
 [32 34 36 38]
 [40 42 44 46]
 [48 50 52 54]
 [56 58 60 62]]
Note: Each device processed its own slice of the input.

Data sharded for add_scalar:
[[ 0  1]
 [ 2  3]
 [ 4  5]
 [ 6  7]
 [ 8  9]
 [10 11]
 [12 13]
 [14 15]]
Scalar value: 100.0
Result from pmapped_add_scalar:
[[100. 101.]
 [102. 103.]
 [104. 105.]
 [106. 107.]
 [108. 109.]
 [110. 111.]
 [112. 113.]
 [114. 115.]]


## `pmap` with `axis_name` for Collective Operations

`pmap` can define a named axis using the `axis_name` argument. This name can then be used in collective operations like `jax.lax.psum`, `jax.lax.pmean`, `jax.lax.all_gather`, etc., to communicate or aggregate results across devices participating in the `pmap`.

This is fundamental for many distributed algorithms, like training large neural networks.

In [3]:
# Define a function that performs a sum reduction across devices
def sum_across_devices(x):
  # 'devices' is the named axis for pmap
  # jax.lax.psum will sum the value of 'x' from all devices involved in this pmap
  return jax.lax.psum(x, axis_name='devices')

# pmap this function, naming the mapped axis 'devices'
pmapped_sum = jax.pmap(sum_across_devices, axis_name='devices')

if num_devices > 0:
    # Create data such that each device gets a different scalar value
    # For example, if num_devices=4, devices get [0, 1, 2, 3]
    per_device_values = jnp.arange(num_devices, dtype=jnp.float32)
    print(f"\nPer-device values: {per_device_values}")

    # When pmapped_sum is called, each device will have one of these values for 'x'.
    # jax.lax.psum will sum them all up, and each device will receive the total sum.
    total_sum_on_all_devices = pmapped_sum(per_device_values)
    total_sum_on_all_devices.block_until_ready()

    print(f"Total sum (available on all devices): {total_sum_on_all_devices}")
    print(f"Expected sum: {jnp.sum(per_device_values)}")

    # Example: Normalize data across devices
    def normalize_globally(x):
        local_sum = jnp.sum(x)
        global_sum = jax.lax.psum(local_sum, axis_name='devices')
        # Add a small epsilon to prevent division by zero if global_sum is zero
        return x / (global_sum + 1e-8)
    
    pmapped_normalize = jax.pmap(normalize_globally, axis_name='devices')
    
    # Each device has a small vector
    data_for_norm = jnp.arange(num_devices * 2, dtype=jnp.float32).reshape((num_devices, 2)) + 1.0
    print(f"\nData for global normalization (shape {data_for_norm.shape}):\n{data_for_norm}")
    
    normalized_data = pmapped_normalize(data_for_norm)
    normalized_data.block_until_ready()
    print(f"Globally normalized data (shape {normalized_data.shape}):\n{normalized_data}")
    
    # Verify: the sum of all elements in normalized_data should be close to 1.0
    # Each device has its part of the normalized data.
    # To get the total sum for verification, we'd need to sum it up (conceptually)
    if num_devices > 0:
      print(f"Sum of one shard of normalized data: {jnp.sum(normalized_data[0])}") 
      # The sum of *all* elements across *all* devices in `normalized_data` will be 1.0.
      # We can verify this by summing all parts of normalized_data (which are on the host now)
      print(f"Total sum of all normalized data: {jnp.sum(normalized_data)}")
else:
    print("\nSkipping pmap collective examples as num_devices is 0.")


Per-device values: [0. 1. 2. 3. 4. 5. 6. 7.]
Total sum (available on all devices): [28. 28. 28. 28. 28. 28. 28. 28.]
Expected sum: 28.0

Data for global normalization (shape (8, 2)):
[[ 1.  2.]
 [ 3.  4.]
 [ 5.  6.]
 [ 7.  8.]
 [ 9. 10.]
 [11. 12.]
 [13. 14.]
 [15. 16.]]
Globally normalized data (shape (8, 2)):
[[0.00735294 0.01470588]
 [0.02205882 0.02941176]
 [0.03676471 0.04411765]
 [0.05147059 0.05882353]
 [0.06617647 0.07352941]
 [0.08088236 0.0882353 ]
 [0.09558824 0.10294118]
 [0.11029412 0.11764706]]
Sum of one shard of normalized data: 0.022058824077248573
Total sum of all normalized data: 1.0


## Performance Comparison: `vmap` vs. `pmap` vs. Sequential

Let's compare the performance of applying a function to a batch of data using:
1.  A standard Python loop applying a JITted JAX function to each item.
2.  `jax.vmap` to vectorize the function.
3.  `jax.pmap` to parallelize the function across available devices.

We'll use a slightly more compute-intensive function for this benchmark. For accurate timing in JAX, especially with asynchronous dispatch, it's crucial to call `.block_until_ready()` on the results before stopping the timer.

In [4]:
key = jax.random.PRNGKey(0)
matrix_size = 256 # Size of square matrices for the benchmark task
batch_size = 64   # Number of matrices to process

# Define a JITted function for our benchmark task
@jax.jit
def benchmark_task(matrix):
    # Perform some operations that are reasonably compute-intensive
    res = matrix
    for _ in range(3):
        res = jnp.dot(res, matrix) # Matrix multiplication
        res = jnp.tanh(res)       # Element-wise non-linearity
    return jnp.sum(res) # Return a scalar summary

# Generate a batch of random matrices
key, *subkeys = jax.random.split(key, batch_size + 1)
batched_matrices = jnp.stack([jax.random.normal(subkeys[i], (matrix_size, matrix_size)) for i in range(batch_size)])
print(f"Generated batched_matrices with shape: {batched_matrices.shape}")

# --- 1. Sequential execution (Python loop over JITted function) ---
print("\nTiming sequential execution...")
start_time_seq = time.time()
results_seq = []
for i in range(batch_size):
    results_seq.append(benchmark_task(batched_matrices[i]))
# Stack results and block for timing
jnp.stack(results_seq).block_until_ready()
end_time_seq = time.time()
time_seq = end_time_seq - start_time_seq
print(f"Sequential execution time: {time_seq:.4f} seconds")

# --- 2. `vmap` execution ---
print("\nTiming vmap execution...")
vmapped_task = jax.vmap(benchmark_task)
start_time_vmap = time.time()
results_vmap = vmapped_task(batched_matrices)
results_vmap.block_until_ready() # Important for accurate timing
end_time_vmap = time.time()
time_vmap = end_time_vmap - start_time_vmap
print(f"vmap execution time: {time_vmap:.4f} seconds")

# --- 3. `pmap` execution ---
print("\nTiming pmap execution...")
if num_devices > 0:
    # pmap requires the leading dimension to be equal to num_devices
    # We need to ensure batch_size is a multiple of num_devices for this simple setup
    if batch_size % num_devices == 0:
        sharded_batch_size = batch_size // num_devices
        sharded_matrices = batched_matrices.reshape((num_devices, sharded_batch_size, matrix_size, matrix_size))
        print(f"Reshaped matrices for pmap to shape: {sharded_matrices.shape}")

        pmapped_task = jax.pmap(benchmark_task)
        
        # Warm-up pmap (compilation can take time on first run)
        warmup_result = pmapped_task(sharded_matrices)
        warmup_result.block_until_ready()

        start_time_pmap = time.time()
        results_pmap = pmapped_task(sharded_matrices)
        results_pmap.block_until_ready() # Important for accurate timing
        end_time_pmap = time.time()
        time_pmap = end_time_pmap - start_time_pmap
        print(f"pmap execution time: {time_pmap:.4f} seconds (on {num_devices} devices)")
        # results_pmap will have shape (num_devices, sharded_batch_size)
        # print(f"pmap result shape: {results_pmap.shape}")
    else:
        print(f"Skipping pmap timing: batch_size ({batch_size}) is not divisible by num_devices ({num_devices}).")
        time_pmap = float('inf') # Or some indicator that it wasn't run
else:
    print("Skipping pmap timing: No JAX devices found or num_devices is 0.")
    time_pmap = float('inf')

print("\n--- Timing Summary ---")
print(f"Sequential: {time_seq:.4f} s")
print(f"vmap:       {time_vmap:.4f} s")
if time_pmap != float('inf'):
    print(f"pmap:       {time_pmap:.4f} s (on {num_devices} devices)")
    if time_vmap > 0 : print(f"vmap is {time_seq/time_vmap:.2f}x faster than sequential")
    if time_pmap > 0 and time_vmap > 0 : print(f"pmap is {time_vmap/time_pmap:.2f}x faster than vmap (approx, highly dependent on setup)")
else:
    print("pmap:       Not run or N/A")

Generated batched_matrices with shape: (64, 256, 256)

Timing sequential execution...
Sequential execution time: 0.2050 seconds

Timing vmap execution...
vmap execution time: 0.1385 seconds

Timing pmap execution...
Reshaped matrices for pmap to shape: (8, 8, 256, 256)


: 

### Discussion of Timing Results

You should typically observe the following:
- **`vmap` vs. Sequential:** `vmap` is generally significantly faster than a Python loop calling a JITted function for each item. This is because `vmap` allows JAX to see the entire batch operation at once, enabling XLA to perform aggressive optimizations and parallelize computations on the available hardware (even on a single CPU or GPU core through SIMD instructions).
- **`pmap` vs. `vmap`:**
    - **On a single device (or if `num_devices` is 1):** `pmap` might be slightly slower than `vmap` due to a small overhead. It essentially behaves like `map` in this scenario.
    - **On multiple devices (e.g., multiple GPUs or TPU cores):** `pmap` can be faster than `vmap` if the workload is large enough to benefit from distribution and the communication overhead is managed. The speedup depends on the task, data size, and the efficiency of inter-device communication.
    - **On CPUs:** JAX can treat multiple CPU cores as separate devices. `pmap` might show some speedup over `vmap` if the task is parallelizable and the overhead of managing multiple processes/threads for `pmap` is less than the gains from parallelism. However, for many CPU-bound tasks that are already well-optimized by `vmap`, `pmap` might not offer substantial additional benefits and could even add overhead.

**Important Considerations:**
- **Compilation Time:** The first run of a JITted function (including those transformed by `vmap` or `pmap`) includes compilation time. For fair benchmarking, either perform a warm-up run or use tools like `%timeit` (in a notebook) which handle this.
- **Workload Size:** The benefits of `vmap` and `pmap` are more pronounced for larger batches and more computationally intensive tasks.
- **Data Sharding for `pmap`:** `pmap` requires data to be explicitly sharded. The leading axis of the input arrays must match the number of devices. This might require reshaping your data.
- **Hardware:** The actual performance gains from `pmap` are highly dependent on the underlying hardware (number and type of GPUs/TPUs) and the nature of the computation.

## Key Differences: `vmap` vs. `pmap`

| Feature          | `vmap`                                       | `pmap`                                            |
|------------------|----------------------------------------------|---------------------------------------------------|
| **Purpose** | Automatic vectorization (batching)         | Parallel execution across devices                 |
| **Execution** | Typically on a single device (conceptually)  | Across multiple devices (SPMD)                    |
| **Data Handling**| Operates on an axis of an array              | Requires data to be sharded across devices        |
| **Communication**| No inherent inter-batch communication        | Supports collective operations (e.g., `psum`)     |
| **Use Case** | Applying an operation to a batch of data efficiently | Scaling computation to multiple GPUs/TPUs         |
| **`in_axes`** | Flexible control over which args are mapped | Specifies which args are sharded (or broadcasted) |
| **`axis_name`** | Not applicable                               | Used to name the mapped axis for collectives      |

**When to use which?**
- Use **`vmap`** when you have a function written for a single example and want to apply it to a batch of examples efficiently on one device. It's great for adding batch dimensions implicitly.
- Use **`pmap`** when you want to distribute the computation of a function across multiple physical devices (like multiple GPUs or TPU cores) to speed up processing or handle larger datasets that don't fit on a single device. This requires thinking about data sharding and potentially inter-device communication.

## Conclusion & Further Resources

`vmap` and `pmap` are powerful tools in the JAX ecosystem for writing clean, efficient, and scalable code.
- `vmap` simplifies batching.
- `pmap` enables multi-device parallelism.

Understanding how and when to use them is key to leveraging JAX effectively for high-performance numerical computing and machine learning.

**Further Reading:**
- JAX Documentation on `vmap`: [https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html](https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html)
- JAX Documentation on `pmap`: [https://jax.readthedocs.io/en/latest/jax-101/04-parallelization.html](https://jax.readthedocs.io/en/latest/jax-101/04-parallelization.html)
- JAX Sharp Bits - Common Gotchas: [https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) (often covers aspects of transformations)