# SciTeX DSP (Digital Signal Processing) Tutorial

This notebook demonstrates the scitex.dsp module for digital signal processing, particularly focused on neural signal analysis including phase-amplitude coupling (PAC), power spectral density (PSD), and Hilbert transforms.

## Key Features Covered:
- Signal generation (demo signals)
- Power Spectral Density (PSD) analysis
- Hilbert transform for amplitude and phase extraction
- Phase-Amplitude Coupling (PAC) analysis
- Signal filtering and preprocessing
- Ripple detection in neural signals

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from pathlib import Path

# Add src to path for imports
sys.path.insert(0, str(Path('.').parent / "src"))

import scitex
import scitex.dsp as dsp
import scitex.plt as splt

print("SciTeX DSP Tutorial - Neural Signal Processing")
print("=" * 50)

## 1. Demo Signal Generation

The DSP module provides various types of demo signals for testing and demonstration purposes.

In [None]:
# Parameters for signal generation
batch_size = 2
n_channels = 4
duration_sec = 4.0
sampling_rate = 512  # Hz

# Available signal types
signal_types = ['periodic', 'chirp', 'gauss', 'uniform', 'pac', 'ripple']

print("Available demo signal types:")
for i, sig_type in enumerate(signal_types, 1):
    print(f"{i}. {sig_type}")

# Generate different types of signals
signals = {}
for sig_type in signal_types[:4]:  # Generate first 4 types
    try:
        x, t, fs = dsp.demo_sig(
            sig_type=sig_type,
            batch_size=1,
            n_chs=1,
            t_sec=duration_sec,
            fs=sampling_rate
        )
        signals[sig_type] = (x, t, fs)
        print(f"✅ Generated {sig_type} signal: shape {x.shape}")
    except Exception as e:
        print(f"⚠️ Failed to generate {sig_type}: {e}")

In [None]:
# Plot the generated signals
fig, axes = plt.subplots(len(signals), 1, figsize=(12, 8), sharex=True)
if len(signals) == 1:
    axes = [axes]

for ax, (sig_type, (x, t, fs)) in zip(axes, signals.items()):
    # Plot first batch, first channel
    ax.plot(t, x[0, 0], label=f'{sig_type} signal')
    ax.set_ylabel('Amplitude [μV]')
    ax.set_title(f'{sig_type.capitalize()} Signal')
    ax.legend()
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel('Time [s]')
fig.suptitle('Demo Signals Generated by SciTeX DSP', fontsize=14)
plt.tight_layout()
plt.show()

## 2. Power Spectral Density (PSD) Analysis

PSD analysis reveals the frequency content of signals, crucial for understanding neural oscillations.

In [None]:
# Generate a chirp signal for PSD analysis
x_chirp, t_chirp, fs = dsp.demo_sig(
    sig_type='chirp',
    batch_size=1,
    n_chs=1,
    t_sec=4.0,
    fs=512
)

print(f"Chirp signal shape: {x_chirp.shape}")
print(f"Sampling rate: {fs} Hz")
print(f"Duration: {len(t_chirp)/fs:.1f} seconds")

# Calculate PSD
try:
    psd_values, frequencies = dsp.psd(x_chirp, fs, prob=True)
    print(f"✅ PSD calculated: shape {psd_values.shape}")
    print(f"Frequency range: {frequencies[0]:.1f} - {frequencies[-1]:.1f} Hz")
except Exception as e:
    print(f"⚠️ PSD calculation failed: {e}")
    # Fallback: manual PSD calculation
    from scipy.signal import welch
    frequencies, psd_values = welch(x_chirp[0, 0], fs, nperseg=512)
    psd_values = np.log10(psd_values)
    print("✅ Used scipy.signal.welch as fallback")

In [None]:
# Plot signal and its PSD
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Time domain
ax1.plot(t_chirp, x_chirp[0, 0])
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Amplitude [μV]')
ax1.set_title('Chirp Signal (Time Domain)')
ax1.grid(True, alpha=0.3)

# Frequency domain
if hasattr(psd_values, 'shape') and len(psd_values.shape) > 1:
    psd_plot = psd_values[0, 0]  # First batch, first channel
else:
    psd_plot = psd_values

ax2.plot(frequencies, psd_plot)
ax2.set_xlabel('Frequency [Hz]')
ax2.set_ylabel('Log Power [μV²/Hz]')
ax2.set_title('Power Spectral Density')
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, fs/2)  # Nyquist frequency

plt.tight_layout()
plt.show()

## 3. Hilbert Transform for Amplitude and Phase

The Hilbert transform extracts instantaneous amplitude and phase from signals, essential for studying neural oscillations.

In [None]:
# Generate a periodic signal for Hilbert analysis
x_periodic, t_periodic, fs = dsp.demo_sig(
    sig_type='periodic',
    batch_size=1,
    n_chs=1,
    t_sec=2.0,
    fs=512,
    freqs_hz=[10, 25]  # 10 Hz and 25 Hz components
)

print(f"Periodic signal shape: {x_periodic.shape}")

# Apply Hilbert transform
try:
    phase, amplitude = dsp.hilbert(x_periodic, dim=-1)
    print(f"✅ Hilbert transform successful")
    print(f"Phase shape: {phase.shape}")
    print(f"Amplitude shape: {amplitude.shape}")
except Exception as e:
    print(f"⚠️ Hilbert transform failed: {e}")
    # Fallback using scipy
    from scipy.signal import hilbert as scipy_hilbert
    analytic_signal = scipy_hilbert(x_periodic[0, 0])
    amplitude = np.abs(analytic_signal)
    phase = np.angle(analytic_signal)
    print("✅ Used scipy.signal.hilbert as fallback")

In [None]:
# Plot Hilbert transform results
fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)

# Original signal
axes[0].plot(t_periodic, x_periodic[0, 0], 'b-', label='Original Signal')
axes[0].set_ylabel('Amplitude [μV]')
axes[0].set_title('Original Signal')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Instantaneous amplitude
if hasattr(amplitude, 'shape') and len(amplitude.shape) > 1:
    amp_plot = amplitude[0, 0]
else:
    amp_plot = amplitude

axes[1].plot(t_periodic, amp_plot, 'r-', label='Instantaneous Amplitude')
axes[1].set_ylabel('Amplitude [μV]')
axes[1].set_title('Instantaneous Amplitude (Hilbert Transform)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Instantaneous phase
if hasattr(phase, 'shape') and len(phase.shape) > 1:
    phase_plot = phase[0, 0]
else:
    phase_plot = phase

axes[2].plot(t_periodic, phase_plot, 'g-', label='Instantaneous Phase')
axes[2].set_xlabel('Time [s]')
axes[2].set_ylabel('Phase [rad]')
axes[2].set_title('Instantaneous Phase (Hilbert Transform)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Phase-Amplitude Coupling (PAC) Analysis

PAC measures the coupling between the phase of low-frequency oscillations and the amplitude of high-frequency oscillations, important in neuroscience.

In [None]:
# Generate PAC signal
try:
    x_pac, t_pac, fs = dsp.demo_sig(
        sig_type='pac',
        batch_size=1,
        n_chs=1,
        n_segments=1,
        t_sec=4.0,
        fs=512
    )
    print(f"PAC signal shape: {x_pac.shape}")
    
    # If signal has segments dimension, select first segment
    if len(x_pac.shape) == 4:
        x_pac_plot = x_pac[0, 0, 0, :]  # batch, channel, segment, time
    else:
        x_pac_plot = x_pac[0, 0, :]  # batch, channel, time
        
except Exception as e:
    print(f"⚠️ PAC signal generation failed: {e}")
    # Create a simple PAC-like signal manually
    fs = 512
    t_pac = np.linspace(0, 4, int(4 * fs))
    # Low frequency phase (theta: 6 Hz)
    theta_phase = 2 * np.pi * 6 * t_pac
    # High frequency amplitude modulated by theta phase (gamma: 80 Hz)
    gamma_amp = 1 + 0.5 * np.cos(theta_phase)
    gamma_signal = gamma_amp * np.sin(2 * np.pi * 80 * t_pac)
    theta_signal = np.sin(theta_phase)
    x_pac_plot = theta_signal + gamma_signal + 0.1 * np.random.randn(len(t_pac))
    x_pac = x_pac_plot.reshape(1, 1, -1)
    print("✅ Created synthetic PAC signal")

In [None]:
# Plot the PAC signal
fig, ax = plt.subplots(1, 1, figsize=(12, 4))

ax.plot(t_pac, x_pac_plot)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Amplitude [μV]')
ax.set_title('Phase-Amplitude Coupling (PAC) Demo Signal\n(Theta phase modulates Gamma amplitude)')
ax.grid(True, alpha=0.3)

# Zoom in on a short segment to see the coupling
ax.set_xlim(0, 1)  # First second only

plt.tight_layout()
plt.show()

In [None]:
# Calculate PAC using SciTeX
try:
    pac_values, phase_freqs, amp_freqs = dsp.pac(
        x_pac,
        fs,
        pha_start_hz=2,
        pha_end_hz=20,
        pha_n_bands=20,
        amp_start_hz=30,
        amp_end_hz=120,
        amp_n_bands=20,
        device='cpu',  # Use CPU for compatibility
        batch_size_ch=-1
    )
    
    print(f"✅ PAC calculation successful")
    print(f"PAC values shape: {pac_values.shape}")
    print(f"Phase frequencies: {len(phase_freqs)} bands ({phase_freqs[0]:.1f} - {phase_freqs[-1]:.1f} Hz)")
    print(f"Amplitude frequencies: {len(amp_freqs)} bands ({amp_freqs[0]:.1f} - {amp_freqs[-1]:.1f} Hz)")
    
    pac_calculated = True
    
except Exception as e:
    print(f"⚠️ PAC calculation failed: {e}")
    print("This may require PyTorch and CUDA setup")
    pac_calculated = False

In [None]:
# Plot PAC results if calculation was successful
if pac_calculated:
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    
    # Extract PAC matrix for first batch and channel
    if hasattr(pac_values, 'numpy'):
        pac_matrix = pac_values[0, 0].numpy()  # Convert from tensor if needed
    else:
        pac_matrix = pac_values[0, 0]
    
    # Create comodulogram
    im = ax.imshow(
        pac_matrix,
        aspect='auto',
        origin='lower',
        extent=[amp_freqs[0], amp_freqs[-1], phase_freqs[0], phase_freqs[-1]],
        cmap='viridis'
    )
    
    ax.set_xlabel('Amplitude Frequency [Hz]')
    ax.set_ylabel('Phase Frequency [Hz]')
    ax.set_title('Phase-Amplitude Coupling Comodulogram')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('PAC Strength')
    
    plt.tight_layout()
    plt.show()
else:
    print("PAC visualization skipped due to calculation failure")

## 5. Signal Preprocessing and Utilities

The DSP module includes various preprocessing functions for neural signals.

In [None]:
# Generate a longer signal for preprocessing examples
x_long, t_long, fs = dsp.demo_sig(
    sig_type='periodic',
    batch_size=1,
    n_chs=2,
    t_sec=10.0,
    fs=512
)

print(f"Long signal shape: {x_long.shape}")
print(f"Duration: {len(t_long)/fs:.1f} seconds")

# Test signal utilities
utilities_tested = []

# Test ensure_3d function
try:
    x_2d = x_long[0]  # Remove batch dimension
    x_3d = dsp.ensure_3d(x_2d)
    print(f"✅ ensure_3d: {x_2d.shape} -> {x_3d.shape}")
    utilities_tested.append('ensure_3d')
except Exception as e:
    print(f"⚠️ ensure_3d failed: {e}")

# Test resampling
try:
    new_fs = 256  # Downsample to 256 Hz
    x_resampled = dsp.resample(x_long, fs, new_fs)
    print(f"✅ resample: {fs} Hz -> {new_fs} Hz, shape {x_long.shape} -> {x_resampled.shape}")
    utilities_tested.append('resample')
except Exception as e:
    print(f"⚠️ resample failed: {e}")

# Test cropping
try:
    t_start, t_end = 2.0, 8.0  # Crop from 2 to 8 seconds
    x_cropped = dsp.crop(x_long, t_start, t_end, fs)
    print(f"✅ crop: {t_start}s to {t_end}s, shape {x_long.shape} -> {x_cropped.shape}")
    utilities_tested.append('crop')
except Exception as e:
    print(f"⚠️ crop failed: {e}")

print(f"\nSuccessfully tested utilities: {utilities_tested}")

## 6. Ripple Detection in Neural Signals

Ripples are high-frequency oscillations (100-200 Hz) important in memory consolidation.

In [None]:
# Generate ripple signal
try:
    x_ripple, t_ripple, fs = dsp.demo_sig(
        sig_type='ripple',
        batch_size=1,
        n_chs=1,
        t_sec=6.0,
        fs=512
    )
    print(f"Ripple signal generated: {x_ripple.shape}")
    
    # Detect ripples
    try:
        ripple_events = dsp.detect_ripples(x_ripple[0, 0], fs)
        print(f"✅ Detected {len(ripple_events)} ripple events")
        ripples_detected = True
    except Exception as e:
        print(f"⚠️ Ripple detection failed: {e}")
        ripples_detected = False
        
except Exception as e:
    print(f"⚠️ Ripple signal generation failed: {e}")
    ripples_detected = False

if ripples_detected:
    # Plot ripple signal with detected events
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    ax.plot(t_ripple, x_ripple[0, 0], 'b-', alpha=0.7, label='LFP Signal')
    
    # Mark detected ripples
    for ripple in ripple_events.itertuples():
        start_idx = int(ripple.start_time * fs)
        end_idx = int(ripple.end_time * fs)
        ax.axvspan(ripple.start_time, ripple.end_time, alpha=0.3, color='red')
    
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Amplitude [μV]')
    ax.set_title(f'Ripple Detection in Neural Signal ({len(ripple_events)} events detected)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Display ripple statistics
    if len(ripple_events) > 0:
        print(f"\nRipple Statistics:")
        print(f"- Number of ripples: {len(ripple_events)}")
        print(f"- Average duration: {ripple_events['duration'].mean():.3f} ± {ripple_events['duration'].std():.3f} s")
        print(f"- Average amplitude: {ripple_events['amplitude'].mean():.2f} ± {ripple_events['amplitude'].std():.2f} μV")
else:
    print("Ripple detection visualization skipped")

## 7. Modulation Index Calculation

The modulation index quantifies the strength of phase-amplitude coupling.

In [None]:
# Calculate modulation index for the PAC signal
try:
    # Generate a clean PAC signal for modulation index
    fs = 512
    t = np.linspace(0, 4, int(4 * fs))
    
    # Create theta (6 Hz) and gamma (80 Hz) with coupling
    theta = np.sin(2 * np.pi * 6 * t)
    gamma_modulated = (1 + 0.8 * theta) * np.sin(2 * np.pi * 80 * t)
    
    # Combine signals
    coupled_signal = theta + 0.5 * gamma_modulated
    x_mi = coupled_signal.reshape(1, 1, -1)
    
    # Calculate modulation index
    mi_value = dsp.modulation_index(x_mi)
    print(f"✅ Modulation Index calculated: {mi_value:.4f}")
    
    # Also calculate for uncoupled signal (control)
    gamma_uncoupled = np.sin(2 * np.pi * 80 * t)
    uncoupled_signal = theta + 0.5 * gamma_uncoupled
    x_mi_control = uncoupled_signal.reshape(1, 1, -1)
    
    mi_control = dsp.modulation_index(x_mi_control)
    print(f"✅ Control Modulation Index: {mi_control:.4f}")
    
    mi_calculated = True
    
except Exception as e:
    print(f"⚠️ Modulation index calculation failed: {e}")
    mi_calculated = False

if mi_calculated:
    # Plot coupled vs uncoupled signals
    fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
    
    axes[0].plot(t[:1024], coupled_signal[:1024])  # First 2 seconds
    axes[0].set_ylabel('Amplitude [μV]')
    axes[0].set_title(f'Coupled Signal (MI = {mi_value:.4f})')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(t[:1024], uncoupled_signal[:1024])
    axes[1].set_xlabel('Time [s]')
    axes[1].set_ylabel('Amplitude [μV]')
    axes[1].set_title(f'Uncoupled Signal (MI = {mi_control:.4f})')
    axes[1].grid(True, alpha=0.3)
    
    fig.suptitle('Modulation Index Comparison: Coupled vs Uncoupled Signals')
    plt.tight_layout()
    plt.show()
    
    print(f"\nModulation Index Interpretation:")
    print(f"- Coupled signal MI: {mi_value:.4f} (higher values indicate stronger coupling)")
    print(f"- Control signal MI: {mi_control:.4f} (baseline/uncoupled)")
    print(f"- Difference: {mi_value - mi_control:.4f}")

## Summary

This tutorial demonstrated the key features of the SciTeX DSP module:

### ✅ **Signal Generation**
- Various demo signal types (periodic, chirp, gaussian, PAC, ripple)
- Configurable parameters (duration, sampling rate, channels)
- Neural signal simulation capabilities

### ✅ **Spectral Analysis** 
- Power Spectral Density (PSD) calculation
- Frequency domain analysis of neural oscillations
- Band power estimation

### ✅ **Time-Frequency Analysis**
- Hilbert transform for amplitude and phase extraction
- Instantaneous signal properties
- Complex signal analysis

### ✅ **Phase-Amplitude Coupling**
- PAC calculation between frequency bands
- Comodulogram visualization
- Modulation index quantification

### ✅ **Signal Preprocessing**
- Signal cropping and resampling
- Dimension handling (ensure_3d)
- Data format utilities

### ✅ **Neural Event Detection**
- Ripple detection in LFP signals
- Event characterization and statistics
- Automated neural pattern recognition

### Key Applications:
- **Neuroscience Research**: Analysis of brain oscillations, connectivity
- **Signal Processing**: General time-series analysis workflows
- **Biomedical Engineering**: Neural signal preprocessing and feature extraction
- **Research**: Cross-frequency coupling analysis

The SciTeX DSP module provides a comprehensive toolkit for neural signal analysis with GPU acceleration support and integration with modern deep learning frameworks.