# Fourier Neural Operator - Exercise

This notebook contains exercises to implement and experiment with FNO.

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## Exercise 1: Understanding Fourier Transform

### Task

Implement a function that computes the derivative using Fourier transform.

**Theory**: The derivative in Fourier space is:

$$
\widehat{\frac{du}{dx}}(k) = 2\pi i k \cdot \hat{u}(k)
$$

**Steps**:
1. Apply FFT to $u(x)$
2. Multiply by $2\pi i k$
3. Apply inverse FFT

### Your Code

In [None]:
def fourier_derivative(u, dx):
    """
    Compute derivative using Fourier transform
    
    Args:
        u: Function values (n_points,)
        dx: Grid spacing
    
    Returns:
        du_dx: Derivative (n_points,)
    """
    n = len(u)
    
    # TODO: Implement this function
    # Hint 1: Use torch.fft.fft and torch.fft.ifft
    # Hint 2: Frequencies are torch.fft.fftfreq(n, d=dx)
    # Hint 3: Multiply by 2*pi*i*k in frequency space
    
    raise NotImplementedError("Implement fourier_derivative")
    
    return du_dx

In [None]:
# Test your implementation
x = torch.linspace(0, 2*np.pi, 100)
u = torch.sin(x)  # u(x) = sin(x)
du_true = torch.cos(x)  # du/dx = cos(x)

dx = x[1] - x[0]
du_computed = fourier_derivative(u, dx)

plt.figure(figsize=(10, 4))
plt.plot(x, du_true, 'k-', label='True derivative', linewidth=2)
plt.plot(x, du_computed.real, 'r--', label='Fourier derivative', linewidth=2)
plt.xlabel('x')
plt.ylabel('du/dx')
plt.legend()
plt.grid(True, alpha=0.3)
plt.title('Derivative of sin(x)')
plt.show()

error = torch.abs(du_true - du_computed.real).mean()
print(f"Mean absolute error: {error:.6f}")
print(f"Test passed: {error < 1e-5}")

## Exercise 2: Implement Spectral Convolution Layer

### Task

Complete the `SpectralConv1d` class. The forward pass should:
1. Apply real FFT to input
2. Multiply by learned complex weights (only for low frequencies)
3. Apply inverse FFT

In [None]:
class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes
        
        # Initialize weights
        scale = 1 / (in_channels * out_channels)
        self.weights = nn.Parameter(
            scale * torch.rand(in_channels, out_channels, modes, 2)
        )
    
    def forward(self, x):
        # x: (batch, channels, n_points)
        batch_size = x.shape[0]
        
        # TODO: Implement the forward pass
        # Step 1: Apply FFT using torch.fft.rfft
        # Step 2: Convert weights to complex using torch.view_as_complex
        # Step 3: Multiply x_ft by weights (use torch.einsum)
        # Step 4: Apply inverse FFT using torch.fft.irfft
        
        raise NotImplementedError("Implement SpectralConv1d.forward")
        
        return x

# Test your implementation
layer = SpectralConv1d(in_channels=4, out_channels=4, modes=8)
x = torch.randn(2, 4, 32)
y = layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Test passed: {y.shape == x.shape}")

## Exercise 3: Build a Simple FNO

### Task

Construct a simple FNO with 2 Fourier layers to learn the antiderivative operator:

$$
G: f \mapsto u \text{ where } \frac{du}{dx} = f
$$

Given $f(x)$, predict $u(x) = \int_0^x f(s) ds$.

In [None]:
# Generate training data
def generate_antiderivative_data(n_samples, n_points=128):
    x = np.linspace(0, 2*np.pi, n_points)
    
    f_all = []
    u_all = []
    
    for _ in range(n_samples):
        # Random function: sum of sines
        n_modes = np.random.randint(2, 6)
        f = np.zeros(n_points)
        u = np.zeros(n_points)
        
        for _ in range(n_modes):
            k = np.random.randint(1, 5)
            amp = np.random.randn()
            phase = np.random.uniform(0, 2*np.pi)
            
            f += amp * np.cos(k*x + phase)
            u += (amp/k) * np.sin(k*x + phase)
        
        f_all.append(f)
        u_all.append(u)
    
    return np.array(f_all), np.array(u_all), x

n_train = 500
n_test = 100
n_points = 128

f_train, u_train, x = generate_antiderivative_data(n_train, n_points)
f_test, u_test, _ = generate_antiderivative_data(n_test, n_points)

print(f"Training: {f_train.shape}")
print(f"Test: {f_test.shape}")

In [None]:
# Visualize examples
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for i in range(3):
    axes[0, i].plot(x, f_train[i])
    axes[0, i].set_title(f'f(x) - Example {i+1}')
    axes[0, i].set_xlabel('x')
    axes[0, i].set_ylabel('f(x)')
    axes[0, i].grid(True, alpha=0.3)
    
    axes[1, i].plot(x, u_train[i])
    axes[1, i].set_title(r'$u(x) = \int f(x)dx$')
    axes[1, i].set_xlabel('x')
    axes[1, i].set_ylabel('u(x)')
    axes[1, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
class SimpleFNO(nn.Module):
    """Simple FNO for learning antiderivative"""
    def __init__(self, modes, width):
        super().__init__()
        
        # TODO: Build your FNO
        # Components needed:
        # 1. Lifting layer: nn.Linear(2, width) for (x, f(x))
        # 2. Two Fourier layers: SpectralConv1d(width, width, modes)
        # 3. Two local layers: nn.Conv1d(width, width, 1)
        # 4. Projection: nn.Linear(width, 1)
        
        raise NotImplementedError("Implement SimpleFNO")
    
    def forward(self, x):
        # TODO: Implement forward pass
        # 1. Lift input
        # 2. For each Fourier layer: x = gelu(fourier(x) + local(x))
        # 3. Project to output
        
        raise NotImplementedError("Implement SimpleFNO.forward")
        
        return x

# Initialize model
model = SimpleFNO(modes=16, width=32).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Prepare data
def prepare_data(f, x):
    n_samples = f.shape[0]
    n_points = len(x)
    grid = np.tile(x.reshape(1, -1, 1), (n_samples, 1, 1))
    values = f.reshape(n_samples, n_points, 1)
    data = np.concatenate([grid, values], axis=-1)
    return torch.FloatTensor(data)

X_train = prepare_data(f_train, x).to(device)
y_train = torch.FloatTensor(u_train).to(device)

X_test = prepare_data(f_test, x).to(device)
y_test = torch.FloatTensor(u_test).to(device)

print(f"X_train: {X_train.shape}")
print(f"y_train: {y_train.shape}")

In [None]:
# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 500
batch_size = 20

train_losses = []
test_losses = []

for epoch in tqdm(range(n_epochs)):
    model.train()
    perm = torch.randperm(len(X_train))
    train_loss = 0
    
    for i in range(0, len(X_train), batch_size):
        idx = perm[i:i+batch_size]
        X_batch = X_train[idx]
        y_batch = y_train[idx]
        
        optimizer.zero_grad()
        y_pred = model(X_batch)
        loss = F.mse_loss(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * len(X_batch)
    
    train_loss /= len(X_train)
    train_losses.append(train_loss)
    
    model.eval()
    with torch.no_grad():
        y_pred = model(X_test)
        test_loss = F.mse_loss(y_pred, y_test).item()
        test_losses.append(test_loss)
    
    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch+1}: train_loss={train_loss:.6f}, test_loss={test_loss:.6f}")

In [None]:
# Plot results
plt.figure(figsize=(10, 4))
plt.semilogy(train_losses, label='Train')
plt.semilogy(test_losses, label='Test')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training History')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Evaluate
model.eval()
with torch.no_grad():
    y_pred_test = model(X_test).cpu().numpy()

y_test_np = y_test.cpu().numpy()

fig, axes = plt.subplots(3, 3, figsize=(15, 12))

for i in range(3):
    # Input f(x)
    axes[0, i].plot(x, f_test[i])
    axes[0, i].set_title(f'Input f(x) - {i+1}')
    axes[0, i].grid(True, alpha=0.3)
    
    # Prediction
    axes[1, i].plot(x, y_test_np[i], 'k-', label='True', linewidth=2)
    axes[1, i].plot(x, y_pred_test[i], 'r--', label='FNO', linewidth=2)
    axes[1, i].set_title(r'Antiderivative $\int f dx$')
    axes[1, i].legend()
    axes[1, i].grid(True, alpha=0.3)
    
    # Error
    error = np.abs(y_test_np[i] - y_pred_test[i])
    axes[2, i].plot(x, error)
    axes[2, i].set_title(f'Error (MAE={np.mean(error):.4f})')
    axes[2, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Exercise 4: Resolution Invariance Test

### Task

Test if your trained FNO can handle different resolutions:
1. Train on $n=128$ points
2. Evaluate on $n=256$ points
3. Evaluate on $n=64$ points

Compare the errors.

In [None]:
# TODO: Implement resolution test
# Steps:
# 1. Generate test data at n=256 and n=64
# 2. Interpolate f(x) to these resolutions
# 3. Compute true antiderivative
# 4. Run model predictions
# 5. Compare MAE for different resolutions

def test_resolution(model, f, x_original, n_points_new):
    """Test model at different resolution"""
    # TODO: Implement this
    raise NotImplementedError("Implement test_resolution")

# Test at different resolutions
resolutions = [64, 128, 256]
errors = []

for n in resolutions:
    error = test_resolution(model, f_test[0], x, n)
    errors.append(error)
    print(f"Resolution {n}: MAE = {error:.6f}")

# Plot errors vs resolution
plt.figure(figsize=(8, 5))
plt.plot(resolutions, errors, 'o-', linewidth=2, markersize=8)
plt.xlabel('Resolution (n_points)')
plt.ylabel('MAE')
plt.title('Resolution Invariance Test')
plt.grid(True, alpha=0.3)
plt.show()

## Exercise 5: Effect of Fourier Modes

### Task

Study how the number of Fourier modes affects performance:
1. Train FNOs with `modes = [4, 8, 16, 32]`
2. Compare test error and number of parameters
3. Plot error vs modes

**Question**: What happens when `modes` is too small? Too large?

In [None]:
# TODO: Implement mode comparison
# Train multiple models with different modes
# Compare:
# - Test error
# - Number of parameters
# - Training time

mode_values = [4, 8, 16, 32]
results = []

for modes in mode_values:
    # TODO: Train and evaluate model
    pass

# Plot results
# TODO: Create comparison plots

## Exercise 6: Compare FNO with MLP

### Task

Build an MLP baseline that maps $(x, f(x)) \to u(x)$ point-wise.

Compare:
1. Test error
2. Number of parameters
3. Resolution invariance

**Question**: Why does FNO have better resolution invariance?

In [None]:
class MLPBaseline(nn.Module):
    """Point-wise MLP baseline"""
    def __init__(self, hidden_size=128):
        super().__init__()
        # TODO: Build MLP
        # Input: (x, f(x)) - 2 dimensions
        # Output: u(x) - 1 dimension
        # Architecture: 2 -> hidden -> hidden -> 1
        raise NotImplementedError("Implement MLPBaseline")
    
    def forward(self, x):
        # TODO: Implement forward
        # Note: MLP processes each point independently
        raise NotImplementedError("Implement MLPBaseline.forward")
        return x

# TODO: Train and compare with FNO

## Bonus Exercise: 2D Problem

### Task

Implement a 2D Poisson solver:

$$
-\Delta u = f, \quad u|_{\partial D} = 0
$$

Use a 2D FNO to learn the mapping $f \to u$.

**Hint**: You'll need `SpectralConv2d` and `torch.fft.rfft2`.

In [None]:
# TODO: Implement 2D Poisson solver with FNO
# 1. Generate data: random source f, solve for u
# 2. Build 2D FNO
# 3. Train and evaluate
# 4. Test resolution invariance in 2D

## Summary Questions

Answer these after completing the exercises:

1. **Why does FNO use FFT instead of learning convolution kernels directly?**

2. **What is the computational complexity of one FNO layer?**
   - With FFT: ?
   - Without FFT (direct convolution): ?

3. **Why is FNO resolution-invariant but CNN is not?**

4. **What happens if you set `modes` = `n_points`? Why is truncation useful?**

5. **Compare FNO with DeepONet**:
   - When would you use FNO?
   - When would you use DeepONet?

6. **Can FNO handle non-uniform grids? Why or why not?**