# B2B Operator Learning: Basis-to-Basis Transformations

**Learning Objectives:**
- Master the B2B (Basis-to-Basis) framework for operator learning
- Learn operators as explicit transformation matrices
- Implement the same examples as DeepONet for direct comparison
- Understand zero-shot and few-shot operator learning
- Compare B2B with DeepONet on performance and interpretability

**Examples covered (same as DeepONet):**
1. Derivative operator
2. Poisson equation solver  
3. 1D nonlinear Darcy flow

---

## The B2B Framework

**Core idea:** Decompose operator learning into three steps:

1. **Encode source:** $f \xrightarrow{E_1} c_f \in \mathbb{R}^{n_1}$
2. **Transform:** $c_f \xrightarrow{A} c_g \in \mathbb{R}^{n_2}$
3. **Decode target:** $c_g \xrightarrow{D_2} g$

The operator $\mathcal{G}$ is represented as: $\mathcal{G}[f] \approx D_2(A \cdot E_1(f))$

**Key advantage:** The transformation matrix $A$ is explicit and interpretable!

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve
from scipy.stats import multivariate_normal
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else 
                     "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: Function Encoder Architecture

First, we need function encoders to learn representations of source and target function spaces.

In [None]:
class FunctionEncoder(nn.Module):
    """Function encoder for learning basis representations"""
    
    def __init__(self, sensor_dim, n_basis, hidden_dim=64):
        super().__init__()
        self.sensor_dim = sensor_dim
        self.n_basis = n_basis
        
        # Encoder: maps function samples to coefficients
        self.encoder = nn.Sequential(
            nn.Linear(sensor_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, n_basis)
        )
        
        # Decoder: generates basis functions at query points
        self.decoder = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, n_basis)
        )
    
    def encode(self, function_samples):
        """Extract coefficients from function samples"""
        return self.encoder(function_samples)
    
    def decode_basis(self, x):
        """Get basis function values at points x"""
        return self.decoder(x)
    
    def reconstruct(self, coefficients, x):
        """Reconstruct function from coefficients"""
        if x.dim() == 2:
            x = x.unsqueeze(0)
        
        batch_size, n_points, _ = x.shape
        basis_values = self.decoder(x.reshape(-1, 1))
        basis_values = basis_values.view(batch_size, n_points, self.n_basis)
        
        if coefficients.dim() == 1:
            coefficients = coefficients.unsqueeze(0)
        
        return torch.einsum('bn,bpn->bp', coefficients, basis_values)
    
    def forward(self, function_samples, query_points):
        coeffs = self.encode(function_samples)
        return self.reconstruct(coeffs, query_points)


class B2BOperator:
    """B2B Operator Learning Framework"""
    
    def __init__(self, source_encoder, target_encoder):
        self.source_encoder = source_encoder
        self.target_encoder = target_encoder
        self.transformation_matrix = None
    
    def learn_transformation(self, source_functions, target_functions, regularization=1e-6):
        """Learn transformation matrix A using least squares"""
        
        self.source_encoder.eval()
        self.target_encoder.eval()
        
        with torch.no_grad():
            # Encode all functions
            source_coeffs = self.source_encoder.encode(source_functions)
            target_coeffs = self.target_encoder.encode(target_functions)
        
        # Solve least squares: Y = X @ A.T
        # Add regularization for stability
        X = source_coeffs.cpu()
        Y = target_coeffs.cpu()
        
        # Regularized least squares
        XtX = X.T @ X + regularization * torch.eye(X.shape[1])
        XtY = X.T @ Y
        A = torch.linalg.solve(XtX, XtY).T
        
        self.transformation_matrix = A.to(device)
        
        # Compute fitting error
        Y_pred = X @ A.T
        mse = F.mse_loss(Y_pred, Y).item()
        
        return A, mse
    
    def apply(self, source_function, query_points):
        """Apply the learned operator"""
        if self.transformation_matrix is None:
            raise ValueError("Transformation matrix not learned yet")
        
        self.source_encoder.eval()
        self.target_encoder.eval()
        
        with torch.no_grad():
            # Encode source
            source_coeffs = self.source_encoder.encode(source_function)
            
            # Transform
            target_coeffs = source_coeffs @ self.transformation_matrix.T
            
            # Decode
            return self.target_encoder.reconstruct(target_coeffs, query_points)


print("B2B Framework initialized")
print("Components: Encoder → Transformation → Decoder")

## Part 2: Helper Functions for Training

In [None]:
def train_encoder(encoder, functions, x_points, n_epochs=500, lr=1e-3, name="Encoder"):
    """Train a function encoder to reconstruct functions"""
    
    optimizer = optim.Adam(encoder.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=0.5)
    
    losses = []
    encoder.train()
    
    x_tensor = torch.tensor(x_points, dtype=torch.float32).unsqueeze(-1).to(device)
    
    pbar = tqdm(range(n_epochs), desc=f"Training {name}")
    for epoch in pbar:
        # Random batch
        idx = np.random.choice(len(functions), min(32, len(functions)))
        batch_functions = torch.tensor(functions[idx], dtype=torch.float32).to(device)
        
        # Prepare query points
        batch_x = x_tensor.unsqueeze(0).repeat(len(idx), 1, 1)
        
        # Forward pass
        pred = encoder(batch_functions, batch_x)
        loss = F.mse_loss(pred, batch_functions)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        scheduler.step(loss)
        
        if epoch % 50 == 0:
            pbar.set_postfix({'Loss': f'{loss.item():.6f}'})
    
    return losses


def visualize_basis_functions(encoder, x_range=(-2, 2), name="Encoder"):
    """Visualize learned basis functions"""
    
    x = torch.linspace(x_range[0], x_range[1], 200).unsqueeze(-1).to(device)
    
    with torch.no_grad():
        basis = encoder.decode_basis(x).cpu().numpy()
    
    x = x.cpu().numpy().squeeze()
    
    plt.figure(figsize=(10, 4))
    for i in range(min(basis.shape[1], 10)):
        plt.plot(x, basis[:, i], linewidth=2, alpha=0.7, label=f'φ_{i+1}')
    
    plt.title(f'{name} Basis Functions')
    plt.xlabel('x')
    plt.grid(True, alpha=0.3)
    if basis.shape[1] <= 10:
        plt.legend(ncol=2)
    plt.show()


print("Helper functions defined")

## Example 1: The Derivative Operator

Same as DeepONet: Learn $\mathcal{D}[u] = \frac{du}{dx}$ for cubic polynomials.

In [None]:
# Generate polynomial data (same as DeepONet)
def generate_polynomial_data(num_functions=2000, num_points=100, x_range=(-2, 2)):
    """Generate cubic polynomials and their derivatives"""
    np.random.seed(42)
    
    coeffs = np.random.randn(num_functions, 4) * 0.5
    x = np.linspace(x_range[0], x_range[1], num_points)
    
    functions = np.zeros((num_functions, num_points))
    derivatives = np.zeros((num_functions, num_points))
    
    for i in range(num_functions):
        a, b, c, d = coeffs[i]
        functions[i] = a * x**3 + b * x**2 + c * x + d
        derivatives[i] = 3 * a * x**2 + 2 * b * x + c
    
    return coeffs, x, functions, derivatives

print("=== DERIVATIVE OPERATOR EXAMPLE ===")
coeffs, x, functions, derivatives = generate_polynomial_data()

# Split data
n_train = 1600
train_functions = functions[:n_train]
train_derivatives = derivatives[:n_train]
test_functions = functions[n_train:]
test_derivatives = derivatives[n_train:]

print(f"Data: {n_train} training, {len(test_functions)} test functions")
print(f"Domain: x ∈ [{x[0]:.1f}, {x[-1]:.1f}]")

# Visualize samples
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i in range(3):
    ax = axes[i]
    ax.plot(x, functions[i], 'b-', linewidth=2, label='f(x)')
    ax.plot(x, derivatives[i], 'r-', linewidth=2, label="f'(x)")
    ax.set_title(f'Sample {i+1}')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()
    ax.set_xlabel('x')

plt.tight_layout()
plt.show()

### Train Function Encoders

In [None]:
print("\nTraining function encoders...")

# Create encoders
derivative_source_encoder = FunctionEncoder(sensor_dim=100, n_basis=4, hidden_dim=32).to(device)
derivative_target_encoder = FunctionEncoder(sensor_dim=100, n_basis=3, hidden_dim=32).to(device)

# Train source encoder (cubic polynomials)
print("\n1. Source encoder (cubic space):")
source_losses = train_encoder(derivative_source_encoder, train_functions, x, 
                             n_epochs=300, name="Cubic Encoder")

# Train target encoder (quadratic polynomials)
print("\n2. Target encoder (quadratic space):")
target_losses = train_encoder(derivative_target_encoder, train_derivatives, x, 
                             n_epochs=300, name="Quadratic Encoder")

# Visualize basis functions
visualize_basis_functions(derivative_source_encoder, name="Source (Cubic)")
visualize_basis_functions(derivative_target_encoder, name="Target (Quadratic)")

print(f"\nFinal losses - Source: {source_losses[-1]:.6f}, Target: {target_losses[-1]:.6f}")

### Learn and Apply the Derivative Operator

In [None]:
# Create B2B operator
derivative_b2b = B2BOperator(derivative_source_encoder, derivative_target_encoder)

# Learn transformation matrix
print("Learning transformation matrix...")
train_source_tensor = torch.tensor(train_functions, dtype=torch.float32).to(device)
train_target_tensor = torch.tensor(train_derivatives, dtype=torch.float32).to(device)

A_derivative, fit_error = derivative_b2b.learn_transformation(
    train_source_tensor, train_target_tensor
)

print(f"Transformation matrix shape: {A_derivative.shape}")
print(f"Fitting error: {fit_error:.6f}")

# Visualize transformation matrix
plt.figure(figsize=(6, 5))
plt.imshow(A_derivative.cpu().numpy(), cmap='RdBu_r', aspect='auto')
plt.colorbar(label='Weight')
plt.title('Derivative Operator Transformation Matrix')
plt.xlabel('Source Basis (Cubic)')
plt.ylabel('Target Basis (Quadratic)')

# Annotate values
for i in range(A_derivative.shape[0]):
    for j in range(A_derivative.shape[1]):
        plt.text(j, i, f'{A_derivative[i,j].item():.2f}', 
                ha='center', va='center', fontsize=10)

plt.tight_layout()
plt.show()

### Test the Derivative Operator

In [None]:
# Test on unseen functions
n_test_vis = 6
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

test_errors = []
x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)

for i in range(n_test_vis):
    test_idx = i * 10
    
    # Apply B2B operator
    source_func = torch.tensor(test_functions[test_idx:test_idx+1], dtype=torch.float32).to(device)
    pred_derivative = derivative_b2b.apply(source_func, x_tensor).squeeze().cpu().numpy()
    
    true_derivative = test_derivatives[test_idx]
    
    # Compute error
    mse = np.mean((pred_derivative - true_derivative)**2)
    test_errors.append(mse)
    
    # Plot
    ax = axes[i]
    ax.plot(x, test_functions[test_idx], 'b-', linewidth=2, alpha=0.7, label='f(x)')
    ax.plot(x, true_derivative, 'g-', linewidth=2, label="True f'(x)")
    ax.plot(x, pred_derivative, 'r--', linewidth=2, label="B2B f'(x)")
    
    ax.set_title(f'Test {i+1}: MSE = {mse:.6f}')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()
    ax.set_xlabel('x')

plt.suptitle('B2B Derivative Operator Results', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Average test MSE: {np.mean(test_errors):.6f} ± {np.std(test_errors):.6f}")

## Example 2: Poisson Equation Solver

Learn the solution operator for the Poisson equation:
$$-\nabla^2 u = f \text{ in } \Omega, \quad u = 0 \text{ on } \partial\Omega$$

In 1D: $-\frac{d^2u}{dx^2} = f(x)$ with $u(0) = u(1) = 0$

In [None]:
def generate_poisson_data(n_samples=1000, n_points=100):
    """Generate Poisson equation data"""
    
    x = np.linspace(0, 1, n_points)
    dx = x[1] - x[0]
    
    # Create finite difference matrix for -d²/dx²
    main_diag = 2 * np.ones(n_points - 2) / dx**2
    off_diag = -np.ones(n_points - 3) / dx**2
    A_fd = diags([off_diag, main_diag, off_diag], [-1, 0, 1]).toarray()
    A_inv = np.linalg.inv(A_fd)
    
    sources = []
    solutions = []
    
    np.random.seed(42)
    
    for i in range(n_samples):
        # Generate random source function (combination of sines)
        f = np.zeros(n_points)
        n_modes = np.random.randint(2, 6)
        for k in range(n_modes):
            mode = np.random.randint(1, 10)
            amplitude = np.random.randn()
            phase = np.random.rand() * 2 * np.pi
            f += amplitude * np.sin(mode * np.pi * x + phase)
        
        # Solve Poisson equation
        u = np.zeros(n_points)
        u[1:-1] = A_inv @ f[1:-1]
        
        sources.append(f)
        solutions.append(u)
    
    return np.array(sources), np.array(solutions), x


print("\n=== POISSON EQUATION EXAMPLE ===")
poisson_sources, poisson_solutions, x_poisson = generate_poisson_data(n_samples=1500)

# Split data
n_train_poisson = 1200
train_sources_poisson = poisson_sources[:n_train_poisson]
train_solutions_poisson = poisson_solutions[:n_train_poisson]
test_sources_poisson = poisson_sources[n_train_poisson:]
test_solutions_poisson = poisson_solutions[n_train_poisson:]

print(f"Data: {n_train_poisson} training, {len(test_sources_poisson)} test")

# Visualize samples
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i in range(3):
    ax = axes[i]
    ax.plot(x_poisson, poisson_sources[i], 'r-', linewidth=2, label='f(x)')
    ax.plot(x_poisson, poisson_solutions[i], 'b-', linewidth=2, label='u(x)')
    ax.set_title(f'Sample {i+1}')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()
    ax.set_xlabel('x')

plt.suptitle('Poisson Equation: Source → Solution')
plt.tight_layout()
plt.show()

### Train Encoders for Poisson

In [None]:
print("\nTraining Poisson encoders...")

# Create encoders with more basis functions for this problem
poisson_source_encoder = FunctionEncoder(sensor_dim=100, n_basis=15, hidden_dim=64).to(device)
poisson_solution_encoder = FunctionEncoder(sensor_dim=100, n_basis=15, hidden_dim=64).to(device)

# Train encoders
print("1. Source encoder (f space):")
poisson_source_losses = train_encoder(poisson_source_encoder, train_sources_poisson, x_poisson,
                                     n_epochs=400, name="Source Encoder")

print("\n2. Solution encoder (u space):")
poisson_solution_losses = train_encoder(poisson_solution_encoder, train_solutions_poisson, x_poisson,
                                       n_epochs=400, name="Solution Encoder")

print(f"\nFinal losses - Source: {poisson_source_losses[-1]:.6f}, Solution: {poisson_solution_losses[-1]:.6f}")

### Learn and Apply Poisson Solver

In [None]:
# Create B2B operator for Poisson
poisson_b2b = B2BOperator(poisson_source_encoder, poisson_solution_encoder)

# Learn transformation
print("Learning Poisson transformation matrix...")
train_source_poisson_tensor = torch.tensor(train_sources_poisson, dtype=torch.float32).to(device)
train_solution_poisson_tensor = torch.tensor(train_solutions_poisson, dtype=torch.float32).to(device)

A_poisson, fit_error_poisson = poisson_b2b.learn_transformation(
    train_source_poisson_tensor, train_solution_poisson_tensor
)

print(f"Transformation matrix shape: {A_poisson.shape}")
print(f"Fitting error: {fit_error_poisson:.6f}")

# Visualize transformation matrix
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Matrix heatmap
ax = axes[0]
im = ax.imshow(A_poisson.cpu().numpy(), cmap='RdBu_r', aspect='auto')
plt.colorbar(im, ax=ax)
ax.set_title('Poisson Solver Transformation Matrix')
ax.set_xlabel('Source Basis')
ax.set_ylabel('Solution Basis')

# Singular values
ax = axes[1]
U, S, Vt = torch.linalg.svd(A_poisson.cpu())
ax.bar(range(len(S)), S.numpy())
ax.set_title('Singular Values')
ax.set_xlabel('Index')
ax.set_ylabel('Value')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Test Poisson Solver

In [None]:
# Test Poisson solver
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

poisson_test_errors = []
x_poisson_tensor = torch.tensor(x_poisson, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)

for i in range(6):
    test_idx = i * 10
    
    # Apply B2B operator
    source_func = torch.tensor(test_sources_poisson[test_idx:test_idx+1], dtype=torch.float32).to(device)
    pred_solution = poisson_b2b.apply(source_func, x_poisson_tensor).squeeze().cpu().numpy()
    
    true_solution = test_solutions_poisson[test_idx]
    
    # Compute error
    mse = np.mean((pred_solution - true_solution)**2)
    rel_error = np.sqrt(mse) / np.sqrt(np.mean(true_solution**2) + 1e-8)
    poisson_test_errors.append(rel_error)
    
    # Plot
    ax = axes[i]
    ax.plot(x_poisson, test_sources_poisson[test_idx], 'r-', linewidth=2, alpha=0.7, label='f(x)')
    ax.plot(x_poisson, true_solution, 'b-', linewidth=2, label='True u(x)')
    ax.plot(x_poisson, pred_solution, 'g--', linewidth=2, label='B2B u(x)')
    
    ax.set_title(f'Test {i+1}: Rel. Error = {rel_error:.4f}')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()
    ax.set_xlabel('x')

plt.suptitle('B2B Poisson Solver Results', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Average relative error: {np.mean(poisson_test_errors):.4f} ± {np.std(poisson_test_errors):.4f}")

## Example 3: 1D Nonlinear Darcy Flow

Same as DeepONet: Solve the nonlinear Darcy equation with solution-dependent permeability.

In [None]:
def generate_darcy_data(n_funcs=1000, n_points=40):
    """Generate 1D nonlinear Darcy flow data"""
    
    def permeability(s):
        return 0.2 + s**2
    
    # Gaussian process for source function
    x = np.linspace(0, 1, n_points)
    l, sigma = 0.04, 1.0
    K = sigma**2 * np.exp(-0.5 * (x[:, None] - x[None, :])**2 / l**2)
    K += 1e-6 * np.eye(n_points)
    
    def solve_darcy(u_func):
        dx = x[1] - x[0]
        s = np.zeros(n_points)
        
        for _ in range(100):  # Fixed point iteration
            kappa = permeability(s)
            main_diag = (kappa[1:] + kappa[:-1]) / dx**2
            upper_diag = -kappa[1:-1] / dx**2
            lower_diag = -kappa[1:-1] / dx**2
            
            A = diags([lower_diag, main_diag, upper_diag], [-1, 0, 1], 
                     shape=(n_points-2, n_points-2))
            
            s_interior = spsolve(A, u_func[1:-1])
            s_new = np.zeros(n_points)
            s_new[1:-1] = s_interior
            s = 0.5 * s_new + 0.5 * s
        
        return s
    
    # Generate dataset
    np.random.seed(42)
    U, S = [], []
    
    print("Generating Darcy dataset...")
    for i in tqdm(range(n_funcs), desc="Solving PDEs"):
        u = multivariate_normal.rvs(mean=np.zeros(n_points), cov=K)
        s = solve_darcy(u)
        U.append(u)
        S.append(s)
    
    return np.array(U), np.array(S), x


print("\n=== 1D NONLINEAR DARCY EXAMPLE ===")
darcy_sources, darcy_solutions, x_darcy = generate_darcy_data(n_funcs=1000)

# Split data
n_train_darcy = 800
train_sources_darcy = darcy_sources[:n_train_darcy]
train_solutions_darcy = darcy_solutions[:n_train_darcy]
test_sources_darcy = darcy_sources[n_train_darcy:]
test_solutions_darcy = darcy_solutions[n_train_darcy:]

print(f"\nData: {n_train_darcy} training, {len(test_sources_darcy)} test")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
for i in range(3):
    ax = axes[i]
    ax.plot(x_darcy, darcy_sources[i], 'g-', linewidth=2, label='f(x)')
    ax.plot(x_darcy, darcy_solutions[i], 'b-', linewidth=2, label='u(x)')
    ax.set_title(f'Sample {i+1}')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()
    ax.set_xlabel('x')

plt.suptitle('1D Nonlinear Darcy Flow')
plt.tight_layout()
plt.show()

### Train Encoders for Darcy

In [None]:
print("\nTraining Darcy encoders...")

# Create encoders
darcy_source_encoder = FunctionEncoder(sensor_dim=40, n_basis=20, hidden_dim=128).to(device)
darcy_solution_encoder = FunctionEncoder(sensor_dim=40, n_basis=20, hidden_dim=128).to(device)

# Train
print("1. Source encoder:")
darcy_source_losses = train_encoder(darcy_source_encoder, train_sources_darcy, x_darcy,
                                   n_epochs=500, lr=0.001, name="Darcy Source")

print("\n2. Solution encoder:")
darcy_solution_losses = train_encoder(darcy_solution_encoder, train_solutions_darcy, x_darcy,
                                     n_epochs=500, lr=0.001, name="Darcy Solution")

print(f"\nFinal losses - Source: {darcy_source_losses[-1]:.6f}, Solution: {darcy_solution_losses[-1]:.6f}")

### Learn and Apply Darcy Operator

In [None]:
# Create B2B operator for Darcy
darcy_b2b = B2BOperator(darcy_source_encoder, darcy_solution_encoder)

# Learn transformation
print("Learning Darcy transformation matrix...")
train_source_darcy_tensor = torch.tensor(train_sources_darcy, dtype=torch.float32).to(device)
train_solution_darcy_tensor = torch.tensor(train_solutions_darcy, dtype=torch.float32).to(device)

A_darcy, fit_error_darcy = darcy_b2b.learn_transformation(
    train_source_darcy_tensor, train_solution_darcy_tensor
)

print(f"Transformation matrix shape: {A_darcy.shape}")
print(f"Fitting error: {fit_error_darcy:.6f}")

# Analyze transformation
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ax = axes[0]
im = ax.imshow(A_darcy.cpu().numpy(), cmap='RdBu_r', aspect='auto')
plt.colorbar(im, ax=ax)
ax.set_title('Darcy Operator Transformation Matrix')
ax.set_xlabel('Source Basis')
ax.set_ylabel('Solution Basis')

ax = axes[1]
U, S, Vt = torch.linalg.svd(A_darcy.cpu())
ax.bar(range(len(S)), S.numpy())
ax.set_title('Singular Values')
ax.set_xlabel('Index')
ax.set_ylabel('Value')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nMatrix rank: {torch.linalg.matrix_rank(A_darcy).item()}")
print(f"Condition number: {torch.linalg.cond(A_darcy).item():.2f}")

### Test Darcy Operator

In [None]:
# Test Darcy operator
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

darcy_test_errors = []
x_darcy_tensor = torch.tensor(x_darcy, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)

for i in range(6):
    test_idx = i * 5
    
    # Apply B2B
    source_func = torch.tensor(test_sources_darcy[test_idx:test_idx+1], dtype=torch.float32).to(device)
    pred_solution = darcy_b2b.apply(source_func, x_darcy_tensor).squeeze().cpu().numpy()
    
    true_solution = test_solutions_darcy[test_idx]
    
    # Error
    mse = np.mean((pred_solution - true_solution)**2)
    rel_error = np.sqrt(mse) / np.sqrt(np.mean(true_solution**2) + 1e-8)
    darcy_test_errors.append(rel_error)
    
    # Plot
    ax = axes[i]
    ax.plot(x_darcy, test_sources_darcy[test_idx], 'g-', linewidth=2, alpha=0.7, label='f(x)')
    ax.plot(x_darcy, true_solution, 'b-', linewidth=2, label='True u(x)')
    ax.plot(x_darcy, pred_solution, 'r--', linewidth=2, label='B2B u(x)')
    
    ax.set_title(f'Test {i+1}: Rel. Error = {rel_error:.4f}')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.legend()
    ax.set_xlabel('x')

plt.suptitle('B2B Darcy Operator Results', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Average relative error: {np.mean(darcy_test_errors):.4f} ± {np.std(darcy_test_errors):.4f}")

## Zero-Shot and Few-Shot Learning

One of B2B's key advantages: learn new operators with minimal data!

In [None]:
def test_few_shot_learning(source_encoder, target_encoder, 
                          train_sources, train_targets,
                          test_sources, test_targets, 
                          x_points, operator_name="Operator"):
    """Test operator learning with varying amounts of data"""
    
    sample_sizes = [5, 10, 25, 50, 100, 200, 400]
    errors_by_size = []
    
    x_tensor = torch.tensor(x_points, dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)
    
    for n_samples in sample_sizes:
        if n_samples > len(train_sources):
            continue
            
        # Learn with limited data
        b2b = B2BOperator(source_encoder, target_encoder)
        
        train_source_subset = torch.tensor(train_sources[:n_samples], dtype=torch.float32).to(device)
        train_target_subset = torch.tensor(train_targets[:n_samples], dtype=torch.float32).to(device)
        
        b2b.learn_transformation(train_source_subset, train_target_subset)
        
        # Test
        test_errors = []
        for i in range(min(50, len(test_sources))):
            source = torch.tensor(test_sources[i:i+1], dtype=torch.float32).to(device)
            pred = b2b.apply(source, x_tensor).squeeze().cpu().numpy()
            true = test_targets[i]
            
            mse = np.mean((pred - true)**2)
            rel_error = np.sqrt(mse) / (np.sqrt(np.mean(true**2)) + 1e-8)
            test_errors.append(rel_error)
        
        avg_error = np.mean(test_errors)
        errors_by_size.append((n_samples, avg_error))
        print(f"{n_samples:3d} samples: {avg_error:.6f}")
    
    return errors_by_size


print("\n=== FEW-SHOT LEARNING EXPERIMENTS ===")

print("\nDerivative Operator:")
deriv_few_shot = test_few_shot_learning(
    derivative_source_encoder, derivative_target_encoder,
    train_functions, train_derivatives,
    test_functions, test_derivatives,
    x, "Derivative"

)

print("\nPoisson Solver:")
poisson_few_shot = test_few_shot_learning(
    poisson_source_encoder, poisson_solution_encoder,
    train_sources_poisson, train_solutions_poisson,
    test_sources_poisson, test_solutions_poisson,
    x_poisson, "Poisson"
)

### Visualize Few-Shot Learning

In [None]:
# Plot few-shot learning results
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Derivative
ax = axes[0]
sizes, errors = zip(*deriv_few_shot)
ax.plot(sizes, errors, 'o-', linewidth=2, markersize=8)
ax.set_xlabel('Number of Training Samples')
ax.set_ylabel('Test Relative Error')
ax.set_title('Derivative Operator - Few-Shot Learning')
ax.set_xscale('log')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
ax.axhline(y=0.01, color='r', linestyle='--', alpha=0.5, label='1% error')
ax.legend()

# Poisson
ax = axes[1]
sizes, errors = zip(*poisson_few_shot)
ax.plot(sizes, errors, 'o-', linewidth=2, markersize=8, color='orange')
ax.set_xlabel('Number of Training Samples')
ax.set_ylabel('Test Relative Error')
ax.set_title('Poisson Solver - Few-Shot Learning')
ax.set_xscale('log')
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
ax.axhline(y=0.01, color='r', linestyle='--', alpha=0.5, label='1% error')
ax.legend()

plt.tight_layout()
plt.show()

print("\nKey Insight: B2B achieves good accuracy with very few samples!")
print("This is because the encoders already capture the function space structure.")

## Comparison Summary: B2B vs DeepONet

In [None]:
# Create comparison table
comparison_data = [
    ['Aspect', 'B2B Framework', 'DeepONet'],
    ['Architecture', 'Encoder → Transform → Decoder', 'Branch-Trunk'],
    ['Operator Representation', 'Explicit matrix A', 'Implicit in weights'],
    ['Interpretability', 'High (visible transformation)', 'Medium (basis visible)'],
    ['Few-shot learning', 'Excellent', 'Limited'],
    ['Transfer learning', 'Natural (reuse encoders)', 'Difficult'],
    ['Training', 'Two-stage', 'End-to-end'],
    ['Best for', 'Linear/weakly nonlinear', 'Highly nonlinear']
]

# Display as formatted table
fig, ax = plt.subplots(figsize=(12, 6))
ax.axis('tight')
ax.axis('off')

table = ax.table(cellText=comparison_data[1:], colLabels=comparison_data[0],
                cellLoc='left', loc='center', colWidths=[0.25, 0.375, 0.375])

table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1, 2)

# Style header
for i in range(3):
    table[(0, i)].set_facecolor('#40466e')
    table[(0, i)].set_text_props(weight='bold', color='white')

# Alternate row colors
for i in range(1, len(comparison_data)):
    for j in range(3):
        if i % 2 == 0:
            table[(i, j)].set_facecolor('#f0f0f0')

plt.title('B2B vs DeepONet Comparison', fontsize=14, fontweight='bold', pad=20)
plt.show()

print("\n" + "="*60)
print("PERFORMANCE SUMMARY")
print("="*60)
print(f"\nDerivative Operator:")
print(f"  B2B Test MSE: {np.mean(test_errors):.6f}")
print(f"  With 10 samples: {dict(deriv_few_shot)[10]:.6f}")

print(f"\nPoisson Solver:")
print(f"  B2B Test Error: {np.mean(poisson_test_errors):.4f}")
print(f"  With 25 samples: {dict(poisson_few_shot)[25]:.4f}")

print(f"\nDarcy Flow:")
print(f"  B2B Test Error: {np.mean(darcy_test_errors):.4f}")

## Summary and Key Takeaways

### What We've Learned

1. **B2B Framework** decomposes operator learning into:
   - Function encoding (learning basis representations)
   - Transformation learning (explicit matrix)
   - Function decoding (reconstruction from coefficients)

2. **Three Examples** (same as DeepONet):
   - **Derivative operator:** Perfect for B2B (linear operator)
   - **Poisson solver:** Good performance with interpretable structure
   - **Darcy flow:** Handles nonlinearity through learned representations

3. **Key Advantages:**
   - **Interpretability:** Transformation matrix reveals operator structure
   - **Few-shot learning:** Excellent performance with minimal data
   - **Transfer learning:** Reuse encoders for related operators
   - **Modularity:** Separate concerns enable flexibility

### When to Use B2B

✅ **Ideal for:**
- Linear or approximately linear operators
- Multiple related operators on same spaces
- Limited training data scenarios
- Need for interpretable operator representations
- Transfer learning applications

❌ **Consider alternatives when:**
- Operators are highly nonlinear
- Single operator with abundant data
- End-to-end optimization preferred

### Practical Guidelines

1. **Basis size selection:**
   - Start with expected dimensionality
   - Increase if reconstruction error is high
   - Use cross-validation for optimal size

2. **Encoder training:**
   - Ensure good reconstruction before operator learning
   - Pre-train on diverse function samples
   - Share encoders across related problems

3. **Transformation learning:**
   - Use regularization for stability
   - Check singular values for conditioning
   - Visualize matrix for insights

---

**The Big Picture:** B2B provides an interpretable, sample-efficient alternative to DeepONet, especially powerful for linear operators and transfer learning scenarios!