# Spatial Basis Denoising Demo

Simple demo showing how the `SpatialBasisLearner` can denoise spatial data.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from geospatial_neural_adapter.models.spatial_basis_learner import (
    SpatialBasisLearner, 
    train_spatial_basis
)

print("✅ Imports successful!")

## Generate Noisy Data

In [None]:
# Create clean spatial patterns
n_locations = 50
locations = torch.linspace(-3, 3, n_locations)

# Clean patterns
pattern1 = torch.exp(-(locations - 0.5)**2)  # Gaussian
pattern2 = torch.sin(locations * 2)          # Sine

# Generate clean data
n_samples = 100
clean_data = torch.zeros(n_samples, n_locations)
for i in range(n_samples):
    w1, w2 = torch.randn(2) * 2
    clean_data[i] = w1 * pattern1 + w2 * pattern2

# Add noise
noisy_data = clean_data + torch.randn_like(clean_data) * 0.3

print(f"Clean data: {clean_data.shape}")
print(f"Noisy data: {noisy_data.shape}")

## Denoise with Spatial Basis Learner

In [None]:
# Initialize and train
spatial_learner = SpatialBasisLearner(n_locations, latent_dim=2)

train_spatial_basis(
    model=spatial_learner,
    targets=noisy_data,
    epochs=3000,
    verbose=False,
    tau1=.01,
    tau2=.01,
)

# Denoise
with torch.no_grad():
    denoised_data = spatial_learner(noisy_data)

print("✅ Denoising complete!")

## Compare Results

In [None]:
# Calculate improvement
noise_mse = torch.mean((noisy_data - clean_data)**2)
denoised_mse = torch.mean((denoised_data - clean_data)**2)
improvement = (noise_mse - denoised_mse) / noise_mse * 100

print(f"Noise MSE: {noise_mse:.4f}")
print(f"Denoised MSE: {denoised_mse:.4f}")
print(f"Improvement: {improvement:.1f}%")

# Plot results
plt.figure(figsize=(15, 4))

# Sample comparison
plt.subplot(1, 4, 1)
plt.plot(locations, clean_data[0], 'g-', label='Clean', linewidth=2)
plt.plot(locations, noisy_data[0], 'r.', label='Noisy', alpha=0.6)
plt.plot(locations, denoised_data[0], 'b--', label='Denoised', linewidth=2)
plt.title('Sample Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

# True patterns
plt.subplot(1, 4, 2)
plt.plot(locations, pattern1, 'g-', label='Gaussian', linewidth=2)
plt.plot(locations, pattern2, 'r-', label='Sine', linewidth=2)
plt.title('True Patterns')
plt.legend()
plt.grid(True, alpha=0.3)

# Learned basis
plt.subplot(1, 4, 3)
basis = spatial_learner.basis.detach()
for i in range(basis.shape[1]):
    plt.plot(locations, basis[:, i], '--', label=f'Learned {i+1}', alpha=0.8)
plt.title('Learned Basis')
plt.legend()
plt.grid(True, alpha=0.3)

# MSE comparison
plt.subplot(1, 4, 4)
mse_noise = torch.mean((noisy_data - clean_data)**2, dim=1)
mse_denoised = torch.mean((denoised_data - clean_data)**2, dim=1)
plt.scatter(mse_noise, mse_denoised, alpha=0.6)
plt.plot([0, mse_noise.max()], [0, mse_noise.max()], 'r--')
plt.xlabel('Noise MSE')
plt.ylabel('Denoised MSE')
plt.title('MSE Comparison')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()