# SciTeX Neural Network Components

This notebook demonstrates the neural network components provided by the `scitex.nn` module, which offers specialized layers and models for scientific computing, particularly for signal processing and neuroscience applications.

## 1. Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy import signal
import scitex as stx

# Set up reproducible environment
stx.repro.fix_seeds(42)
torch.manual_seed(42)

# Configure visualization
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"SciTeX version: {stx.__version__}")
print(f"PyTorch version: {torch.__version__}")

## 2. Signal Filtering Layers

In [None]:
# Generate synthetic signal with multiple frequency components
fs = 1000  # Sampling frequency
t = np.linspace(0, 2, 2 * fs)
# Signal with 10Hz, 50Hz, and 100Hz components + noise
signal_np = (np.sin(2 * np.pi * 10 * t) + 
             0.5 * np.sin(2 * np.pi * 50 * t) + 
             0.3 * np.sin(2 * np.pi * 100 * t) +
             0.2 * np.random.randn(len(t)))

# Convert to torch tensor
signal_torch = torch.tensor(signal_np, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
print(f"Signal shape: {signal_torch.shape} (batch, channels, time)")

# Visualize original signal
plt.figure(figsize=(12, 8))

# Time domain
plt.subplot(2, 1, 1)
plt.plot(t[:500], signal_np[:500])
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Original Signal (First 0.5s)')
plt.grid(True, alpha=0.3)

# Frequency domain
plt.subplot(2, 1, 2)
freqs, psd = signal.welch(signal_np, fs=fs, nperseg=256)
plt.semilogy(freqs, psd)
plt.xlabel('Frequency (Hz)')
plt.ylabel('PSD')
plt.title('Power Spectral Density')
plt.xlim([0, 200])
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Apply different filters
# Low-pass filter (keep only low frequencies)
low_pass = stx.nn.LowPassFilter(cutoff_freq=30, fs=fs, order=4)
signal_low = low_pass(signal_torch)

# Band-pass filter (keep specific frequency band)
band_pass = stx.nn.BandPassFilter(low_freq=40, high_freq=60, fs=fs, order=4)
signal_band = band_pass(signal_torch)

# High-pass filter (keep only high frequencies)
high_pass = stx.nn.HighPassFilter(cutoff_freq=80, fs=fs, order=4)
signal_high = high_pass(signal_torch)

# Visualize filtered signals
fig, axes = plt.subplots(4, 2, figsize=(14, 12))

signals = [
    (signal_torch, "Original"),
    (signal_low, "Low-pass (<30Hz)"),
    (signal_band, "Band-pass (40-60Hz)"),
    (signal_high, "High-pass (>80Hz)")
]

for i, (sig, title) in enumerate(signals):
    sig_np = sig.squeeze().numpy()
    
    # Time domain
    axes[i, 0].plot(t[:500], sig_np[:500])
    axes[i, 0].set_xlabel('Time (s)')
    axes[i, 0].set_ylabel('Amplitude')
    axes[i, 0].set_title(f'{title} - Time Domain')
    axes[i, 0].grid(True, alpha=0.3)
    
    # Frequency domain
    freqs, psd = signal.welch(sig_np, fs=fs, nperseg=256)
    axes[i, 1].semilogy(freqs, psd)
    axes[i, 1].set_xlabel('Frequency (Hz)')
    axes[i, 1].set_ylabel('PSD')
    axes[i, 1].set_title(f'{title} - Frequency Domain')
    axes[i, 1].set_xlim([0, 200])
    axes[i, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Spectrogram and Time-Frequency Analysis

In [None]:
# Create a chirp signal (frequency increases over time)
t_chirp = np.linspace(0, 2, 2 * fs)
f0, f1 = 10, 100  # Start and end frequencies
chirp_signal = signal.chirp(t_chirp, f0, t_chirp[-1], f1, method='linear')
chirp_torch = torch.tensor(chirp_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Compute spectrogram using SciTeX
spectrogram_layer = stx.nn.Spectrogram(
    n_fft=256,
    hop_length=64,
    win_length=256,
    normalized=True
)

spec = spectrogram_layer(chirp_torch)
spec_db = 20 * torch.log10(spec.abs() + 1e-10)

# Visualize
plt.figure(figsize=(12, 8))

# Original signal
plt.subplot(2, 1, 1)
plt.plot(t_chirp, chirp_signal)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Chirp Signal (10-100 Hz)')
plt.grid(True, alpha=0.3)

# Spectrogram
plt.subplot(2, 1, 2)
spec_np = spec_db.squeeze().numpy()
plt.imshow(spec_np, aspect='auto', origin='lower', 
           extent=[0, 2, 0, fs/2], cmap='viridis')
plt.colorbar(label='Magnitude (dB)')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.title('Spectrogram')
plt.ylim([0, 150])

plt.tight_layout()
plt.show()

print(f"Spectrogram shape: {spec.shape} (batch, freq_bins, time_frames)")

## 4. Hilbert Transform and Analytical Signal

In [None]:
# Generate amplitude-modulated signal
carrier_freq = 100  # Hz
modulation_freq = 5  # Hz
t_am = np.linspace(0, 1, fs)
carrier = np.sin(2 * np.pi * carrier_freq * t_am)
modulation = 1 + 0.5 * np.sin(2 * np.pi * modulation_freq * t_am)
am_signal = modulation * carrier

am_torch = torch.tensor(am_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Apply Hilbert transform
hilbert_layer = stx.nn.Hilbert()
analytic_signal = hilbert_layer(am_torch)

# Extract envelope (instantaneous amplitude)
envelope = torch.abs(analytic_signal)

# Extract instantaneous phase
phase = torch.angle(analytic_signal)

# Visualize
plt.figure(figsize=(12, 10))

# Original signal and envelope
plt.subplot(3, 1, 1)
plt.plot(t_am[:200], am_signal[:200], 'b-', label='AM Signal', alpha=0.7)
plt.plot(t_am[:200], envelope.squeeze().numpy()[:200], 'r-', linewidth=2, label='Envelope')
plt.plot(t_am[:200], -envelope.squeeze().numpy()[:200], 'r-', linewidth=2)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Amplitude Modulated Signal with Envelope')
plt.legend()
plt.grid(True, alpha=0.3)

# Instantaneous phase
plt.subplot(3, 1, 2)
plt.plot(t_am[:200], phase.squeeze().numpy()[:200])
plt.xlabel('Time (s)')
plt.ylabel('Phase (rad)')
plt.title('Instantaneous Phase')
plt.grid(True, alpha=0.3)

# Modulation extraction
plt.subplot(3, 1, 3)
plt.plot(t_am, modulation, 'g-', linewidth=2, label='True Modulation')
plt.plot(t_am, envelope.squeeze().numpy(), 'r--', label='Extracted Envelope')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Modulation Extraction Comparison')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Phase-Amplitude Coupling (PAC)

In [None]:
# Generate synthetic signal with phase-amplitude coupling
# Low frequency phase modulates high frequency amplitude
t_pac = np.linspace(0, 5, 5 * fs)
phase_freq = 6  # Hz (theta band)
amp_freq = 50   # Hz (gamma band)

# Phase signal
phase_signal = np.sin(2 * np.pi * phase_freq * t_pac)

# Amplitude modulation based on phase
amp_modulation = 1 + 0.5 * (1 + phase_signal)

# High frequency signal with modulated amplitude
amp_signal = amp_modulation * np.sin(2 * np.pi * amp_freq * t_pac)

# Combined signal
pac_signal = phase_signal + amp_signal + 0.1 * np.random.randn(len(t_pac))
pac_torch = torch.tensor(pac_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Compute PAC
pac_layer = stx.nn.PAC(
    phase_freq_band=(4, 8),    # Theta band
    amp_freq_band=(40, 60),    # Gamma band
    fs=fs
)

pac_value = pac_layer(pac_torch)

# Visualize
plt.figure(figsize=(14, 10))

# Original signal
plt.subplot(4, 1, 1)
plt.plot(t_pac[:1000], pac_signal[:1000])
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Signal with Phase-Amplitude Coupling')
plt.grid(True, alpha=0.3)

# Phase component
plt.subplot(4, 1, 2)
plt.plot(t_pac[:1000], phase_signal[:1000], 'b-', label=f'{phase_freq}Hz Phase')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Low Frequency Phase Component')
plt.legend()
plt.grid(True, alpha=0.3)

# Amplitude component
plt.subplot(4, 1, 3)
plt.plot(t_pac[:1000], amp_signal[:1000], 'r-', alpha=0.7, label=f'{amp_freq}Hz Amplitude')
plt.plot(t_pac[:1000], amp_modulation[:1000], 'g--', linewidth=2, label='Modulation Envelope')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('High Frequency Amplitude Component')
plt.legend()
plt.grid(True, alpha=0.3)

# Spectrogram
plt.subplot(4, 1, 4)
f, t_spec, Sxx = signal.spectrogram(pac_signal, fs=fs, nperseg=256, noverlap=200)
plt.pcolormesh(t_spec, f, 10 * np.log10(Sxx + 1e-10), cmap='viridis')
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (s)')
plt.title('Spectrogram showing PAC')
plt.ylim([0, 100])
plt.colorbar(label='Power (dB)')

plt.tight_layout()
plt.show()

print(f"Phase-Amplitude Coupling value: {pac_value.item():.4f}")

## 6. Specialized Dropout Layers

In [None]:
# Create multi-channel signal
n_channels = 8
n_timepoints = 1000
multi_channel_signal = torch.randn(1, n_channels, n_timepoints)

# Apply different dropout strategies
# Standard dropout
standard_dropout = nn.Dropout(p=0.5)

# Channel-wise dropout (drops entire channels)
channel_dropout = stx.nn.DropoutChannels(p=0.3)

# Axis-wise dropout (drops along specific axis)
axiswise_dropout = stx.nn.AxiswiseDropout(p=0.3, axis=2)  # Drop time points

# Apply dropouts in training mode
standard_dropout.train()
channel_dropout.train()
axiswise_dropout.train()

signal_standard = standard_dropout(multi_channel_signal.clone())
signal_channel = channel_dropout(multi_channel_signal.clone())
signal_axiswise = axiswise_dropout(multi_channel_signal.clone())

# Visualize dropout effects
fig, axes = plt.subplots(4, 1, figsize=(14, 12))

# Original
im0 = axes[0].imshow(multi_channel_signal.squeeze().numpy(), 
                     aspect='auto', cmap='viridis')
axes[0].set_title('Original Multi-channel Signal')
axes[0].set_ylabel('Channel')
plt.colorbar(im0, ax=axes[0])

# Standard dropout
im1 = axes[1].imshow(signal_standard.squeeze().numpy(), 
                     aspect='auto', cmap='viridis')
axes[1].set_title('Standard Dropout (p=0.5)')
axes[1].set_ylabel('Channel')
plt.colorbar(im1, ax=axes[1])

# Channel dropout
im2 = axes[2].imshow(signal_channel.squeeze().numpy(), 
                     aspect='auto', cmap='viridis')
axes[2].set_title('Channel Dropout (p=0.3) - Entire channels dropped')
axes[2].set_ylabel('Channel')
plt.colorbar(im2, ax=axes[2])

# Axiswise dropout
im3 = axes[3].imshow(signal_axiswise.squeeze().numpy(), 
                     aspect='auto', cmap='viridis')
axes[3].set_title('Axiswise Dropout (p=0.3, axis=time) - Time points dropped')
axes[3].set_ylabel('Channel')
axes[3].set_xlabel('Time')
plt.colorbar(im3, ax=axes[3])

plt.tight_layout()
plt.show()

# Check which channels were dropped
dropped_channels = torch.all(signal_channel.squeeze() == 0, dim=1)
print(f"Dropped channels: {torch.where(dropped_channels)[0].tolist()}")

## 7. ResNet1D for Time Series

In [None]:
# Create ResNet1D model for time series classification
model = stx.nn.ResNet1D(
    in_channels=n_channels,
    num_classes=4,
    block_sizes=[2, 2, 2, 2],  # Number of residual blocks in each stage
    channels=[64, 128, 256, 512]  # Channels in each stage
)

# Generate synthetic data
batch_size = 16
seq_length = 1000
x = torch.randn(batch_size, n_channels, seq_length)

# Forward pass
output = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nModel architecture:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Visualize intermediate feature maps
# Hook to capture intermediate outputs
intermediate_outputs = []

def hook_fn(module, input, output):
    intermediate_outputs.append(output.detach())

# Register hooks on each stage
hooks = []
for i, stage in enumerate([model.conv1, model.stage1, model.stage2, model.stage3, model.stage4]):
    hooks.append(stage.register_forward_hook(hook_fn))

# Forward pass with hooks
_ = model(x[:1])  # Single sample

# Remove hooks
for hook in hooks:
    hook.remove()

# Visualize feature evolution
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

titles = ['Input Conv', 'Stage 1', 'Stage 2', 'Stage 3', 'Stage 4', 'Global Avg Pool']

for i, (feat, title) in enumerate(zip(intermediate_outputs[:6], titles)):
    if i < 5:
        # Show first few channels
        feat_vis = feat[0, :min(8, feat.shape[1])].numpy()
        im = axes[i].imshow(feat_vis, aspect='auto', cmap='viridis')
        axes[i].set_title(f'{title}\nShape: {list(feat.shape)}')
        axes[i].set_ylabel('Channel')
        axes[i].set_xlabel('Time')
        plt.colorbar(im, ax=axes[i])

# Hide last subplot
axes[-1].axis('off')

plt.tight_layout()
plt.show()

## 8. Spatial Attention for Multi-channel Data

In [None]:
# Create spatial attention module
spatial_attention = stx.nn.SpatialAttention(
    in_channels=n_channels,
    reduction_ratio=2
)

# Generate data with different importance across channels
important_channels = [1, 3, 5]  # Channels with signal
noise_channels = [0, 2, 4, 6, 7]  # Channels with mostly noise

data = torch.randn(batch_size, n_channels, seq_length) * 0.1

# Add signal to important channels
for ch in important_channels:
    freq = 10 + ch * 5  # Different frequency for each channel
    t = torch.linspace(0, 1, seq_length)
    signal_ch = torch.sin(2 * np.pi * freq * t)
    data[:, ch, :] += signal_ch

# Apply spatial attention
attended_data, attention_weights = spatial_attention(data, return_attention=True)

# Visualize attention weights
plt.figure(figsize=(12, 8))

# Original data (single sample)
plt.subplot(3, 1, 1)
plt.imshow(data[0].numpy(), aspect='auto', cmap='viridis')
plt.colorbar(label='Amplitude')
plt.ylabel('Channel')
plt.title('Original Multi-channel Data')

# Attention weights
plt.subplot(3, 1, 2)
avg_attention = attention_weights.mean(dim=0).squeeze().numpy()
plt.bar(range(n_channels), avg_attention)
plt.xlabel('Channel')
plt.ylabel('Attention Weight')
plt.title('Learned Spatial Attention Weights')
plt.grid(True, alpha=0.3)

# Add labels for important channels
for ch in important_channels:
    plt.text(ch, avg_attention[ch] + 0.01, 'Signal', ha='center', fontsize=8)
for ch in noise_channels:
    plt.text(ch, avg_attention[ch] + 0.01, 'Noise', ha='center', fontsize=8)

# Attended data
plt.subplot(3, 1, 3)
plt.imshow(attended_data[0].detach().numpy(), aspect='auto', cmap='viridis')
plt.colorbar(label='Amplitude')
plt.ylabel('Channel')
plt.xlabel('Time')
plt.title('Data After Spatial Attention')

plt.tight_layout()
plt.show()

print(f"Attention weights shape: {attention_weights.shape}")
print(f"Channels with highest attention: {avg_attention.argsort()[-3:][::-1].tolist()}")

## 9. Power Spectral Density (PSD) Layer

In [None]:
# Create PSD layer
psd_layer = stx.nn.PSD(
    n_fft=256,
    hop_length=128,
    normalized=True
)

# Generate multi-band signal
fs = 1000
t = np.linspace(0, 2, 2 * fs)

# Different frequency bands
delta = 0.5 * np.sin(2 * np.pi * 2 * t)    # 2 Hz
theta = 0.7 * np.sin(2 * np.pi * 6 * t)    # 6 Hz
alpha = 1.0 * np.sin(2 * np.pi * 10 * t)   # 10 Hz
beta = 0.6 * np.sin(2 * np.pi * 25 * t)    # 25 Hz
gamma = 0.4 * np.sin(2 * np.pi * 50 * t)   # 50 Hz

eeg_signal = delta + theta + alpha + beta + gamma + 0.2 * np.random.randn(len(t))
eeg_torch = torch.tensor(eeg_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Compute PSD
psd = psd_layer(eeg_torch)

# Frequency bins
freq_bins = torch.linspace(0, fs/2, psd.shape[-1])

# Visualize
plt.figure(figsize=(14, 10))

# Time series
plt.subplot(3, 1, 1)
plt.plot(t[:500], eeg_signal[:500])
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('EEG-like Signal with Multiple Frequency Bands')
plt.grid(True, alpha=0.3)

# PSD
plt.subplot(3, 1, 2)
psd_db = 10 * torch.log10(psd.squeeze() + 1e-10)
plt.plot(freq_bins, psd_db)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power (dB)')
plt.title('Power Spectral Density')
plt.xlim([0, 100])
plt.grid(True, alpha=0.3)

# Annotate frequency bands
bands = [
    ('Delta', 0.5, 4, 'blue'),
    ('Theta', 4, 8, 'green'),
    ('Alpha', 8, 13, 'orange'),
    ('Beta', 13, 30, 'red'),
    ('Gamma', 30, 100, 'purple')
]

for name, f_low, f_high, color in bands:
    plt.axvspan(f_low, f_high, alpha=0.2, color=color, label=name)

plt.legend()

# Band power
plt.subplot(3, 1, 3)
band_powers = []
band_names = []

for name, f_low, f_high, color in bands:
    # Find frequency indices
    mask = (freq_bins >= f_low) & (freq_bins < f_high)
    band_power = psd.squeeze()[mask].mean().item()
    band_powers.append(band_power)
    band_names.append(name)

plt.bar(band_names, band_powers, color=[b[3] for b in bands])
plt.xlabel('Frequency Band')
plt.ylabel('Average Power')
plt.title('Power Distribution Across Frequency Bands')
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## 10. Integration Example: Multi-Modal Signal Processing Network

In [None]:
# Build a complete signal processing network
class SignalProcessingNetwork(nn.Module):
    def __init__(self, n_channels, n_classes, fs=1000):
        super().__init__()
        
        # Preprocessing layers
        self.bandpass = stx.nn.BandPassFilter(low_freq=1, high_freq=100, fs=fs)
        self.spatial_attention = stx.nn.SpatialAttention(n_channels)
        
        # Feature extraction
        self.spectrogram = stx.nn.Spectrogram(n_fft=256, hop_length=64)
        self.psd = stx.nn.PSD(n_fft=256)
        
        # Channel augmentation during training
        self.channel_dropout = stx.nn.DropoutChannels(p=0.2)
        
        # Temporal feature extraction
        self.resnet = stx.nn.ResNet1D(
            in_channels=n_channels,
            num_classes=128,  # Feature dimension
            block_sizes=[1, 1, 1, 1],
            channels=[32, 64, 128, 256]
        )
        
        # Combine features for classification
        self.classifier = nn.Sequential(
            nn.Linear(128 + 129 + 129, 256),  # ResNet + Spec + PSD features
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, n_classes)
        )
        
    def forward(self, x):
        # Preprocessing
        x = self.bandpass(x)
        x, attention = self.spatial_attention(x, return_attention=True)
        
        # Apply channel dropout during training
        if self.training:
            x = self.channel_dropout(x)
        
        # Extract different features
        # 1. Temporal features from ResNet
        temporal_features = self.resnet(x)
        
        # 2. Spectral features
        spec = self.spectrogram(x)
        spec_features = spec.mean(dim=-1).squeeze()  # Average over time
        
        # 3. PSD features
        psd = self.psd(x)
        psd_features = psd.squeeze()
        
        # Combine all features
        combined_features = torch.cat([
            temporal_features,
            spec_features.mean(dim=1),  # Average over channels
            psd_features.mean(dim=1)     # Average over channels
        ], dim=-1)
        
        # Classification
        output = self.classifier(combined_features)
        
        return output, attention

# Create model
model = SignalProcessingNetwork(n_channels=8, n_classes=4, fs=1000)

# Test with synthetic data
test_data = torch.randn(4, 8, 2000)  # batch_size=4, channels=8, time=2000
output, attention = model(test_data)

print("Signal Processing Network Summary:")
print(f"Input shape: {test_data.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attention.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

# Visualize attention patterns
plt.figure(figsize=(10, 6))
attention_avg = attention.mean(dim=0).squeeze().detach().numpy()
plt.imshow(attention_avg.reshape(1, -1), aspect='auto', cmap='hot')
plt.colorbar(label='Attention Weight')
plt.xlabel('Channel')
plt.ylabel('Sample')
plt.title('Learned Spatial Attention Patterns')
plt.tight_layout()
plt.show()

## Summary

The `scitex.nn` module provides specialized neural network components for scientific computing:

1. **Signal Filtering**: Differentiable filters (low-pass, high-pass, band-pass) for preprocessing
2. **Time-Frequency Analysis**: Spectrogram and PSD layers for spectral feature extraction
3. **Advanced Signal Processing**: Hilbert transform, phase-amplitude coupling (PAC)
4. **Specialized Architectures**: ResNet1D for time series, spatial attention for multi-channel data
5. **Regularization**: Channel dropout and axis-wise dropout for robust training

These components are particularly useful for:
- EEG/MEG signal analysis
- Time series classification
- Biomedical signal processing
- Multi-channel sensor data analysis
- Real-time signal processing applications