# GAN-Based High-Fidelity Reconstruction

Generative Adversarial Networks for seismic data enhancement.

**Prerequisites:** Deep learning basics, notebook 05

**Estimated Runtime:** 10 minutes

In [None]:
# !pip install promethium-seismic==1.0.1

In [None]:
import promethium
from promethium import (
    generate_synthetic_traces,
    add_noise,
    evaluate_reconstruction,
    set_seed,
    get_device,
)

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

set_seed(42)
device = get_device()
print(f"Promethium {promethium.__version__} | Device: {device}")

## 1. GAN Architecture Overview

GANs consist of two networks:
- **Generator**: Produces reconstructed signals from degraded input
- **Discriminator**: Distinguishes real from generated signals

Training alternates:
1. Train D to classify real vs fake
2. Train G to fool D

In [None]:
class Generator(nn.Module):
    def __init__(self, in_ch=1, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, hidden, 7, padding=3),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden*2, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden*2, hidden, 5, padding=2),
            nn.ReLU(),
            nn.Conv1d(hidden, in_ch, 7, padding=3),
            nn.Tanh(),
        )
    def forward(self, x): return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, in_ch=1, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, hidden, 7, stride=2, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(hidden, hidden*2, 5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(hidden*2, 1),
            nn.Sigmoid(),
        )
    def forward(self, x): return self.net(x)

G = Generator().to(device)
D = Discriminator().to(device)
print(f"Generator params: {sum(p.numel() for p in G.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in D.parameters()):,}")

## 2. Data Preparation

In [None]:
clean, meta = generate_synthetic_traces(n_traces=200, n_samples=256, seed=42)
noisy = add_noise(clean, noise_level=0.4, seed=42)

# Normalize to [-1, 1] for tanh output
clean_norm = clean / (np.abs(clean).max() + 1e-8)
noisy_norm = noisy / (np.abs(noisy).max() + 1e-8)

X = torch.tensor(noisy_norm, dtype=torch.float32).unsqueeze(1).to(device)
Y = torch.tensor(clean_norm, dtype=torch.float32).unsqueeze(1).to(device)

print(f"Data shape: {X.shape}")

## 3. GAN vs U-Net Comparison (Conceptual)

In [None]:
# For demonstration: simple forward pass without full training
G.eval()
with torch.no_grad():
    generated = G(X)

gen_np = generated.squeeze().cpu().numpy()

# Compare with simple baseline
from scipy.ndimage import gaussian_filter1d
baseline = np.array([gaussian_filter1d(t, sigma=2) for t in noisy_norm])

print("Method Comparison (untrained GAN for demo):")
print(f"Baseline MSE: {np.mean((clean_norm - baseline)**2):.6f}")
print(f"GAN MSE: {np.mean((clean_norm - gen_np)**2):.6f}")

In [None]:
# Visual comparison
idx = 10
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

axes[0,0].plot(clean_norm[idx], 'b-', lw=0.8)
axes[0,0].set_title('Original')

axes[0,1].plot(noisy_norm[idx], 'r-', lw=0.8)
axes[0,1].set_title('Noisy')

axes[1,0].plot(baseline[idx], 'g-', lw=0.8)
axes[1,0].set_title('Gaussian Filter')

axes[1,1].plot(gen_np[idx], 'm-', lw=0.8)
axes[1,1].set_title('GAN (untrained)')

for ax in axes.flatten():
    ax.set_xlabel('Sample')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

This notebook covered:
1. GAN architecture for signal reconstruction
2. Generator and Discriminator design
3. Comparison framework for methods

For full training, see notebook 14.