[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/natrask/ENM5320-2026/blob/main/NewMaterial/Lecture01_Jan15/pytorch_review.ipynb)

# PyTorch Fundamentals

In this course we will be developing numerical solvers of differential equations in PyTorch so that we can gain access to Torch's automatic differentiation engine and integrate advanced architectures like transformers into standard scientific computing tasks. To ensure that you have sufficient background to keep up with assignments, I've put together this Jupyter notebook which summarizes PyTorch basics that you will need to be comfortable in order to keep up with the class.

During lecture today, skim this notebook and give an honest self-assessment, and talk to Prof. Trask after class if you feel like this is out of reach for you. With some proactive catch-up it will be possible, but **please do not assume that you can vibe code your way out of this**. When we start developing PDE solvers in Torch we will be doing something very unusual that was not in GPT's training data and you will need to have control of the fine-grained code machinations to get something that works and complete course assignments.

### Notebook outline

I have structured this notebook into two pieces:
- **Piece 1:** Basic PyTorch. You should feel comfortable at this level with PyTorch if you're going to be able to keep up. This is basic backprop, forward/backward pass, and training.
- **Piece 2:** Advanced topics. This is the material that we will learn in the class so you can get a feel for what you will learn. We will teach this by doing - throughout the semester we will have some coding labs where we will write code together. Don't stress if this looks complicated.

The point of this notebook is not for you to go through and run everything/memorize syntax. It is to help you assess the level of Python programming you will need to complete the course.

**WARNING:** In my experience right now, many graduate students are completely reliant on LLMs to write code. With the current capabilities of LLMs, you can do that for Piece 1, **but not Piece 2**. This is because hundreds of thousands of people are training MLPs and transformers, but probably <20 in the world are integrating them into FEM and so an LLM will not just regurgitate the right answer. I am drawing your attention to this now so that you can budget the appropriate amount of time for this course to do some moderately serious code development.

### Required packages

We will mostly just use vanilla pytorch, but the `einops` package will be useful to manipulate matrices and tensors. If you're running this in Colab, uncomment the following to pip install the package.

In [None]:
# Install required packages (uncomment if running in Colab)
# !pip install einops

# **Piece 1.** PyTorch Review

This notebook provides a comprehensive review of PyTorch fundamentals for ENM5320. We'll cover:
- Tensor creation and operations
- Automatic differentiation
- Building neural networks
- Training loops
- GPU operations

### 1. Import PyTorch and Check Version

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")

### 2. Tensor Creation and Basic Operations

Tensors are the fundamental data structure in PyTorch, similar to NumPy arrays but with GPU acceleration support.

In [None]:
# Create tensors from different sources
a = torch.zeros(3, 4)
print("Zeros tensor:\n", a)

b = torch.ones(2, 3)
print("\nOnes tensor:\n", b)

c = torch.rand(2, 3)  # Uniform distribution [0, 1)
print("\nRandom tensor (uniform):\n", c)

d = torch.randn(2, 3)  # Standard normal distribution
print("\nRandom tensor (normal):\n", d)

# From Python list
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("\nTensor from list:\n", e)

# From NumPy array
f_np = np.array([[7, 8], [9, 10]])
f = torch.from_numpy(f_np)
print("\nTensor from NumPy:\n", f)

In [None]:
# Basic operations
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
y = torch.tensor([[5.0, 6.0], [7.0, 8.0]])

# Element-wise operations
print("Addition:", x + y)
print("\nSubtraction:", x - y)
print("\nElement-wise multiplication:", x * y)
print("\nElement-wise division:", x / y)

# Matrix multiplication
print("\nMatrix multiplication (x @ y.T):\n", x @ y.T)
print("\nMatrix multiplication (torch.mm):\n", torch.mm(x, y.T))

# Other useful operations
print("\nMean:", x.mean())
print("Sum:", x.sum())
print("Max:", x.max())
print("Min:", x.min())
print("Standard deviation:", x.std())

### 3. Tensor Indexing and Slicing

Access and modify tensor elements using NumPy-like indexing.

In [None]:
# Create a tensor for indexing examples
tensor = torch.arange(0, 24).reshape(4, 6)
print("Original tensor:\n", tensor)

# Access single element
print("\nElement at [1, 2]:", tensor[1, 2].item())

# Slice rows
print("\nFirst two rows:\n", tensor[:2, :])

# Slice columns
print("\nFirst three columns:\n", tensor[:, :3])

# Advanced indexing
print("\nRows 1 and 3:\n", tensor[[1, 3], :])

# Boolean indexing
mask = tensor > 10
print("\nElements > 10:\n", tensor[mask])

# Modify elements
tensor[0, 0] = 100
print("\nAfter modification:\n", tensor)

### 4. Tensor Reshaping and Broadcasting

Reshape tensors and understand broadcasting rules for operations on different-sized tensors.

In [None]:
# Reshaping tensors
x = torch.arange(12)
print("Original shape:", x.shape)
print(x)

# Using view (requires contiguous memory)
x_view = x.view(3, 4)
print("\nReshaped with view (3, 4):\n", x_view)

# Using reshape (more flexible)
x_reshaped = x.reshape(2, 6)
print("\nReshaped with reshape (2, 6):\n", x_reshaped)

# Transpose
x_t = x_view.T
print("\nTransposed:\n", x_t)

# Flatten
x_flat = x_view.flatten()
print("\nFlattened:", x_flat)

# Squeeze and unsqueeze
x_squeezed = torch.ones(1, 3, 1, 4)
print("\nOriginal shape:", x_squeezed.shape)
print("After squeeze:", x_squeezed.squeeze().shape)

x_unsqueezed = torch.ones(3, 4)
print("\nOriginal shape:", x_unsqueezed.shape)
print("After unsqueeze(0):", x_unsqueezed.unsqueeze(0).shape)
print("After unsqueeze(1):", x_unsqueezed.unsqueeze(1).shape)

In [None]:
# Broadcasting examples
a = torch.tensor([[1, 2, 3], [4, 5, 6]])  # Shape: (2, 3)
b = torch.tensor([10, 20, 30])            # Shape: (3,)

# Broadcasting: b is automatically expanded to (2, 3)
c = a + b
print("Broadcasting addition:\n", c)

# Broadcasting with different dimensions
x = torch.ones(3, 4)      # Shape: (3, 4)
y = torch.ones(4)         # Shape: (4,)
z = x + y                 # y is broadcasted to (3, 4)
print("\nBroadcasting (3,4) + (4,):\n", z)

# Broadcasting with explicit unsqueeze
m = torch.tensor([[1], [2], [3]])  # Shape: (3, 1)
n = torch.tensor([10, 20, 30, 40])  # Shape: (4,)
result = m + n  # Broadcasted to (3, 4)
print("\nBroadcasting (3,1) + (4,) to (3,4):\n", result)

### 5. Automatic Differentiation with Autograd

PyTorch's autograd package provides automatic differentiation for all operations on tensors.

In [None]:
# Basic autograd example
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

# Forward pass
z = x**2 + 2*y**3
print(f"z = {z.item()}")

# Backward pass - compute gradients
z.backward()

print(f"dz/dx = {x.grad.item()}")  # Should be 2*x = 4
print(f"dz/dy = {y.grad.item()}")  # Should be 6*y^2 = 54

In [None]:
# More complex example with vectors
x = torch.randn(5, requires_grad=True)
print("x:", x)

# Compute a scalar output
y = (x**2).sum()
print("y:", y.item())

# Compute gradients
y.backward()
print("Gradient (dy/dx):", x.grad)
print("Expected (2*x):", 2*x.data)

# Zero gradients for next computation
x.grad.zero_()
print("\nAfter zeroing gradients:", x.grad)

In [None]:
# Gradient control: detach and no_grad
x = torch.tensor(2.0, requires_grad=True)

# Detach creates a new tensor that shares storage but doesn't track gradients
y = x**2
z = y.detach()
print(f"y requires_grad: {y.requires_grad}")
print(f"z requires_grad: {z.requires_grad}")

# Context manager to disable gradient tracking
with torch.no_grad():
    w = x**3
    print(f"w requires_grad: {w.requires_grad}")

# This is useful during inference to save memory
print("\nUseful for inference where we don't need gradients!")

### 6. Building a Simple Neural Network

Use `nn.Module` to define neural network architectures.

In [None]:
# Define a simple feedforward neural network
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # Define forward pass
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create an instance of the network
model = SimpleNN(input_size=10, hidden_size=20, output_size=2)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

In [None]:
# Test forward pass
batch_size = 5
input_data = torch.randn(batch_size, 10)
output = model(input_data)
print(f"Input shape: {input_data.shape}")
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")

In [None]:
# Alternative: using nn.Sequential for simple architectures
model_sequential = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 2)
)
print(model_sequential)

# Test
output_seq = model_sequential(input_data)
print(f"\nSequential output shape: {output_seq.shape}")

### 7. Loss Functions and Optimization

PyTorch provides various loss functions and optimizers for training neural networks.

In [None]:
# Common loss functions
# 1. Mean Squared Error (MSE) - for regression
predictions = torch.tensor([1.5, 2.3, 3.1])
targets = torch.tensor([1.0, 2.0, 3.0])
mse_loss = nn.MSELoss()
loss = mse_loss(predictions, targets)
print(f"MSE Loss: {loss.item()}")

# 2. Cross Entropy Loss - for classification
ce_loss = nn.CrossEntropyLoss()
# Predictions: logits for 3 classes, batch size 2
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.2]])
# Targets: class indices
class_targets = torch.tensor([0, 1])
loss = ce_loss(logits, class_targets)
print(f"Cross Entropy Loss: {loss.item()}")

# 3. Binary Cross Entropy - for binary classification
bce_loss = nn.BCEWithLogitsLoss()
binary_logits = torch.tensor([0.5, -1.0, 2.0])
binary_targets = torch.tensor([1.0, 0.0, 1.0])
loss = bce_loss(binary_logits, binary_targets)
print(f"Binary Cross Entropy Loss: {loss.item()}")

In [None]:
# Optimizers
model = SimpleNN(10, 20, 2)

# 1. Stochastic Gradient Descent (SGD)
optimizer_sgd = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
print("SGD optimizer:", optimizer_sgd)

# 2. Adam - adaptive learning rate
optimizer_adam = optim.Adam(model.parameters(), lr=0.001)
print("\nAdam optimizer:", optimizer_adam)

# 3. RMSprop
optimizer_rmsprop = optim.RMSprop(model.parameters(), lr=0.001)
print("\nRMSprop optimizer:", optimizer_rmsprop)

# Access parameter groups
print("\nOptimizer state dict keys:", optimizer_adam.state_dict().keys())

### 8. Training Loop Example

A complete example showing the typical training loop structure.

In [None]:
# Create synthetic data for training
torch.manual_seed(42)
X_train = torch.randn(100, 10)
y_train = torch.randint(0, 2, (100,))

# Initialize model, loss, and optimizer
model = SimpleNN(input_size=10, hidden_size=20, output_size=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 50
losses = []

for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    
    # Backward pass and optimization
    optimizer.zero_grad()  # Clear gradients
    loss.backward()        # Compute gradients
    optimizer.step()       # Update parameters
    
    # Store loss
    losses.append(loss.item())
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("\nTraining complete!")

In [None]:
# Plot training loss
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.grid(True)
plt.show()

# Evaluate on training data (normally you'd use a separate test set)
model.eval()  # Set to evaluation mode
with torch.no_grad():
    outputs = model(X_train)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == y_train).sum().item() / len(y_train)
    print(f'\nTraining Accuracy: {accuracy * 100:.2f}%')

### 9. GPU Operations

Move tensors and models to GPU for faster computation.

In [None]:
# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"Memory cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
else:
    print("No GPU available, using CPU")

In [None]:
# Move tensors to GPU
x_cpu = torch.randn(3, 4)
print(f"Tensor on CPU: {x_cpu.device}")

# Method 1: Using .to()
x_gpu = x_cpu.to(device)
print(f"Tensor moved to: {x_gpu.device}")

# Method 2: Using .cuda() (if available)
if torch.cuda.is_available():
    x_cuda = x_cpu.cuda()
    print(f"Tensor on GPU: {x_cuda.device}")
    
    # Move back to CPU
    x_back = x_cuda.cpu()
    print(f"Tensor back on CPU: {x_back.device}")

# Create tensor directly on device
y = torch.randn(3, 4, device=device)
print(f"Tensor created on: {y.device}")

In [None]:
# Move model to GPU
model = SimpleNN(10, 20, 2)
model = model.to(device)
print(f"Model device: {next(model.parameters()).device}")

# Training with GPU
X_train_gpu = torch.randn(100, 10, device=device)
y_train_gpu = torch.randint(0, 2, (100,), device=device)

# Verify all on same device
print(f"Input device: {X_train_gpu.device}")
print(f"Target device: {y_train_gpu.device}")
print(f"Model device: {next(model.parameters()).device}")

# Forward pass on GPU
outputs = model(X_train_gpu)
print(f"Output device: {outputs.device}")
print(f"Output shape: {outputs.shape}")

## Summary

This notebook covered the essential PyTorch concepts:
- **Tensor operations**: Creation, indexing, slicing, reshaping, and broadcasting
- **Autograd**: Automatic differentiation for gradient computation
- **Neural networks**: Building models with `nn.Module` and `nn.Sequential`
- **Training**: Loss functions, optimizers, and the training loop
- **GPU acceleration**: Moving tensors and models to GPU for faster computation

For more information, visit the [PyTorch documentation](https://pytorch.org/docs/stable/index.html).

# **Piece 2.** Advanced topics for scientific computing in PyTorch

## Using einops for Tensor Manipulation

The `einops` library provides a cleaner, more readable way to perform tensor operations like rearranging, reducing, and repeating. It uses Einstein notation-inspired syntax.

In [None]:
from einops import rearrange, reduce, repeat
import torch

# Create a sample tensor: batch of images
# Shape: (batch, channels, height, width)
images = torch.randn(4, 3, 28, 28)
print(f"Original shape: {images.shape}")

### 1. Rearrange - Reshaping and Transposing

The `rearrange` function is perfect for reshaping, transposing, and reordering dimensions with readable syntax.

In [None]:
# Example 1: Transpose dimensions
# Swap batch and channel dimensions
transposed = rearrange(images, 'b c h w -> c b h w')
print(f"After transpose: {transposed.shape}")

# Example 2: Flatten spatial dimensions
# Convert (batch, channel, height, width) to (batch, channel, pixels)
flattened = rearrange(images, 'b c h w -> b c (h w)')
print(f"Flattened spatial: {flattened.shape}")

# Example 3: Flatten completely
# Convert to (batch, features)
flat = rearrange(images, 'b c h w -> b (c h w)')
print(f"Completely flat: {flat.shape}")

# Example 4: Split dimensions
# Split height and width into patches
patches = rearrange(images, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=7, p2=7)
print(f"Image patches (7x7): {patches.shape}")  # 4 batches, 16 patches, 147 features per patch

### 2. Reduce - Aggregation Operations

The `reduce` function performs reductions (mean, max, min, sum) along specified dimensions.

In [None]:
# Example 1: Global average pooling
# Average over spatial dimensions
pooled = reduce(images, 'b c h w -> b c', 'mean')
print(f"Global average pooled: {pooled.shape}")

# Example 2: Average over channels
channel_avg = reduce(images, 'b c h w -> b h w', 'mean')
print(f"Channel-averaged: {channel_avg.shape}")

# Example 3: Max pooling over spatial dimensions
max_pooled = reduce(images, 'b c h w -> b c', 'max')
print(f"Max pooled: {max_pooled.shape}")

# Example 4: Sum with spatial reduction
summed = reduce(images, 'b c h w -> b c', 'sum')
print(f"Spatial sum: {summed.shape}")

# Example 5: Combining with rearrangement - patch-wise mean
patch_mean = reduce(images, 'b c (h p1) (w p2) -> b c h w', 'mean', p1=7, p2=7)
print(f"Patch-wise mean: {patch_mean.shape}")

### 3. Repeat - Replicating Tensors

The `repeat` function duplicates data along specified dimensions.

In [None]:
# Create a smaller tensor for demonstration
vector = torch.randn(3)
print(f"Original vector: {vector.shape}")

# Example 1: Repeat along new dimension
repeated = repeat(vector, 'c -> b c', b=4)
print(f"Repeated along batch: {repeated.shape}")
print(repeated)

# Example 2: Create a matrix from vector
matrix = repeat(vector, 'c -> c h', h=5)
print(f"\nRepeated to matrix: {matrix.shape}")

# Example 3: Tile a small image
small_image = torch.randn(1, 3, 8, 8)
tiled = repeat(small_image, 'b c h w -> b c (h tile1) (w tile2)', tile1=3, tile2=3)
print(f"\nTiled image: {tiled.shape}")

# Example 4: Broadcasting-like operation
# Repeat a bias vector for each batch element and spatial location
bias = torch.randn(3)
broadcasted = repeat(bias, 'c -> b c h w', b=4, h=28, w=28)
print(f"Broadcasted bias: {broadcasted.shape}")

### 4. Practical Examples

Combining einops operations for common deep learning tasks.

In [None]:
# Example 1: Vision Transformer (ViT) - Image to patches
image = torch.randn(1, 3, 224, 224)
patch_size = 16
patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
print(f"ViT patches: {patches.shape}")  # (1, 196, 768) - 196 patches of 768 dims

# Example 2: Batch matrix multiplication with einops
batch_size, seq_len, dim = 4, 10, 64
queries = torch.randn(batch_size, seq_len, dim)
keys = torch.randn(batch_size, seq_len, dim)

# Attention scores: rearrange for batch matrix multiplication
q_rearranged = rearrange(queries, 'b n d -> b n d')
k_rearranged = rearrange(keys, 'b n d -> b d n')
attention_scores = torch.bmm(q_rearranged, k_rearranged)
print(f"Attention scores: {attention_scores.shape}")

# Example 3: Converting between different data formats
# NCHW (PyTorch) to NHWC (TensorFlow)
nchw_tensor = torch.randn(8, 3, 32, 32)
nhwc_tensor = rearrange(nchw_tensor, 'b c h w -> b h w c')
print(f"NCHW to NHWC: {nchw_tensor.shape} -> {nhwc_tensor.shape}")

# Example 4: Spatial Average Pooling (alternative to nn.AdaptiveAvgPool2d)
feature_map = torch.randn(4, 256, 7, 7)
pooled_features = reduce(feature_map, 'b c h w -> b c', 'mean')
print(f"Spatial pooling: {feature_map.shape} -> {pooled_features.shape}")

### 5. Why Use einops?

**Advantages:**
- **Readability**: Operations are self-documenting with named dimensions
- **Fewer bugs**: Explicit dimension names reduce indexing errors
- **Composability**: Complex operations become simple one-liners
- **Framework agnostic**: Works with PyTorch, TensorFlow, JAX, and NumPy

**Comparison:**
```python
# Traditional PyTorch
x = x.permute(0, 2, 3, 1).reshape(batch, -1, channels)

# With einops
x = rearrange(x, 'b c h w -> b (h w) c')
```

The einops version clearly shows what transformation is being performed!

### 6. Tensor Contraction with einops vs einsum

Tensor contraction is a generalization of matrix multiplication. While `torch.einsum` is convenient, it can be slow, especially for sparse tensors. Using `einops` with optimized matrix multiplications is often faster.

In [None]:
import time

# Setup: Create sample tensors for benchmarking
batch_size, seq_len, dim = 32, 128, 512
A = torch.randn(batch_size, seq_len, dim)
B = torch.randn(batch_size, dim, seq_len)

print(f"Tensor shapes: A={A.shape}, B={B.shape}")
print(f"Goal: Batch matrix multiplication -> ({batch_size}, {seq_len}, {seq_len})")

#### Method 1: Using torch.einsum

In [None]:
# Using einsum notation
# 'bik,bkj->bij' means: batch, i, k @ batch, k, j -> batch, i, j

# Warmup
for _ in range(10):
    _ = torch.einsum('bik,bkj->bij', A, B)

# Benchmark
start = time.time()
for _ in range(100):
    result_einsum = torch.einsum('bik,bkj->bij', A, B)
einsum_time = (time.time() - start) / 100

print(f"Result shape: {result_einsum.shape}")
print(f"Average time (einsum): {einsum_time*1000:.3f} ms")

#### Method 2: Using torch.bmm (Batch Matrix Multiply)

In [None]:
# Using optimized batch matrix multiplication
# torch.bmm is highly optimized and calls BLAS/cuBLAS directly

# Warmup
for _ in range(10):
    _ = torch.bmm(A, B)

# Benchmark
start = time.time()
for _ in range(100):
    result_bmm = torch.bmm(A, B)
bmm_time = (time.time() - start) / 100

print(f"Result shape: {result_bmm.shape}")
print(f"Average time (bmm): {bmm_time*1000:.3f} ms")
print(f"Speedup: {einsum_time/bmm_time:.2f}x faster")

# Verify results are the same
print(f"\nResults match: {torch.allclose(result_einsum, result_bmm)}")

#### Method 3: Using einops + matmul (@ operator)

In [None]:
# Using einops for readability with @ operator for speed
# The @ operator calls the same optimized backends as torch.bmm

# Warmup
for _ in range(10):
    _ = A @ B

# Benchmark
start = time.time()
for _ in range(100):
    result_matmul = A @ B
matmul_time = (time.time() - start) / 100

print(f"Result shape: {result_matmul.shape}")
print(f"Average time (@): {matmul_time*1000:.3f} ms")
print(f"Speedup vs einsum: {einsum_time/matmul_time:.2f}x faster")

# For more complex operations, einops makes it readable
# Example: tensor contraction with rearrangement
C = torch.randn(batch_size, dim, seq_len, 64)
# Contract over dim, resulting in (batch, seq_len, seq_len, 64)
result_complex = rearrange(A, 'b i d -> b i d 1') @ rearrange(C, 'b d j k -> b 1 d (j k)')
result_complex = rearrange(result_complex, 'b i 1 jk -> b i jk')
print(f"\nComplex contraction result: {result_complex.shape}")

#### Why is einsum slower?

**Key reasons:**

1. **Generic implementation**: `einsum` handles arbitrary tensor contractions, so it can't always optimize for specific patterns like matrix multiplication

2. **Memory layout**: `einsum` may need to transpose/reshape tensors internally, creating temporary copies

3. **BLAS optimization**: `torch.bmm` and `@` directly call highly optimized BLAS (CPU) or cuBLAS (GPU) libraries that use:
   - Cache-optimized memory access patterns
   - SIMD vectorization
   - Specialized hardware instructions (e.g., tensor cores on NVIDIA GPUs)

4. **Sparse tensors**: `einsum` is **particularly slow** for sparse tensors because:
   - It often materializes dense intermediate results
   - Doesn't leverage sparse-specific algorithms (CSR/CSC matrix multiplication)
   - Can't skip zero elements efficiently

**Recommendation:**
- Use `einsum` for quick prototyping and unusual contractions
- For production code, rearrange with `einops` then use `@`, `torch.bmm`, or `torch.matmul`
- For sparse tensors, **always** use `torch.sparse.mm` or `@` instead of `einsum`

#### Sparse Tensor Example

In [None]:
# Create a sparse matrix (e.g., adjacency matrix for a graph)
size = 1000
sparsity = 0.01  # 1% non-zero elements

# Create sparse tensor in COO format
indices = torch.randint(0, size, (2, int(size * size * sparsity)))
values = torch.randn(int(size * size * sparsity))
sparse_matrix = torch.sparse_coo_tensor(indices, values, (size, size))

# Dense vector
dense_vec = torch.randn(size, 100)

print(f"Sparse matrix: {size}x{size} with {sparsity*100}% non-zero")
print(f"Dense vector: {dense_vec.shape}")

# Method 1: Using sparse @ (FAST - uses sparse algorithms)
start = time.time()
result_sparse = sparse_matrix @ dense_vec
sparse_time = time.time() - start
print(f"\nSparse @ time: {sparse_time*1000:.3f} ms")

# Method 2: If we convert to dense and use einsum (SLOW - wastes memory)
dense_matrix = sparse_matrix.to_dense()
start = time.time()
result_dense = torch.einsum('ij,jk->ik', dense_matrix, dense_vec)
dense_time = time.time() - start
print(f"Dense einsum time: {dense_time*1000:.3f} ms")
print(f"Speedup: {dense_time/sparse_time:.2f}x faster with sparse operations")

# Memory comparison
print(f"\nMemory usage:")
print(f"Sparse: ~{sparse_matrix._values().nelement() * 4 / 1e6:.2f} MB")
print(f"Dense: ~{size * size * 4 / 1e6:.2f} MB")
print(f"Ratio: {(size * size) / sparse_matrix._values().nelement():.1f}x more memory for dense")

#### Best Practices Summary

**For dense tensors:**
1. Use `einops.rearrange` to make dimensions clear
2. Use `@`, `torch.bmm`, or `torch.matmul` for actual computation
3. Avoid `einsum` in performance-critical code

**For sparse tensors:**
1. **Never** use `einsum` - it will materialize dense intermediates
2. Use `@` operator or `torch.sparse.mm` 
3. Keep computations in sparse format as long as possible
4. Only convert to dense if absolutely necessary

**Example pattern:**
```python
# Good: readable + fast
x = rearrange(x, 'b h w c -> b (h w) c')
result = x @ weight  # Uses optimized BLAS

# Bad: slow, especially for large tensors
result = torch.einsum('bhwc,cd->bhwd', x, weight)
```

### 7. Backpropagation Through Sparse Linear Solves

When solving linear systems $Ax = b$ in neural networks or PDE solvers, we often need gradients. PyTorch doesn't natively support autograd through `torch.linalg.solve` for sparse matrices, so we need custom implementations.

#### The Math: Implicit Function Theorem

For $Ax = b$, we want $\frac{\partial x}{\partial \theta}$ where $\theta$ are parameters in $A$ or $b$.

Using the implicit function theorem:
- Forward: Solve $Ax = b$ for $x$
- Backward: Given $\bar{x}$ (gradient from upstream), compute:
  - $\frac{\partial L}{\partial b} = A^{-T} \bar{x}$ (solve $A^T \lambda = \bar{x}$)
  - $\frac{\partial L}{\partial A} = -\lambda x^T$ where $\lambda$ is from above

**Key insight**: We never form $A^{-1}$ explicitly. We solve two linear systems (forward and backward).

#### Implementation: Custom Autograd Function

In [None]:
class SparseSolve(torch.autograd.Function):
    """
    Custom autograd function for solving Ax = b with sparse A.
    Implements efficient backward pass using implicit differentiation.
    """
    
    @staticmethod
    def forward(ctx, A, b):
        """
        Solve Ax = b for x.
        
        Args:
            A: Sparse matrix (n x n)
            b: Right-hand side (n,) or (n, k)
        
        Returns:
            x: Solution to Ax = b
        """
        # For sparse matrices, we typically use iterative solvers
        # or convert to dense for small problems
        
        # Option 1: Dense solve (for moderate-sized systems)
        if A.is_sparse:
            A_dense = A.to_dense()
        else:
            A_dense = A
            
        x = torch.linalg.solve(A_dense, b)
        
        # Save for backward pass
        ctx.save_for_backward(A_dense, x)
        
        return x
    
    @staticmethod
    def backward(ctx, grad_x):
        """
        Compute gradients using implicit differentiation.
        
        Given grad_x (∂L/∂x), compute:
        - grad_A: ∂L/∂A
        - grad_b: ∂L/∂b
        """
        A, x = ctx.saved_tensors
        
        # Solve A^T λ = grad_x for λ
        # This is the adjoint equation
        lambda_ = torch.linalg.solve(A.T, grad_x)
        
        # Gradient w.r.t. b: ∂L/∂b = λ
        grad_b = lambda_
        
        # Gradient w.r.t. A: ∂L/∂A = -λ x^T
        if x.dim() == 1:
            grad_A = -torch.outer(lambda_, x)
        else:
            grad_A = -torch.mm(lambda_, x.T)
        
        return grad_A, grad_b

# Create a wrapper function for convenience
def sparse_solve(A, b):
    """Solve Ax = b with autograd support."""
    return SparseSolve.apply(A, b)

#### Example: Gradient Check

In [None]:
# Create a simple test case
n = 5
torch.manual_seed(42)

# Create a symmetric positive definite matrix (ensures unique solution)
A_base = torch.randn(n, n, requires_grad=True)
A = A_base @ A_base.T + torch.eye(n)  # SPD matrix

# Right-hand side
b = torch.randn(n, requires_grad=True)

# Solve using our custom function
x = sparse_solve(A, b)
print(f"Solution x: {x}")

# Create a scalar loss (e.g., L2 norm of solution)
loss = (x ** 2).sum()
print(f"Loss: {loss.item():.4f}")

# Backward pass
loss.backward()

print(f"\nGradients computed:")
print(f"grad_A shape: {A.grad.shape}")
print(f"grad_b shape: {b.grad.shape}")
print(f"\ngrad_b: {b.grad}")

# Verify with numerical gradients (finite differences)
print("\n--- Gradient Check ---")
eps = 1e-5

# Check gradient w.r.t. b
b_test = b.detach().clone().requires_grad_(True)
A_test = A.detach()
x_test = sparse_solve(A_test, b_test)
loss_test = (x_test ** 2).sum()
loss_test.backward()

# Numerical gradient for b[0]
b_plus = b.detach().clone()
b_plus[0] += eps
x_plus = sparse_solve(A_test, b_plus)
loss_plus = (x_plus ** 2).sum()

b_minus = b.detach().clone()
b_minus[0] -= eps
x_minus = sparse_solve(A_test, b_minus)
loss_minus = (x_minus ** 2).sum()

numerical_grad_b0 = (loss_plus - loss_minus) / (2 * eps)
analytical_grad_b0 = b_test.grad[0]

print(f"Numerical grad_b[0]: {numerical_grad_b0.item():.6f}")
print(f"Analytical grad_b[0]: {analytical_grad_b0.item():.6f}")
print(f"Relative error: {abs(numerical_grad_b0 - analytical_grad_b0) / abs(numerical_grad_b0):.2e}")

#### Optimization: Using Iterative Solvers

For large sparse systems, we should use iterative solvers (CG, GMRES, BiCGSTAB) instead of direct solvers. Here's an example with Conjugate Gradient:

In [None]:
def conjugate_gradient(A, b, x0=None, tol=1e-5, max_iter=100):
    """
    Solve Ax = b using Conjugate Gradient (for SPD matrices).
    
    Args:
        A: Sparse or dense SPD matrix
        b: Right-hand side
        x0: Initial guess (if None, use zeros)
        tol: Convergence tolerance
        max_iter: Maximum iterations
    
    Returns:
        x: Solution
    """
    n = b.shape[0]
    if x0 is None:
        x = torch.zeros_like(b)
    else:
        x = x0.clone()
    
    # Initial residual
    r = b - A @ x
    p = r.clone()
    rsold = torch.dot(r, r)
    
    for i in range(max_iter):
        Ap = A @ p
        alpha = rsold / torch.dot(p, Ap)
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = torch.dot(r, r)
        
        if torch.sqrt(rsnew) < tol:
            print(f"CG converged in {i+1} iterations")
            break
        
        beta = rsnew / rsold
        p = r + beta * p
        rsold = rsnew
    
    return x

class SparseSolveCG(torch.autograd.Function):
    """
    Sparse solve using Conjugate Gradient with autograd support.
    More efficient for large sparse systems.
    """
    
    @staticmethod
    def forward(ctx, A, b, tol=1e-5):
        x = conjugate_gradient(A, b, tol=tol)
        ctx.save_for_backward(A, x)
        ctx.tol = tol
        return x
    
    @staticmethod
    def backward(ctx, grad_x):
        A, x = ctx.saved_tensors
        
        # Solve A^T λ = grad_x using CG
        # For SPD matrices, A^T = A, so we can reuse CG
        lambda_ = conjugate_gradient(A.T, grad_x, tol=ctx.tol)
        
        grad_b = lambda_
        grad_A = -torch.outer(lambda_, x)
        
        return grad_A, grad_b, None

# Wrapper function
def sparse_solve_cg(A, b, tol=1e-5):
    return SparseSolveCG.apply(A, b, tol)

In [None]:
# Test CG solver with larger system
n = 50
torch.manual_seed(123)

# Create a sparse SPD matrix (e.g., from a discretized Poisson equation)
# Create tridiagonal matrix: -1, 2, -1 pattern
diag = 2 * torch.ones(n, requires_grad=True)
off_diag = -torch.ones(n-1, requires_grad=True)
A_sparse = (torch.diag(diag) + 
            torch.diag(off_diag, diagonal=1) + 
            torch.diag(off_diag, diagonal=-1))

b_sparse = torch.randn(n, requires_grad=True)

# Solve using CG
print("Solving with Conjugate Gradient:")
x_cg = sparse_solve_cg(A_sparse, b_sparse, tol=1e-6)

# Verify solution
residual = torch.norm(A_sparse @ x_cg - b_sparse)
print(f"Residual: {residual.item():.2e}")

# Test gradients
loss = (x_cg ** 2).sum()
loss.backward()

print(f"\nGradients computed successfully")
print(f"grad_diag shape: {diag.grad.shape}")
print(f"grad_b norm: {torch.norm(b_sparse.grad).item():.4f}")

#### Key Considerations for PDE Solvers

**Memory efficiency:**
- Direct solve: $O(n^2)$ memory for dense $A^{-1}$
- Iterative solve: $O(n)$ memory, only store sparse $A$

**Computational cost:**
- Forward: One solve (direct or iterative)
- Backward: One additional solve for adjoint equation

**Stability:**
- Ensure $A$ is well-conditioned (use preconditioning if needed)
- For non-SPD matrices, use GMRES or BiCGSTAB instead of CG
- Monitor convergence in both forward and backward passes

**Advanced techniques:**
1. **Checkpointing**: For very large systems, don't save $A$ and $x$ in `ctx`. Recompute them in backward pass.
2. **Preconditioning**: Apply preconditioner $M^{-1}$ to speed up iterative solvers
3. **Matrix-free**: Implement $A$ as a function (e.g., finite difference operator) instead of storing explicitly

**Common use cases:**
- Poisson equation solvers in physics-informed neural networks (PINNs)
- Implicit time-stepping in neural ODEs
- Optimization problems with PDE constraints
- Inverse problems (parameter estimation from observations)

#### Example: Simple Poisson Equation

Let's solve $-\nabla^2 u = f$ on $[0,1]$ with $u(0) = u(1) = 0$ and backprop through it.

In [None]:
# Discretize Poisson equation using finite differences
def poisson_matrix_1d(n, h):
    """
    Create 1D Poisson matrix: -u'' = f
    Discretized as: (u[i-1] - 2*u[i] + u[i+1]) / h^2 = -f[i]
    """
    diag = (2.0 / h**2) * torch.ones(n)
    off_diag = (-1.0 / h**2) * torch.ones(n-1)
    A = torch.diag(diag) + torch.diag(off_diag, 1) + torch.diag(off_diag, -1)
    return A

# Problem setup
n = 30  # Interior points
L = 1.0  # Domain length
h = L / (n + 1)  # Grid spacing
x = torch.linspace(h, L - h, n)

# Create learnable forcing function (parameterized)
# f(x) = a * sin(pi * x) where a is learnable
a = torch.tensor(10.0, requires_grad=True)
f = a * torch.sin(torch.pi * x)

# Create Poisson matrix
A_poisson = poisson_matrix_1d(n, h)

# Solve: -u'' = f
print("Solving Poisson equation...")
u = sparse_solve(A_poisson, f)

# Define loss: want solution close to target
u_target = torch.sin(2 * torch.pi * x)  # Target solution
loss_pde = ((u - u_target) ** 2).mean()

print(f"Initial loss: {loss_pde.item():.6f}")
print(f"Parameter a: {a.item():.4f}")

# Optimize the parameter
optimizer = torch.optim.Adam([a], lr=0.5)

for epoch in range(50):
    optimizer.zero_grad()
    
    # Re-solve with updated parameter
    f = a * torch.sin(torch.pi * x)
    u = sparse_solve(A_poisson, f)
    
    # Compute loss
    loss_pde = ((u - u_target) ** 2).mean()
    
    # Backprop through the solve!
    loss_pde.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Loss = {loss_pde.item():.6f}, a = {a.item():.4f}")

print(f"\nFinal parameter: a = {a.item():.4f}")
print(f"Final loss: {loss_pde.item():.6f}")

# Advanced Differentiation: functorch

PyTorch's `functorch` library (now part of `torch.func`) provides JAX-like functional transformations including `vmap` (vectorization) and `jacrev`/`jacfwd` (Jacobian computation).

## vmap - Vectorized Mapping

`vmap` automatically vectorizes a function over batch dimensions, eliminating manual loops and enabling better performance.

In [None]:
from torch.func import vmap

# Example function that operates on a single input
def compute_norm(x):
    """Compute L2 norm of a vector."""
    return torch.sqrt((x ** 2).sum())

# Test with a single vector
single_vec = torch.randn(5)
print(f"Single vector norm: {compute_norm(single_vec).item():.4f}")

# Without vmap: manual loop over batch
batch_vecs = torch.randn(10, 5)  # 10 vectors of dimension 5
norms_loop = torch.stack([compute_norm(v) for v in batch_vecs])
print(f"\nNorms (loop): {norms_loop[:3]}...")

# With vmap: automatic vectorization
norms_vmap = vmap(compute_norm)(batch_vecs)
print(f"Norms (vmap): {norms_vmap[:3]}...")

# Verify they're the same
print(f"\nResults match: {torch.allclose(norms_loop, norms_vmap)}")

### Why use vmap?

1. **Cleaner code**: Write functions for single examples, vmap handles batching
2. **Better performance**: Can leverage vectorized operations and avoid Python loops
3. **Composability**: Can combine with other functorch transforms (grad, jacrev, etc.)

In [None]:
# Example: Matrix-vector products for a batch
def matvec(A, x):
    """Single matrix-vector product."""
    return A @ x

# Batch of matrices and vectors
batch_size = 8
matrices = torch.randn(batch_size, 4, 4)
vectors = torch.randn(batch_size, 4)

# vmap over the batch dimension (dim 0)
batched_matvec = vmap(matvec)
result = batched_matvec(matrices, vectors)
print(f"Batched result shape: {result.shape}")

# Verify against manual batch matmul
manual_result = torch.bmm(matrices, vectors.unsqueeze(-1)).squeeze(-1)
print(f"Results match: {torch.allclose(result, manual_result)}")

### Advanced: Nested vmap and Custom Dimensions

In [None]:
# Nested vmap for 2D batching
def dot_product(x, y):
    return (x * y).sum()

# Create 2D batch: (batch1, batch2, features)
x_2d = torch.randn(3, 4, 5)
y_2d = torch.randn(3, 4, 5)

# Apply vmap twice to handle both batch dimensions
batched_dot = vmap(vmap(dot_product))
result_2d = batched_dot(x_2d, y_2d)
print(f"2D batched result shape: {result_2d.shape}")

# Custom in_dims: specify which dimension to vmap over
# Say we have data with shape (features, batch) instead of (batch, features)
x_transposed = torch.randn(5, 10)  # features first
y_transposed = torch.randn(5, 10)

# vmap over dimension 1 (the batch dimension)
batched_dot_custom = vmap(dot_product, in_dims=(1, 1))
result_custom = batched_dot_custom(x_transposed, y_transposed)
print(f"Custom dim result shape: {result_custom.shape}")

## jacrev - Jacobian via Reverse Mode

`jacrev` computes the Jacobian matrix using reverse-mode autodiff. Best when **outputs >> inputs** (wide Jacobian).

In [None]:
from torch.func import jacrev, jacfwd

# Define a function f: R^n -> R^m
def f(x):
    """Example: f(x) = [x1^2 + x2, x1 * x2, sin(x1)]"""
    return torch.stack([
        x[0]**2 + x[1],
        x[0] * x[1],
        torch.sin(x[0])
    ])

# Input point
x = torch.tensor([1.0, 2.0], requires_grad=True)
print(f"Input: {x}")
print(f"Output: {f(x)}")

# Compute Jacobian using reverse mode
jacobian_rev = jacrev(f)(x)
print(f"\nJacobian (jacrev):\n{jacobian_rev}")
print(f"Shape: {jacobian_rev.shape}  # (outputs, inputs)")

# Manual verification of first row: d(x1^2 + x2)/dx
# df1/dx1 = 2*x1 = 2*1 = 2
# df1/dx2 = 1
print(f"\nExpected first row: [2.0, 1.0]")
print(f"Computed first row: {jacobian_rev[0]}")

### jacrev vs jacfwd: When to Use Which?

**jacrev (reverse mode):**
- Best when **m >> n** (many outputs, few inputs)
- One backward pass per output
- Complexity: O(m) backward passes

**jacfwd (forward mode):**
- Best when **n >> m** (many inputs, few outputs)
- One forward pass per input
- Complexity: O(n) forward passes

**Rule of thumb:** Use jacrev for "tall" Jacobians, jacfwd for "wide" Jacobians.

In [None]:
# Compare jacrev vs jacfwd
def g(x):
    """Function R^10 -> R^3 (wide Jacobian)"""
    return torch.stack([
        x.sum(),
        (x**2).sum(),
        (x**3).sum()
    ])

x_wide = torch.randn(10)

# jacfwd is better here (fewer outputs)
jac_fwd = jacfwd(g)(x_wide)
print(f"jacfwd result shape: {jac_fwd.shape}  # (3, 10)")

# jacrev works too, but less efficient
jac_rev = jacrev(g)(x_wide)
print(f"jacrev result shape: {jac_rev.shape}  # (3, 10)")

print(f"\nResults match: {torch.allclose(jac_fwd, jac_rev)}")

# Tall Jacobian example: R^3 -> R^10
def h(x):
    """Function R^3 -> R^10 (tall Jacobian)"""
    # Each output depends on inputs in different ways
    return torch.cat([x, x**2, x**3, x.sum().unsqueeze(0)])

x_tall = torch.randn(3)

# jacrev is better here (fewer inputs)
jac_tall_rev = jacrev(h)(x_tall)
print(f"\njacrev for tall Jacobian: {jac_tall_rev.shape}  # (10, 3)")

### Computing Hessians

Combine `jacrev` with `grad` to compute Hessian matrices (second derivatives).

In [None]:
from torch.func import grad

# Scalar function for Hessian computation
def scalar_func(x):
    """f(x) = x1^4 + x2^3 + x1*x2"""
    return x[0]**4 + x[1]**3 + x[0]*x[1]

x_hess = torch.tensor([1.0, 2.0])

# Hessian = Jacobian of gradient
# Method 1: jacrev(grad(f))
hessian = jacrev(grad(scalar_func))(x_hess)
print(f"Hessian:\n{hessian}")
print(f"Shape: {hessian.shape}")

# Verify: Hessian should be symmetric for scalar functions
print(f"\nHessian is symmetric: {torch.allclose(hessian, hessian.T)}")

# Manual computation for verification:
# f = x1^4 + x2^3 + x1*x2
# df/dx1 = 4*x1^3 + x2
# df/dx2 = 3*x2^2 + x1
# d²f/dx1² = 12*x1^2 = 12*1 = 12
# d²f/dx2² = 6*x2 = 6*2 = 12
# d²f/dx1dx2 = 1
print(f"\nExpected diagonal: [12, 12]")
print(f"Computed diagonal: {torch.diag(hessian)}")
print(f"\nExpected off-diagonal: 1")
print(f"Computed off-diagonal: {hessian[0, 1].item()}")

## Composing vmap and jacrev

The real power comes from composing these transformations. Example: computing per-sample Jacobians in a batch.

In [None]:
# Example: Per-sample gradients in a neural network
# This is useful for per-example gradient clipping, differential privacy, etc.

def simple_model(params, x):
    """Simple linear model: y = W @ x + b"""
    W, b = params
    return W @ x + b

# Model parameters
W = torch.randn(3, 5)
b = torch.randn(3)
params = (W, b)

# Batch of inputs
batch_x = torch.randn(10, 5)  # 10 samples, 5 features each

# Compute Jacobian w.r.t. input for each sample in batch
# vmap over samples, jacrev over input dimensions
per_sample_jacobians = vmap(lambda x: jacrev(lambda inp: simple_model(params, inp))(x))(batch_x)
print(f"Per-sample Jacobians shape: {per_sample_jacobians.shape}")
print(f"(batch_size, output_dim, input_dim): ({per_sample_jacobians.shape})")

# Each element [i] is the Jacobian for sample i
print(f"\nFirst sample Jacobian:\n{per_sample_jacobians[0]}")

# Use case: Compute per-sample parameter gradients
def loss_fn(params, x, y):
    """MSE loss for single example"""
    pred = simple_model(params, x)
    return ((pred - y) ** 2).sum()

# Target outputs
batch_y = torch.randn(10, 3)

# Compute gradient w.r.t. params for each sample
def compute_grad(x, y):
    return grad(lambda p: loss_fn(p, x, y))(params)

# vmap over the batch
per_sample_grads = vmap(compute_grad)(batch_x, batch_y)
print(f"\nPer-sample gradient shapes:")
print(f"  W gradients: {per_sample_grads[0].shape}")  # (10, 3, 5)
print(f"  b gradients: {per_sample_grads[1].shape}")  # (10, 3)

# Average gradient (equivalent to standard batch gradient)
avg_grad_W = per_sample_grads[0].mean(dim=0)
avg_grad_b = per_sample_grads[1].mean(dim=0)
print(f"\nAverage gradient shapes:")
print(f"  W: {avg_grad_W.shape}")
print(f"  b: {avg_grad_b.shape}")

## PDE Applications: Computing Spatial Derivatives

For PDE solvers, we often need spatial derivatives. `jacrev` and `vmap` make this efficient.

In [None]:
# Example: Computing Laplacian for physics-informed neural networks

class PINN(nn.Module):
    """Physics-Informed Neural Network"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 20),
            nn.Tanh(),
            nn.Linear(20, 20),
            nn.Tanh(),
            nn.Linear(20, 1)
        )
    
    def forward(self, x):
        """x: (batch, 2) -> u: (batch, 1)"""
        return self.net(x)

pinn = PINN()

# Define function to compute Laplacian: ∇²u = ∂²u/∂x² + ∂²u/∂y²
def compute_laplacian(model, point):
    """
    Compute Laplacian of model output at a single point.
    point: (2,) tensor [x, y]
    """
    # First derivatives
    def grad_u(p):
        return grad(lambda x: model(x.unsqueeze(0)).squeeze())(p)
    
    # Jacobian of gradient = Hessian
    hessian = jacrev(grad_u)(point)
    
    # Laplacian is trace of Hessian
    laplacian = hessian.trace()
    return laplacian

# Test point
test_point = torch.tensor([0.5, 0.5])
laplacian_value = compute_laplacian(pinn, test_point)
print(f"Laplacian at {test_point.tolist()}: {laplacian_value.item():.4f}")

# Vectorize over batch of points using vmap
points_batch = torch.rand(100, 2)  # 100 random points in [0,1]²

# Compute Laplacian for all points efficiently
laplacians_batch = vmap(lambda p: compute_laplacian(pinn, p))(points_batch)
print(f"\nBatch Laplacians shape: {laplacians_batch.shape}")
print(f"Mean Laplacian: {laplacians_batch.mean().item():.4f}")
print(f"Std Laplacian: {laplacians_batch.std().item():.4f}")

## Performance Tips

**vmap:**
- Doesn't allocate separate memory for each iteration - uses view/slice tricks
- Can sometimes be slower than manual batching for simple operations
- Shines when combining with other transforms or complex functions

**jacrev/jacfwd:**
- Choose based on Jacobian shape (tall vs wide)
- For Hessians: `jacfwd(jacrev(f))` often faster than `jacrev(jacrev(f))`
- Can combine with `vmap` for batched Jacobians

**Memory considerations:**
- Jacobians can be large: (m × n) for f: R^n → R^m
- For PINNs, computing Laplacian is cheaper than full Hessian (only need trace)
- Use `torch.no_grad()` when possible to save memory