# Thin Plate Spline Demo

Simple demonstration of thin plate spline interpolation.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from geospatial_neural_adapter.cpp_extensions import thin_plate_spline

### 1D Example

In [None]:
num_locations = 100
x_1D = np.linspace(0, 1, num_locations)
true_y = np.sin(2 * np.pi * x_1D)
noise = np.random.normal(0, 0.2, num_locations) 
y_values = true_y + noise

smoothing_penalty_matrix = thin_plate_spline.smoothing_penalty_matrix(x_1D)

In [None]:
lambda_values = [0, 1e-2, 1e8]

plt.figure(figsize=(10, 6))
plt.plot(x_1D, y_values, 'o-', label="Noisy Data (y)", markersize=8, linewidth=2, alpha=0.7)
plt.plot(x_1D, true_y, 'k--', label="True Function", linewidth=3)

for lambda_smoothing in lambda_values:
    identity_matrix = np.eye(num_locations)
    g_values = np.linalg.solve(identity_matrix + lambda_smoothing * smoothing_penalty_matrix, y_values)
    
    plt.plot(x_1D, g_values, label=f"Smoothed (λ = {lambda_smoothing})", linewidth=2)

plt.title("Effect of Smoothing (g) for Different λ with Noisy Data")
plt.xlabel("X")
plt.ylabel("Values")
plt.legend()
plt.grid(True)
plt.show()

### 3D Example

In [None]:
# Generate sample data
np.random.seed(42)
n_points = 20

x = np.random.uniform(0, 10, n_points)
y = np.random.uniform(0, 10, n_points)
z = 2 * np.sin(x/2) + 1.5 * np.cos(y/3) + 0.3 * np.random.normal(0, 1, n_points)

print(f"Generated {n_points} data points")

# Create interpolation grid
xi = np.linspace(0, 10, 50)
yi = np.linspace(0, 10, 50)
xi_grid, yi_grid = np.meshgrid(xi, yi)
zi = griddata((x, y), z, (xi_grid, yi_grid), method="cubic")

In [None]:
# Plot interpolated surface
plt.figure(figsize=(10, 8))
plt.contourf(xi_grid, yi_grid, zi, levels=20, cmap="viridis")
plt.scatter(x, y, c="red", s=50, label="Data points")
plt.colorbar(label="Interpolated Z")
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Thin Plate Spline Interpolation")
plt.legend()
plt.show()

In [None]:
# 3D surface plot
fig = plt.figure(figsize=(12, 5))

ax1 = fig.add_subplot(121, projection="3d")
ax1.scatter(x, y, z, c=z, cmap="viridis", s=50)
ax1.set_title("Original Data (3D)")

ax2 = fig.add_subplot(122, projection="3d")
surf = ax2.plot_surface(xi_grid, yi_grid, zi, cmap="viridis", alpha=0.8)
ax2.scatter(x, y, z, c="red", s=30)
ax2.set_title("Interpolated Surface (3D)")

plt.tight_layout()
plt.show()