# Deep Learning: U-Net Reconstruction

This notebook demonstrates using deep learning models for seismic reconstruction.

**Prerequisites:**
- Python 3.10+
- PyTorch basics
- GPU recommended

**Estimated Runtime:** 10 minutes (CPU), 3 minutes (GPU)

**Topics Covered:**
- U-Net architecture overview
- Model loading and inference
- Batch processing

In [None]:
# Uncomment to install:
# !pip install promethium-seismic==1.0.0

In [None]:
import promethium
from promethium import (
    SeismicRecoveryPipeline,
    get_model,
    run_recovery,
    InferenceEngine,
    generate_synthetic_traces,
    add_noise,
    evaluate_reconstruction,
    set_seed,
    get_device,
)

import numpy as np
import torch
import matplotlib.pyplot as plt

print(f"Promethium version: {promethium.__version__}")
print(f"PyTorch version: {torch.__version__}")

set_seed(42)
device = get_device()
print(f"Using device: {device}")

## 1. U-Net Architecture Overview

The U-Net is an encoder-decoder architecture with skip connections:

```
Input -> Encoder (downsample) -> Bottleneck -> Decoder (upsample) -> Output
           |___________________Skip Connections_________________|
```

**Key Properties:**
- Preserves spatial information via skip connections
- Captures multi-scale features
- Effective for denoising and interpolation

## 2. Create Test Data

In [None]:
# Generate clean and noisy data
clean_data, metadata = generate_synthetic_traces(
    n_traces=64,
    n_samples=256,
    sample_rate=250.0,
    seed=42
)

noisy_data = add_noise(clean_data, noise_level=0.3, seed=42)

print(f"Data shape: {clean_data.shape}")
print(f"Clean data range: [{clean_data.min():.3f}, {clean_data.max():.3f}]")
print(f"Noisy data range: [{noisy_data.min():.3f}, {noisy_data.max():.3f}]")

## 3. Load Pre-configured Pipeline

In [None]:
# View available presets
presets = SeismicRecoveryPipeline.list_presets()
print("Available presets:")
for preset in presets:
    print(f"  - {preset}")

In [None]:
# Load U-Net denoising pipeline
try:
    pipeline = SeismicRecoveryPipeline.from_preset('unet_denoise_v1')
    print(f"Loaded pipeline: {pipeline.model_name}")
    print(f"Device: {pipeline.device}")
    
    # Print model summary
    if hasattr(pipeline, 'model') and pipeline.model is not None:
        total_params = sum(p.numel() for p in pipeline.model.parameters())
        print(f"Total parameters: {total_params:,}")
except Exception as e:
    print(f"Pipeline loading note: {e}")
    print("Using fallback inference for demonstration.")
    pipeline = None

## 4. Run Inference

In [None]:
# Run recovery
if pipeline is not None:
    try:
        reconstructed = pipeline.run(noisy_data)
        print(f"Reconstructed shape: {reconstructed.shape}")
    except Exception as e:
        print(f"Inference note: {e}")
        reconstructed = None
else:
    reconstructed = None

# Fallback: demonstrate with simple denoising
if reconstructed is None:
    print("Using demonstration fallback (Gaussian filter)")
    from scipy.ndimage import gaussian_filter1d
    reconstructed = np.array([gaussian_filter1d(t, sigma=2) for t in noisy_data])

print(f"Output shape: {reconstructed.shape}")

## 5. Evaluate Results

In [None]:
# Compute metrics
noisy_metrics = evaluate_reconstruction(clean_data, noisy_data)
recon_metrics = evaluate_reconstruction(clean_data, reconstructed)

print("Performance Comparison")
print("=" * 50)
print(f"{'Metric':>20} {'Noisy':>12} {'Reconstructed':>15}")
print("-" * 50)
for metric in ['snr', 'psnr', 'ssim', 'mse']:
    n_val = noisy_metrics[metric]
    r_val = recon_metrics[metric]
    improvement = r_val - n_val if metric != 'mse' else n_val - r_val
    print(f"{metric.upper():>20} {n_val:>12.4f} {r_val:>15.4f}")

In [None]:
# Visual comparison
trace_idx = 16
t = np.arange(metadata['n_samples']) / metadata['sample_rate']

fig, axes = plt.subplots(4, 1, figsize=(14, 10), sharex=True)

axes[0].plot(t, clean_data[trace_idx], 'b-', linewidth=0.8)
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Original Clean Signal')
axes[0].grid(True, alpha=0.3)

axes[1].plot(t, noisy_data[trace_idx], 'r-', linewidth=0.8)
axes[1].set_ylabel('Amplitude')
axes[1].set_title('Noisy Input')
axes[1].grid(True, alpha=0.3)

axes[2].plot(t, reconstructed[trace_idx], 'g-', linewidth=0.8)
axes[2].set_ylabel('Amplitude')
axes[2].set_title('Reconstructed Output')
axes[2].grid(True, alpha=0.3)

# Overlay
axes[3].plot(t, clean_data[trace_idx], 'b-', linewidth=0.8, label='Original', alpha=0.7)
axes[3].plot(t, reconstructed[trace_idx], 'g--', linewidth=0.8, label='Reconstructed', alpha=0.7)
axes[3].set_xlabel('Time (s)')
axes[3].set_ylabel('Amplitude')
axes[3].set_title('Comparison')
axes[3].legend()
axes[3].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Batch Processing with InferenceEngine

In [None]:
# For large datasets, use InferenceEngine with patch-based processing
print("InferenceEngine Parameters:")
print("  - Supports sliding window extraction")
print("  - Automatic patch blending")
print("  - GPU memory management")
print("  - Batch size optimization")

# Example configuration
inference_config = {
    'patch_size': (64, 64),
    'stride': (32, 32),
    'batch_size': 16,
    'blend_mode': 'linear',
}

print(f"\nExample config: {inference_config}")

## 7. Summary

This notebook demonstrated:

1. **U-Net Architecture**: Encoder-decoder with skip connections
2. **Pipeline Loading**: Using `SeismicRecoveryPipeline.from_preset()`
3. **Inference**: Running recovery on noisy data
4. **Evaluation**: Quantitative metrics comparison
5. **Batch Processing**: InferenceEngine for large datasets

### Next Steps

- **06_gan_based_high_fidelity_reconstruction.ipynb**: GAN models
- **14_advanced_model_customization_and_training.ipynb**: Custom training