# Tangent: Source-to-Source Automatic Differentiation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pedronahum/tangent/blob/master/notebooks/tangent_tutorial.ipynb)

Welcome to this comprehensive tutorial on **Tangent**, a library for automatic differentiation that works by transforming Python source code directly. This makes gradients readable, debuggable, and efficient!

## What Makes Tangent Special?

- **Source-to-Source**: Transforms Python code directly, making gradients human-readable
- **Multi-Backend**: Works with NumPy, TensorFlow, and JAX
- **Debuggable**: Generated gradient code can be inspected and stepped through
- **Efficient**: No tape overhead at runtime

## Tutorial Contents

1. Installation & Setup
2. Basic Concepts
3. NumPy Integration
4. TensorFlow 2.x Integration
5. JAX Integration
6. Advanced Features
7. Visualization & Debugging
8. Real-World Examples

## 1. Installation & Setup

First, let's install Tangent and the numerical computing libraries we'll use.

In [None]:
# Install Tangent from GitHub and dependencies
!pip install git+https://github.com/pedronahum/tangent.git numpy matplotlib

# Install optional backends (choose what you need)
!pip install jax jaxlib  # For JAX support
!pip install tensorflow  # For TensorFlow support

In [None]:
# Import core libraries
import tangent
import numpy as np
import matplotlib.pyplot as plt

# Configure matplotlib for better plots
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

print("âœ“ Tangent imported successfully!")

## 2. Basic Concepts

### What is Automatic Differentiation?

Automatic differentiation (autodiff) computes derivatives of functions automatically. Unlike:
- **Numerical differentiation**: (f(x+h) - f(x))/h (approximate, unstable)
- **Symbolic differentiation**: Full symbolic manipulation (can explode in size)

Autodiff is:
- **Exact**: Computes derivatives to machine precision
- **Efficient**: Complexity proportional to original computation

### How Tangent Works

Tangent transforms your Python function's source code into a new function that computes gradients:

```python
def f(x):
    return x * x
```

becomes (conceptually):

```python
def df_dx(x):
    return 2 * x
```

Let's see this in action!

In [None]:
# Define a simple function
def square(x):
    return x * x

# Compute its gradient
dsquare = tangent.grad(square)

# Test it
x_val = 3.0
gradient = dsquare(x_val)

print(f"f(x) = xÂ²")
print(f"f({x_val}) = {square(x_val)}")
print(f"f'({x_val}) = {gradient}")
print(f"Expected: {2 * x_val} âœ“" if abs(gradient - 2 * x_val) < 1e-5 else "Error!")

### Inspecting Generated Code

One of Tangent's superpowers is that you can actually **see** the generated gradient code!

In [None]:
# Let's inspect the generated gradient code
import inspect

def polynomial(x):
    """f(x) = 3xÂ² + 2x + 1"""
    return 3.0 * x * x + 2.0 * x + 1.0

# Generate gradient function
dpolynomial = tangent.grad(polynomial)

# Show the generated code
print("Original function:")
print(inspect.getsource(polynomial))
print("\nGenerated gradient code:")
print(inspect.getsource(dpolynomial))

### Visualizing Functions and Gradients

Let's create a helper function to visualize functions alongside their gradients.

In [None]:
def plot_function_and_gradient(f, df, x_range, title="Function and Gradient"):
    """Plot a function and its gradient side by side."""
    x = np.linspace(*x_range, 200)
    y = np.array([f(xi) for xi in x])
    dy = np.array([df(xi) for xi in x])
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot function
    ax1.plot(x, y, 'b-', linewidth=2, label='f(x)')
    ax1.set_xlabel('x', fontsize=12)
    ax1.set_ylabel('f(x)', fontsize=12)
    ax1.set_title('Function', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Plot gradient
    ax2.plot(x, dy, 'r-', linewidth=2, label="f'(x)")
    ax2.set_xlabel('x', fontsize=12)
    ax2.set_ylabel("f'(x)", fontsize=12)
    ax2.set_title('Gradient', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    fig.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

# Test with our polynomial
plot_function_and_gradient(
    polynomial, 
    dpolynomial, 
    (-3, 3),
    title="Polynomial: f(x) = 3xÂ² + 2x + 1"
)

## 3. NumPy Integration

Tangent works seamlessly with NumPy arrays and operations!

### Vector Operations

In [None]:
def vector_norm_squared(x):
    """Compute ||x||Â² = sum of xÂ²"""
    return np.sum(x * x)

# Gradient of ||x||Â² is 2x
dvector_norm_squared = tangent.grad(vector_norm_squared)

x = np.array([1.0, 2.0, 3.0])
gradient = dvector_norm_squared(x)

print(f"x = {x}")
print(f"||x||Â² = {vector_norm_squared(x)}")
print(f"âˆ‡||x||Â² = {gradient}")
print(f"Expected: {2 * x}")
print(f"Match: {np.allclose(gradient, 2 * x)} âœ“")

### Matrix Operations

In [None]:
def matrix_vector_sum(x):
    """Compute sum(A @ x) where A is a fixed matrix"""
    A = np.array([[2.0, 1.0, 0.5],
                  [1.0, 3.0, 0.7],
                  [0.5, 0.7, 4.0]])
    return np.sum(np.dot(A, x))

# Gradient w.r.t. x
df_dx = tangent.grad(matrix_vector_sum)

x = np.array([1.0, 2.0, 3.0])
gradient = df_dx(x)

A = np.array([[2.0, 1.0, 0.5],
              [1.0, 3.0, 0.7],
              [0.5, 0.7, 4.0]])

print(f"x = {x}")
print(f"A = \n{A}")
print(f"\nâˆ‡_x sum(Ax) = {gradient}")
print(f"Expected (sum of columns): {np.sum(A, axis=0)}")
print(f"Match: {np.allclose(gradient, np.sum(A, axis=0))} âœ“")

### Element-wise Operations

In [None]:
def sigmoid_loss(x):
    """Sum of sigmoid function: sum(1 / (1 + exp(-x)))"""
    return np.sum(1.0 / (1.0 + np.exp(-x)))

dsigmoid_loss = tangent.grad(sigmoid_loss)

x = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
gradient = dsigmoid_loss(x)

# Expected: sigmoid(x) * (1 - sigmoid(x))
sigmoid_x = 1.0 / (1.0 + np.exp(-x))
expected = sigmoid_x * (1.0 - sigmoid_x)

print(f"x = {x}")
print(f"âˆ‡(sigmoid sum) = {gradient}")
print(f"Expected: {expected}")
print(f"Match: {np.allclose(gradient, expected)} âœ“")

## 4. TensorFlow 2.x Integration

Tangent works with TensorFlow 2.x in eager execution mode!

In [None]:
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")

### Basic TensorFlow Operations

In [None]:
def tf_quadratic(x):
    """TensorFlow quadratic: 2xÂ² + 3x + 1"""
    return 2.0 * x * x + 3.0 * x + 1.0

dtf_quadratic = tangent.grad(tf_quadratic)

x_tf = tf.constant(2.0)
gradient = dtf_quadratic(x_tf)

print(f"f(x) = 2xÂ² + 3x + 1")
print(f"f(2) = {tf_quadratic(x_tf).numpy()}")
print(f"f'(2) = {gradient.numpy()}")
print(f"Expected f'(2) = 4*2 + 3 = {4*2 + 3}")
print(f"Match: {abs(gradient.numpy() - 11) < 1e-5} âœ“")

### Neural Network Layer Example

In [None]:
def simple_layer(x, W, b):
    """Simple neural network layer: sum(tanh(Wx + b))"""
    linear = tf.matmul(tf.reshape(x, [1, -1]), W) + b
    activation = tf.tanh(linear)
    return tf.reduce_sum(activation)

# Compute gradient w.r.t. weights W
dlayer_dW = tangent.grad(simple_layer, wrt=(1,))

x = tf.constant([1.0, 2.0, 3.0])
W = tf.constant([[0.5, 0.3], [0.2, 0.7], [0.1, 0.4]])
b = tf.constant([0.1, 0.2])

gradient = dlayer_dW(x, W, b)

print(f"Input shape: {x.shape}")
print(f"Weight shape: {W.shape}")
print(f"Gradient shape: {gradient.shape}")
print(f"\nGradient w.r.t. W:\n{gradient.numpy()}")

## 5. JAX Integration

Tangent also supports JAX, Google's high-performance numerical computing library!

In [None]:
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")

### JAX Array Operations

In [None]:
def jax_polynomial(x):
    """JAX polynomial: xÂ³ - 2xÂ² + 3x - 1"""
    return x**3 - 2*x**2 + 3*x - 1

djax_polynomial = tangent.grad(jax_polynomial)

x_jax = jnp.array(2.0)
gradient = djax_polynomial(x_jax)

# Expected: 3xÂ² - 4x + 3
expected = 3 * x_jax**2 - 4 * x_jax + 3

print(f"f(x) = xÂ³ - 2xÂ² + 3x - 1")
print(f"f(2) = {jax_polynomial(x_jax)}")
print(f"f'(2) = {gradient}")
print(f"Expected: {expected}")
print(f"Match: {jnp.allclose(gradient, expected)} âœ“")

### JAX Neural Network Activations

In [None]:
def jax_relu_network(x):
    """Simple ReLU network"""
    return jnp.sum(jax.nn.relu(x * x - 1.0))

djax_relu_network = tangent.grad(jax_relu_network)

x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
gradient = djax_relu_network(x)

print(f"x = {x}")
print(f"f(x) = sum(relu(xÂ² - 1))")
print(f"f(x) = {jax_relu_network(x)}")
print(f"\nGradient: {gradient}")
print(f"\nNote: Gradient is zero where xÂ² - 1 < 0 (ReLU inactive)")

### Comparing JAX Operations

In [None]:
# Compare activation functions and their gradients
import matplotlib.pyplot as plt

# Define activation functions as regular functions (not lambdas)
def sigmoid_sum(x):
    return jnp.sum(jax.nn.sigmoid(x))

def tanh_sum(x):
    return jnp.sum(jnp.tanh(x))

def relu_sum(x):
    return jnp.sum(jax.nn.relu(x))

def elu_sum(x):
    return jnp.sum(jax.nn.elu(x))

# Create visualization
x = np.linspace(-3, 3, 200)

activations = {
    'Sigmoid': (sigmoid_sum, 'blue'),
    'Tanh': (tanh_sum, 'green'),
    'ReLU': (relu_sum, 'red'),
    'ELU': (elu_sum, 'orange'),
}

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for idx, (name, (func, color)) in enumerate(activations.items()):
    dfunc = tangent.grad(func)
    
    # Compute values
    y = np.array([float(func(jnp.array([xi]))) for xi in x])
    dy = np.array([float(dfunc(jnp.array([xi]))) for xi in x])
    
    # Plot activation
    axes[0, idx].plot(x, y, color=color, linewidth=2)
    axes[0, idx].set_title(f'{name}', fontweight='bold')
    axes[0, idx].grid(True, alpha=0.3)
    axes[0, idx].set_ylabel('f(x)')
    
    # Plot gradient
    axes[1, idx].plot(x, dy, color=color, linewidth=2, linestyle='--')
    axes[1, idx].set_title(f'{name} Gradient', fontweight='bold')
    axes[1, idx].grid(True, alpha=0.3)
    axes[1, idx].set_xlabel('x')
    axes[1, idx].set_ylabel("f'(x)")

plt.suptitle('JAX Activation Functions and Their Gradients', 
             fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

## 6. Advanced Features

### Multiple Gradients

You can compute gradients with respect to multiple arguments!

In [None]:
def bivariate(x, y):
    """f(x, y) = xÂ²y + xyÂ²"""
    return x * x * y + x * y * y

# Compute gradients w.r.t. both x and y
dbivariate = tangent.grad(bivariate, wrt=(0, 1))

x, y = 2.0, 3.0
grad_x, grad_y = dbivariate(x, y)

# Expected:
# âˆ‚f/âˆ‚x = 2xy + yÂ²
# âˆ‚f/âˆ‚y = xÂ² + 2xy
expected_grad_x = 2 * x * y + y * y
expected_grad_y = x * x + 2 * x * y

print(f"f(x, y) = xÂ²y + xyÂ²")
print(f"f({x}, {y}) = {bivariate(x, y)}")
print(f"\nâˆ‚f/âˆ‚x = {grad_x}, expected = {expected_grad_x}")
print(f"âˆ‚f/âˆ‚y = {grad_y}, expected = {expected_grad_y}")
print(f"\nMatch: {abs(grad_x - expected_grad_x) < 1e-5 and abs(grad_y - expected_grad_y) < 1e-5} âœ“")

### Visualizing Multivariate Gradients

Let's visualize the gradient field of a 2D function.

In [None]:
def rosenbrock(x, y):
    """Rosenbrock function: (1-x)Â² + 100(y-xÂ²)Â²"""
    return (1 - x)**2 + 100 * (y - x**2)**2

drosenbrock = tangent.grad(rosenbrock, wrt=(0, 1))

# Create grid
x = np.linspace(-2, 2, 50)
y = np.linspace(-1, 3, 50)
X, Y = np.meshgrid(x, y)

# Compute function values
Z = np.zeros_like(X)
for i in range(len(x)):
    for j in range(len(y)):
        Z[j, i] = rosenbrock(X[j, i], Y[j, i])

# Compute gradients for quiver plot
U = np.zeros_like(X)
V = np.zeros_like(Y)
for i in range(0, len(x), 3):  # Subsample for clarity
    for j in range(0, len(y), 3):
        grad_x, grad_y = drosenbrock(X[j, i], Y[j, i])
        U[j, i] = -grad_x  # Negative for gradient descent direction
        V[j, i] = -grad_y

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Contour plot
contour = ax1.contour(X, Y, Z, levels=20, cmap='viridis')
ax1.clabel(contour, inline=True, fontsize=8)
ax1.plot(1, 1, 'r*', markersize=15, label='Minimum (1, 1)')
ax1.set_xlabel('x', fontsize=12)
ax1.set_ylabel('y', fontsize=12)
ax1.set_title('Rosenbrock Function Contours', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Gradient field
ax2.contour(X, Y, Z, levels=20, cmap='viridis', alpha=0.3)
skip = (slice(None, None, 3), slice(None, None, 3))
ax2.quiver(X[skip], Y[skip], U[skip], V[skip], 
           color='red', alpha=0.6, scale=500, width=0.003)
ax2.plot(1, 1, 'r*', markersize=15, label='Minimum (1, 1)')
ax2.set_xlabel('x', fontsize=12)
ax2.set_ylabel('y', fontsize=12)
ax2.set_title('Gradient Descent Direction', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.suptitle('Rosenbrock Function: f(x,y) = (1-x)Â² + 100(y-xÂ²)Â²', 
             fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

### Preserving Results

Sometimes you want both the function value AND its gradient. Use `preserve_result=True`!

In [None]:
def expensive_function(x):
    """A function where we want both value and gradient"""
    return np.sum(np.exp(x) * np.sin(x))

# Get both gradient and result
dexpensive = tangent.grad(expensive_function, preserve_result=True)

x = np.array([0.0, 1.0, 2.0])
gradient, result = dexpensive(x)

print(f"Input: {x}")
print(f"Function value: {result}")
print(f"Gradient: {gradient}")
print(f"\nVerify function value: {expensive_function(x)}")

## 7. Visualization & Debugging

### Gradient Checking

Let's verify our gradients against numerical differentiation.

In [None]:
def numerical_gradient(f, x, eps=1e-7):
    """Compute gradient using finite differences."""
    grad = np.zeros_like(x)
    for i in range(len(x)):
        x_plus = x.copy()
        x_minus = x.copy()
        x_plus[i] += eps
        x_minus[i] -= eps
        grad[i] = (f(x_plus) - f(x_minus)) / (2 * eps)
    return grad

def complex_function(x):
    """A more complex function to test"""
    return np.sum(x**3 - 2*x**2 + np.exp(x/10))

dcomplex = tangent.grad(complex_function)

x = np.array([1.0, 2.0, 3.0])
tangent_grad = dcomplex(x)
numerical_grad = numerical_gradient(complex_function, x)

print("Gradient Checking")
print("="*50)
print(f"Input: {x}")
print(f"\nTangent gradient:   {tangent_grad}")
print(f"Numerical gradient: {numerical_grad}")
print(f"\nDifference: {np.abs(tangent_grad - numerical_grad)}")
print(f"Max error: {np.max(np.abs(tangent_grad - numerical_grad)):.2e}")
print(f"\nGradients match: {np.allclose(tangent_grad, numerical_grad, atol=1e-5)} âœ“")

### Performance Comparison

Let's compare Tangent's performance across different backends.

In [None]:
import time

def benchmark_gradient(name, grad_func, x, n_iterations=1000):
    """Benchmark gradient computation."""
    # Warmup
    for _ in range(10):
        _ = grad_func(x)
    
    # Benchmark
    start = time.time()
    for _ in range(n_iterations):
        _ = grad_func(x)
    elapsed = time.time() - start
    
    return elapsed / n_iterations

# Define same function for each backend (using simple polynomial)
def numpy_func(x):
    return np.sum(x**3 - 2*x**2 + x)

def tf_func(x):
    return tf.reduce_sum(x**3 - 2*x**2 + x)

def jax_func(x):
    return jnp.sum(x**3 - 2*x**2 + x)

# Create gradient functions
dnumpy = tangent.grad(numpy_func)
dtf = tangent.grad(tf_func)
djax = tangent.grad(jax_func)

# Benchmark
x_np = np.random.randn(100)
x_tf = tf.constant(x_np)
x_jax = jnp.array(x_np)

results = {
    'NumPy': benchmark_gradient('NumPy', dnumpy, x_np),
    'TensorFlow': benchmark_gradient('TensorFlow', dtf, x_tf),
    'JAX': benchmark_gradient('JAX', djax, x_jax),
}

# Plot results
fig, ax = plt.subplots(figsize=(10, 6))
backends = list(results.keys())
times = [results[b] * 1000 for b in backends]  # Convert to milliseconds

bars = ax.bar(backends, times, color=['blue', 'orange', 'green'], alpha=0.7)
ax.set_ylabel('Time per gradient (ms)', fontsize=12)
ax.set_title('Tangent Performance Across Backends', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, time in zip(bars, times):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{time:.3f} ms',
            ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

print("\nBenchmark Results:")
print("="*50)
for backend, time_val in results.items():
    print(f"{backend:12s}: {time_val*1000:.3f} ms per gradient")

## 8. Real-World Examples

### Example 1: Linear Regression with Gradient Descent

In [None]:
# Generate synthetic data
np.random.seed(42)
X_data = np.random.randn(100, 1)
y_data = 3 * X_data + 2 + np.random.randn(100, 1) * 0.5

def mse_loss(w, b, X, y):
    """Mean squared error loss."""
    predictions = w * X + b
    return np.mean((predictions - y) ** 2)

# Compute gradients
dmse_dw = tangent.grad(mse_loss, wrt=(0,))
dmse_db = tangent.grad(mse_loss, wrt=(1,))

# Gradient descent
w, b = 0.0, 0.0
learning_rate = 0.1
n_epochs = 50
loss_history = []

for epoch in range(n_epochs):
    # Compute gradients
    grad_w = dmse_dw(w, b, X_data, y_data)
    grad_b = dmse_db(w, b, X_data, y_data)
    
    # Update parameters
    w -= learning_rate * grad_w
    b -= learning_rate * grad_b
    
    # Track loss
    loss = mse_loss(w, b, X_data, y_data)
    loss_history.append(loss)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss = {loss:.4f}, w = {w:.4f}, b = {b:.4f}")

print(f"\nFinal parameters: w = {w:.4f}, b = {b:.4f}")
print(f"True parameters:  w = 3.0000, b = 2.0000")

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Data and fit
ax1.scatter(X_data, y_data, alpha=0.5, label='Data')
x_line = np.linspace(X_data.min(), X_data.max(), 100)
ax1.plot(x_line, w * x_line + b, 'r-', linewidth=2, label=f'Fit: y = {w:.2f}x + {b:.2f}')
ax1.set_xlabel('x', fontsize=12)
ax1.set_ylabel('y', fontsize=12)
ax1.set_title('Linear Regression Result', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss curve
ax2.plot(loss_history, 'b-', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('MSE Loss', fontsize=12)
ax2.set_title('Training Loss', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.suptitle('Linear Regression with Tangent', fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

### Example 2: Logistic Regression

In [None]:
# Generate binary classification data
np.random.seed(42)
n_samples = 200
X_class = np.random.randn(n_samples, 2)
y_class = (X_class[:, 0] + X_class[:, 1] > 0).astype(float).reshape(-1, 1)

def sigmoid(z):
    return 1.0 / (1.0 + np.exp(-z))

def logistic_loss_vec(w1, w2, b, x1_vec, x2_vec, y_vec):
    """Binary cross-entropy loss for logistic regression.
    
    Takes pre-extracted column vectors to avoid array indexing issues with Tangent.
    Uses simple operations that Tangent can handle.
    
    Args:
        w1, w2, b: Model parameters (scalars)
        x1_vec, x2_vec: Feature columns (1D arrays)
        y_vec: Labels (1D array)
    """
    # Compute linear combination
    z = w1 * x1_vec + w2 * x2_vec + b
    
    # Apply sigmoid
    predictions = sigmoid(z)
    
    # Add small epsilon for numerical stability (simpler than np.clip for Tangent)
    eps = 1e-10
    predictions = predictions + eps
    
    # Binary cross-entropy
    return -np.mean(y_vec * np.log(predictions) + (1 - y_vec) * np.log(1 - predictions + eps))

# Extract column vectors once (outside the loss function)
x1_data = X_class[:, 0]
x2_data = X_class[:, 1]
y_data = y_class[:, 0]

# Compute gradients separately for each parameter
dloss_dw1 = tangent.grad(logistic_loss_vec, wrt=(0,))
dloss_dw2 = tangent.grad(logistic_loss_vec, wrt=(1,))
dloss_db = tangent.grad(logistic_loss_vec, wrt=(2,))

# Training
w1, w2, b = 0.0, 0.0, 0.0
learning_rate = 0.1
n_epochs = 100
loss_history = []

for epoch in range(n_epochs):
    # Compute gradients
    grad_w1 = dloss_dw1(w1, w2, b, x1_data, x2_data, y_data)
    grad_w2 = dloss_dw2(w1, w2, b, x1_data, x2_data, y_data)
    grad_b = dloss_db(w1, w2, b, x1_data, x2_data, y_data)
    
    # Update parameters
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
    b -= learning_rate * grad_b
    
    # Track loss
    loss = logistic_loss_vec(w1, w2, b, x1_data, x2_data, y_data)
    loss_history.append(loss)

print(f"Final parameters: w1 = {w1:.4f}, w2 = {w2:.4f}, b = {b:.4f}")

# Visualize decision boundary
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Decision boundary
x1_min, x1_max = X_class[:, 0].min() - 1, X_class[:, 0].max() + 1
x2_min, x2_max = X_class[:, 1].min() - 1, X_class[:, 1].max() + 1
xx1, xx2 = np.meshgrid(np.linspace(x1_min, x1_max, 100),
                       np.linspace(x2_min, x2_max, 100))
Z = sigmoid(w1 * xx1 + w2 * xx2 + b)

ax1.contourf(xx1, xx2, Z, levels=20, cmap='RdBu', alpha=0.6)
ax1.scatter(X_class[y_class.flatten() == 0, 0], 
           X_class[y_class.flatten() == 0, 1], 
           c='blue', marker='o', label='Class 0', alpha=0.7)
ax1.scatter(X_class[y_class.flatten() == 1, 0], 
           X_class[y_class.flatten() == 1, 1], 
           c='red', marker='s', label='Class 1', alpha=0.7)
ax1.set_xlabel('Feature 1', fontsize=12)
ax1.set_ylabel('Feature 2', fontsize=12)
ax1.set_title('Logistic Regression Decision Boundary', fontsize=14, fontweight='bold')
ax1.legend()

# Loss curve
ax2.plot(loss_history, 'b-', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax2.set_title('Training Loss', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.suptitle('Logistic Regression with Tangent', fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

### Example 3: Neural Network Training (JAX)

In [None]:
# Simple 2-layer neural network with JAX
def neural_network(W1, b1, W2, b2, X, y):
    """Two-layer neural network."""
    # Forward pass
    hidden = jax.nn.relu(jnp.dot(X, W1) + b1)
    output = jax.nn.sigmoid(jnp.dot(hidden, W2) + b2)
    
    # Binary cross-entropy loss
    output = jnp.clip(output, 1e-10, 1 - 1e-10)
    loss = -jnp.mean(y * jnp.log(output) + (1 - y) * jnp.log(1 - output))
    return loss

# Generate spiral dataset
def make_spiral_data(n_points=100, noise=0.2):
    n = n_points // 2
    theta = np.linspace(0, 4 * np.pi, n)
    
    # Class 0
    r0 = theta + np.random.randn(n) * noise
    x0 = r0 * np.cos(theta)
    y0 = r0 * np.sin(theta)
    
    # Class 1
    r1 = theta + np.random.randn(n) * noise
    x1 = -r1 * np.cos(theta)
    y1 = -r1 * np.sin(theta)
    
    X = np.vstack([np.column_stack([x0, y0]), np.column_stack([x1, y1])])
    y = np.hstack([np.zeros(n), np.ones(n)]).reshape(-1, 1)
    
    return jnp.array(X), jnp.array(y)

# Data
np.random.seed(42)
X_spiral, y_spiral = make_spiral_data(n_points=200)

# Initialize parameters
input_dim, hidden_dim, output_dim = 2, 10, 1
W1 = jnp.array(np.random.randn(input_dim, hidden_dim) * 0.1)
b1 = jnp.zeros((1, hidden_dim))
W2 = jnp.array(np.random.randn(hidden_dim, output_dim) * 0.1)
b2 = jnp.zeros((1, output_dim))

# Compute gradients
dnn_dW1 = tangent.grad(neural_network, wrt=(0,))
dnn_db1 = tangent.grad(neural_network, wrt=(1,))
dnn_dW2 = tangent.grad(neural_network, wrt=(2,))
dnn_db2 = tangent.grad(neural_network, wrt=(3,))

# Training
learning_rate = 0.5
n_epochs = 500
loss_history = []

print("Training neural network...")
for epoch in range(n_epochs):
    # Compute gradients
    grad_W1 = dnn_dW1(W1, b1, W2, b2, X_spiral, y_spiral)
    grad_b1 = dnn_db1(W1, b1, W2, b2, X_spiral, y_spiral)
    grad_W2 = dnn_dW2(W1, b1, W2, b2, X_spiral, y_spiral)
    grad_b2 = dnn_db2(W1, b1, W2, b2, X_spiral, y_spiral)
    
    # Update
    W1 = W1 - learning_rate * grad_W1
    b1 = b1 - learning_rate * grad_b1
    W2 = W2 - learning_rate * grad_W2
    b2 = b2 - learning_rate * grad_b2
    
    loss = neural_network(W1, b1, W2, b2, X_spiral, y_spiral)
    loss_history.append(float(loss))
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: loss = {loss:.4f}")

print("Training complete!")

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Decision boundary
x_min, x_max = X_spiral[:, 0].min() - 1, X_spiral[:, 0].max() + 1
y_min, y_max = X_spiral[:, 1].min() - 1, X_spiral[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
                     np.linspace(y_min, y_max, 200))
grid = jnp.column_stack([xx.ravel(), yy.ravel()])

# Compute predictions
hidden = jax.nn.relu(jnp.dot(grid, W1) + b1)
predictions = jax.nn.sigmoid(jnp.dot(hidden, W2) + b2)
Z = predictions.reshape(xx.shape)

ax1.contourf(xx, yy, Z, levels=20, cmap='RdBu', alpha=0.6)
ax1.scatter(X_spiral[y_spiral.flatten() == 0, 0], 
           X_spiral[y_spiral.flatten() == 0, 1], 
           c='blue', marker='o', label='Class 0', alpha=0.7, edgecolors='black')
ax1.scatter(X_spiral[y_spiral.flatten() == 1, 0], 
           X_spiral[y_spiral.flatten() == 1, 1], 
           c='red', marker='s', label='Class 1', alpha=0.7, edgecolors='black')
ax1.set_xlabel('Feature 1', fontsize=12)
ax1.set_ylabel('Feature 2', fontsize=12)
ax1.set_title('Neural Network Decision Boundary', fontsize=14, fontweight='bold')
ax1.legend()

# Loss curve
ax2.plot(loss_history, 'b-', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax2.set_title('Training Loss', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.suptitle('Neural Network with JAX & Tangent', fontsize=16, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

## Conclusion

In this tutorial, we've covered:

1. **Basic Concepts**: How Tangent transforms source code to compute gradients
2. **Multi-Backend Support**: Using Tangent with NumPy, TensorFlow, and JAX
3. **Advanced Features**: Multiple gradients, result preservation, and caching
4. **Visualization**: Plotting functions, gradients, and decision boundaries
5. **Real-World Examples**: Linear regression, logistic regression, and neural networks

### Key Takeaways

- **Readable**: Generated gradient code is Python you can read and debug
- **Efficient**: No tape overhead, compiled gradients run fast
- **Flexible**: Works with multiple backends seamlessly
- **Educational**: Perfect for learning how autodiff works under the hood

### Next Steps

- Try Tangent on your own functions
- Explore the [documentation](https://github.com/google/tangent)
- Implement custom gradient rules for your operations
- Compare with other autodiff libraries

### Resources

- GitHub: https://github.com/google/tangent
- Paper: [Automatic differentiation in ML: Where we are and where we should be going](https://arxiv.org/abs/1810.11530)
- Tutorial: [Source-to-Source Differentiation](https://github.com/google/tangent/blob/master/docs/walkthrough.md)

Happy differentiating! ðŸŽ‰