# Matrix Completion and Compressive Sensing

This notebook demonstrates classical algorithmic approaches to seismic data recovery.

**Prerequisites:**
- Python 3.10+
- Linear algebra fundamentals

**Estimated Runtime:** 10 minutes

**Topics Covered:**
- Missing trace interpolation
- Matrix completion via SVD
- Compressive sensing basics

In [None]:
# Uncomment to install:
# !pip install promethium-seismic==1.0.0

In [None]:
import promethium
from promethium import (
    generate_synthetic_traces,
    evaluate_reconstruction,
    set_seed,
)
from promethium.utils.synthetic import create_missing_traces

import numpy as np
import matplotlib.pyplot as plt
from scipy import linalg

print(f"Promethium version: {promethium.__version__}")
set_seed(42)

## 1. Create Dataset with Missing Traces

In [None]:
# Generate clean data
clean_data, metadata = generate_synthetic_traces(
    n_traces=100,
    n_samples=500,
    sample_rate=250.0,
    seed=42
)

# Create dataset with 30% missing traces
corrupted_data, mask = create_missing_traces(
    clean_data,
    missing_ratio=0.3,
    pattern='random',
    seed=42
)

n_missing = int(np.sum(mask == 0))
print(f"Total traces: {clean_data.shape[0]}")
print(f"Missing traces: {n_missing} ({100*n_missing/clean_data.shape[0]:.1f}%)")
print(f"Available traces: {int(np.sum(mask))}")

In [None]:
# Visualize missing pattern
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

clip = np.percentile(np.abs(clean_data), 99)

# Original
axes[0].imshow(clean_data.T, aspect='auto', cmap='seismic', vmin=-clip, vmax=clip)
axes[0].set_xlabel('Trace')
axes[0].set_ylabel('Sample')
axes[0].set_title('Original Complete Data')

# Corrupted
axes[1].imshow(corrupted_data.T, aspect='auto', cmap='seismic', vmin=-clip, vmax=clip)
axes[1].set_xlabel('Trace')
axes[1].set_ylabel('Sample')
axes[1].set_title('Data with Missing Traces')

plt.tight_layout()
plt.show()

## 2. SVD-Based Matrix Completion

In [None]:
def svd_matrix_completion(data, mask, rank=10, max_iter=100, tol=1e-5):
    """
    Complete missing entries using iterative SVD thresholding.
    
    Args:
        data: Input data with missing values (n_traces, n_samples)
        mask: Binary mask (1=present, 0=missing) of shape (n_traces,)
        rank: Target rank for low-rank approximation
        max_iter: Maximum iterations
        tol: Convergence tolerance
    """
    # Create full mask matrix
    full_mask = np.outer(mask, np.ones(data.shape[1]))
    
    # Initialize: fill missing with zeros (or mean)
    X = data.copy()
    
    for i in range(max_iter):
        X_old = X.copy()
        
        # SVD
        U, s, Vt = linalg.svd(X, full_matrices=False)
        
        # Truncate to rank
        s_truncated = np.zeros_like(s)
        s_truncated[:rank] = s[:rank]
        
        # Reconstruct
        X_approx = U @ np.diag(s_truncated) @ Vt
        
        # Fill missing values with approximation, keep observed values
        X = np.where(full_mask, data, X_approx)
        
        # Check convergence
        change = np.linalg.norm(X - X_old) / (np.linalg.norm(X_old) + 1e-10)
        if change < tol:
            print(f"Converged at iteration {i+1}")
            break
    
    return X

# Run SVD completion
print("Running SVD matrix completion...")
recovered_svd = svd_matrix_completion(corrupted_data, mask, rank=15, max_iter=50)
print(f"Completed shape: {recovered_svd.shape}")

## 3. Linear Interpolation Baseline

In [None]:
def linear_interpolation(data, mask):
    """Simple linear interpolation for missing traces."""
    result = data.copy()
    missing_idx = np.where(mask == 0)[0]
    present_idx = np.where(mask == 1)[0]
    
    for sample in range(data.shape[1]):
        # Interpolate each time sample independently
        present_values = data[present_idx, sample]
        result[:, sample] = np.interp(
            np.arange(data.shape[0]),
            present_idx,
            present_values
        )
    
    return result

# Run linear interpolation
print("Running linear interpolation...")
recovered_linear = linear_interpolation(corrupted_data, mask)
print(f"Completed shape: {recovered_linear.shape}")

## 4. Compare Methods

In [None]:
# Evaluate both methods
methods = {
    'Linear Interpolation': recovered_linear,
    'SVD Completion': recovered_svd,
}

print("Method Comparison")
print("=" * 60)
for name, recovered in methods.items():
    metrics = evaluate_reconstruction(clean_data, recovered)
    print(f"\n{name}:")
    print(f"  SNR:  {metrics['snr']:.2f} dB")
    print(f"  PSNR: {metrics['psnr']:.2f} dB")
    print(f"  SSIM: {metrics['ssim']:.4f}")
    print(f"  MSE:  {metrics['mse']:.6f}")

In [None]:
# Visual comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

clip = np.percentile(np.abs(clean_data), 99)

# Original
axes[0, 0].imshow(clean_data.T, aspect='auto', cmap='seismic', vmin=-clip, vmax=clip)
axes[0, 0].set_title('Original')

# Corrupted
axes[0, 1].imshow(corrupted_data.T, aspect='auto', cmap='seismic', vmin=-clip, vmax=clip)
axes[0, 1].set_title('Missing Traces')

# Linear
axes[1, 0].imshow(recovered_linear.T, aspect='auto', cmap='seismic', vmin=-clip, vmax=clip)
axes[1, 0].set_title('Linear Interpolation')

# SVD
axes[1, 1].imshow(recovered_svd.T, aspect='auto', cmap='seismic', vmin=-clip, vmax=clip)
axes[1, 1].set_title('SVD Completion')

for ax in axes.flatten():
    ax.set_xlabel('Trace')
    ax.set_ylabel('Sample')

plt.tight_layout()
plt.show()

## 5. Summary

This notebook demonstrated:

1. **Missing Trace Creation**: Random, regular, and block patterns
2. **SVD Matrix Completion**: Low-rank approximation approach
3. **Linear Interpolation**: Simple baseline method
4. **Comparison**: Quantitative and visual evaluation

### Next Steps

- **05_deep_learning_unet_reconstruction.ipynb**: Neural network approaches
- **06_gan_based_high_fidelity_reconstruction.ipynb**: GAN-based recovery