# SciTeX Neural Network Module Tutorial

This notebook demonstrates the specialized neural network layers and components in SciTeX, particularly focused on signal processing and neuroscience applications.

## 1. Setup and Imports

In [None]:
import scitex as stx
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Signal Processing Layers

### 2.1 Frequency Filters

In [None]:
# Generate sample signal with multiple frequency components
fs = 1000  # Sampling frequency
t = np.linspace(0, 1, fs)
signal_clean = np.sin(2 * np.pi * 10 * t)  # 10 Hz signal
noise = 0.5 * np.sin(2 * np.pi * 100 * t)  # 100 Hz noise
signal_noisy = signal_clean + noise

# Convert to torch tensor
x = torch.tensor(signal_noisy, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
print(f"Input shape: {x.shape}  # [batch, channels, time]")

# Apply different filters
# Low-pass filter to remove high frequency noise
lowpass = stx.nn.LowPassFilter(cutoff_freq=30, sample_rate=fs, filter_length=51)
x_lowpass = lowpass(x)

# Band-pass filter for specific frequency range
bandpass = stx.nn.BandPassFilter(low_freq=8, high_freq=12, sample_rate=fs, filter_length=51)
x_bandpass = bandpass(x)

# High-pass filter
highpass = stx.nn.HighPassFilter(cutoff_freq=50, sample_rate=fs, filter_length=51)
x_highpass = highpass(x)

# Plot results
fig, axes = plt.subplots(4, 1, figsize=(12, 10))

axes[0].plot(t, signal_noisy)
axes[0].set_title('Original Noisy Signal (10 Hz + 100 Hz)')
axes[0].set_ylabel('Amplitude')

axes[1].plot(t, x_lowpass.squeeze().numpy())
axes[1].set_title('Low-Pass Filtered (< 30 Hz)')
axes[1].set_ylabel('Amplitude')

axes[2].plot(t, x_bandpass.squeeze().numpy())
axes[2].set_title('Band-Pass Filtered (8-12 Hz)')
axes[2].set_ylabel('Amplitude')

axes[3].plot(t, x_highpass.squeeze().numpy())
axes[3].set_title('High-Pass Filtered (> 50 Hz)')
axes[3].set_ylabel('Amplitude')
axes[3].set_xlabel('Time (s)')

plt.tight_layout()
plt.show()

### 2.2 Spectral Analysis Layers

In [None]:
# Generate multi-channel signal
n_channels = 4
n_samples = 2000
fs = 500
t = np.linspace(0, n_samples/fs, n_samples)

# Create signals with different frequency content
signals = []
for i in range(n_channels):
    freq = 10 + i * 5  # 10, 15, 20, 25 Hz
    signal_ch = np.sin(2 * np.pi * freq * t) + 0.3 * np.random.randn(n_samples)
    signals.append(signal_ch)

x = torch.tensor(np.array(signals), dtype=torch.float32).unsqueeze(0)
print(f"Multi-channel input shape: {x.shape}  # [batch, channels, time]")

# Compute spectrogram
spectrogram_layer = stx.nn.Spectrogram(
    n_fft=256,
    hop_length=64,
    power=2.0,
    normalized=True
)
spec = spectrogram_layer(x)
print(f"Spectrogram shape: {spec.shape}  # [batch, channels, freq, time]")

# Compute PSD (Power Spectral Density)
psd_layer = stx.nn.PSD(n_fft=512, fs=fs)
psd = psd_layer(x)
print(f"PSD shape: {psd.shape}  # [batch, channels, freq_bins]")

# Visualize
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Plot spectrograms for first two channels
for i in range(2):
    im = axes[0, i].imshow(spec[0, i].numpy(), aspect='auto', origin='lower',
                          extent=[0, t[-1], 0, fs/2])
    axes[0, i].set_title(f'Channel {i+1} Spectrogram')
    axes[0, i].set_xlabel('Time (s)')
    axes[0, i].set_ylabel('Frequency (Hz)')
    plt.colorbar(im, ax=axes[0, i])

# Plot PSD for all channels
freqs = np.fft.rfftfreq(512, 1/fs)
for i in range(n_channels):
    axes[1, 0].semilogy(freqs, psd[0, i].numpy(), label=f'Ch {i+1} ({10+i*5} Hz)')
axes[1, 0].set_xlabel('Frequency (Hz)')
axes[1, 0].set_ylabel('Power')
axes[1, 0].set_title('Power Spectral Density')
axes[1, 0].legend()
axes[1, 0].grid(True)
axes[1, 0].set_xlim([0, 50])

# Hide unused subplot
axes[1, 1].set_visible(False)

plt.tight_layout()
plt.show()

## 3. Advanced Signal Processing

### 3.1 Hilbert Transform

In [None]:
# Generate amplitude-modulated signal
fs = 1000
t = np.linspace(0, 2, 2*fs)
carrier_freq = 50  # Hz
modulation_freq = 5  # Hz

# Create AM signal
carrier = np.sin(2 * np.pi * carrier_freq * t)
modulation = 1 + 0.5 * np.sin(2 * np.pi * modulation_freq * t)
am_signal = modulation * carrier

x = torch.tensor(am_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Apply Hilbert transform
hilbert_layer = stx.nn.Hilbert()
analytic_signal = hilbert_layer(x)

# Extract amplitude envelope and instantaneous phase
amplitude = torch.abs(analytic_signal)
phase = torch.angle(analytic_signal)

# Plot results
fig, axes = plt.subplots(3, 1, figsize=(12, 8))

axes[0].plot(t[:500], am_signal[:500])
axes[0].set_title('Original AM Signal')
axes[0].set_ylabel('Amplitude')

axes[1].plot(t[:500], amplitude.squeeze().numpy()[:500], 'r', linewidth=2)
axes[1].plot(t[:500], am_signal[:500], 'b', alpha=0.5)
axes[1].set_title('Amplitude Envelope (red) vs Original Signal (blue)')
axes[1].set_ylabel('Amplitude')

axes[2].plot(t[:500], phase.squeeze().numpy()[:500])
axes[2].set_title('Instantaneous Phase')
axes[2].set_ylabel('Phase (radians)')
axes[2].set_xlabel('Time (s)')

plt.tight_layout()
plt.show()

### 3.2 Phase-Amplitude Coupling (PAC)

In [None]:
# Generate synthetic signal with PAC
fs = 1000
duration = 5
t = np.linspace(0, duration, duration * fs)

# Low frequency phase signal (theta: 6 Hz)
phase_freq = 6
phase_signal = np.sin(2 * np.pi * phase_freq * t)

# High frequency amplitude signal (gamma: 60 Hz)
amp_freq = 60
# Modulate gamma amplitude by theta phase
modulation = 1 + 0.5 * (phase_signal + 1) / 2  # Normalize to 0.5-1.5
amp_signal = modulation * np.sin(2 * np.pi * amp_freq * t)

# Combine signals
combined_signal = phase_signal + 0.3 * amp_signal + 0.1 * np.random.randn(len(t))

# Convert to tensor
x = torch.tensor(combined_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Compute PAC
pac_layer = stx.nn.PAC(
    phase_freq_band=(4, 8),    # Theta band
    amp_freq_band=(50, 70),    # Gamma band
    fs=fs,
    method='tort'              # Tort's modulation index
)

pac_value = pac_layer(x)
print(f"PAC Modulation Index: {pac_value.item():.4f}")

# Visualize PAC
fig, axes = plt.subplots(3, 1, figsize=(12, 8))

# Time window for visualization
t_win = slice(0, 1000)  # First second

# Original signal
axes[0].plot(t[t_win], combined_signal[t_win])
axes[0].set_title('Combined Signal with PAC')
axes[0].set_ylabel('Amplitude')

# Phase signal (filtered)
phase_filter = stx.nn.BandPassFilter(4, 8, fs, filter_length=101)
phase_filtered = phase_filter(x).squeeze().numpy()
axes[1].plot(t[t_win], phase_filtered[t_win], 'b')
axes[1].set_title('Theta Phase (4-8 Hz)')
axes[1].set_ylabel('Amplitude')

# Amplitude signal (filtered)
amp_filter = stx.nn.BandPassFilter(50, 70, fs, filter_length=101)
amp_filtered = amp_filter(x).squeeze().numpy()
amp_envelope = np.abs(signal.hilbert(amp_filtered))
axes[2].plot(t[t_win], amp_filtered[t_win], 'r', alpha=0.5)
axes[2].plot(t[t_win], amp_envelope[t_win], 'r', linewidth=2)
axes[2].set_title('Gamma Amplitude (50-70 Hz) and Envelope')
axes[2].set_ylabel('Amplitude')
axes[2].set_xlabel('Time (s)')

plt.tight_layout()
plt.show()

## 4. Wavelet Analysis

In [None]:
# Generate signal with time-varying frequency
fs = 1000
t = np.linspace(0, 2, 2*fs)

# Chirp signal (frequency increases over time)
f0, f1 = 10, 50  # Start and end frequencies
chirp = signal.chirp(t, f0, t[-1], f1, method='linear')

# Add transient burst
burst_time = 1.0
burst_idx = int(burst_time * fs)
burst = np.zeros_like(t)
burst[burst_idx:burst_idx+50] = 0.5 * np.sin(2 * np.pi * 30 * t[:50])

test_signal = chirp + burst + 0.1 * np.random.randn(len(t))

x = torch.tensor(test_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Apply wavelet transform
wavelet_layer = stx.nn.Wavelet(
    wavelet='morlet',
    scales=np.logspace(0, 2, 50),  # Scales from 1 to 100
    fs=fs
)

cwt_coeffs = wavelet_layer(x)
print(f"CWT shape: {cwt_coeffs.shape}  # [batch, channels, scales, time]")

# Visualize
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Original signal
axes[0].plot(t, test_signal)
axes[0].set_title('Signal with Linear Chirp and Transient Burst')
axes[0].set_ylabel('Amplitude')
axes[0].axvline(burst_time, color='r', linestyle='--', alpha=0.5, label='Burst')
axes[0].legend()

# Wavelet scalogram
scales = wavelet_layer.scales.numpy()
frequencies = fs / (2 * scales)  # Approximate frequency for each scale

im = axes[1].imshow(
    np.abs(cwt_coeffs[0, 0].numpy()),
    aspect='auto',
    extent=[t[0], t[-1], frequencies[-1], frequencies[0]],
    cmap='hot',
    interpolation='bilinear'
)
axes[1].set_title('Continuous Wavelet Transform (Scalogram)')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Frequency (Hz)')
axes[1].set_ylim([5, 80])
plt.colorbar(im, ax=axes[1], label='Magnitude')

plt.tight_layout()
plt.show()

## 5. Specialized Neural Network Architectures

### 5.1 ResNet1D for Time Series

In [None]:
# Create ResNet1D for time series classification
resnet1d = stx.nn.ResNet1D(
    in_channels=8,      # Number of input channels
    num_classes=4,      # Number of output classes
    block_sizes=[2, 2, 2, 2],  # Number of blocks in each layer
    hidden_sizes=[64, 128, 256, 512],  # Hidden dimensions
    kernel_size=3
)

# Example input
batch_size = 16
n_channels = 8
seq_length = 1000
x = torch.randn(batch_size, n_channels, seq_length)

# Forward pass
output = resnet1d(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nModel architecture:")
print(f"Total parameters: {sum(p.numel() for p in resnet1d.parameters()):,}")

### 5.2 BNet - Brain Network Architecture

In [None]:
# BNet configuration for EEG/MEG data
config = stx.nn.BNet_config(
    n_channels=64,          # EEG channels
    n_timepoints=1000,      # Time points
    n_classes=5,            # Classification classes
    conv_channels=[32, 64, 128],
    kernel_sizes=[5, 5, 3],
    dropout_rate=0.5
)

# Create BNet model
bnet = stx.nn.BNet(config)

# Example EEG-like input
x = torch.randn(8, 64, 1000)  # [batch, channels, time]
output = bnet(x)

print(f"BNet Configuration:")
print(f"  Input: {config.n_channels} channels × {config.n_timepoints} timepoints")
print(f"  Output: {config.n_classes} classes")
print(f"  Architecture: {config.conv_channels}")
print(f"\nOutput shape: {output.shape}")

## 6. Attention Mechanisms

In [None]:
# Spatial Attention for multi-channel data
spatial_attention = stx.nn.SpatialAttention(
    n_channels=32,
    reduction_ratio=4
)

# Create multi-channel time series data
batch_size = 8
n_channels = 32
seq_length = 500
x = torch.randn(batch_size, n_channels, seq_length)

# Apply spatial attention
x_attended, attention_weights = spatial_attention(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {x_attended.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

# Visualize attention weights for first sample
fig, ax = plt.subplots(figsize=(10, 4))
im = ax.imshow(attention_weights[0].detach().numpy().reshape(1, -1), 
               aspect='auto', cmap='hot')
ax.set_xlabel('Channel')
ax.set_title('Spatial Attention Weights')
ax.set_yticks([])
plt.colorbar(im, ax=ax)
plt.show()

# Show which channels get most attention
avg_attention = attention_weights.mean(dim=0).squeeze()
top_channels = torch.argsort(avg_attention, descending=True)[:5]
print(f"\nTop 5 attended channels: {top_channels.tolist()}")
print(f"Their attention weights: {avg_attention[top_channels].tolist()}")

## 7. Data Augmentation Layers

In [None]:
# Demonstrate various augmentation layers

# 1. Channel Dropout
channel_dropout = stx.nn.DropoutChannels(p=0.2)

# 2. Channel Gain Changer
gain_changer = stx.nn.ChannelGainChanger(gain_range=(0.8, 1.2))

# 3. Frequency Gain Changer
freq_gain = stx.nn.FreqGainChanger(
    freq_bands=[(8, 12), (13, 30), (30, 100)],  # Alpha, Beta, Gamma
    gain_range=(0.9, 1.1),
    fs=1000
)

# 4. Channel Swapping
channel_swap = stx.nn.SwapChannels(p=0.3)

# Create sample data
x = torch.randn(4, 8, 1000)  # [batch, channels, time]

# Apply augmentations
print("Data Augmentation Examples:")
print(f"Original shape: {x.shape}")

# Apply channel dropout
x_dropped = channel_dropout(x)
dropped_channels = (x_dropped[0].sum(dim=1) == 0).sum().item()
print(f"\nChannel Dropout: {dropped_channels} channels dropped")

# Apply gain changes
x_gain = gain_changer(x)
gain_factors = (x_gain[0] / (x[0] + 1e-8)).mean(dim=1)
print(f"\nChannel gains applied: {gain_factors[:4].tolist()}")

# Apply frequency-specific gain
x_freq = freq_gain(x)
print(f"\nFrequency gain applied to alpha, beta, gamma bands")

## 8. Custom Layer Development

In [None]:
# Example: Create a custom layer for computing running statistics
class RunningStatsLayer(nn.Module):
    """Computes running mean and std over time dimension."""
    
    def __init__(self, window_size=100, dim=-1):
        super().__init__()
        self.window_size = window_size
        self.dim = dim
        
    def forward(self, x):
        # x shape: [batch, channels, time]
        if self.dim == -1:
            # Unfold to create sliding windows
            x_unfold = x.unfold(dimension=2, size=self.window_size, step=1)
            # x_unfold shape: [batch, channels, n_windows, window_size]
            
            # Compute statistics
            mean = x_unfold.mean(dim=-1)
            std = x_unfold.std(dim=-1)
            
            # Combine mean and std as new channels
            output = torch.cat([mean, std], dim=1)
            return output
        
# Use the custom layer
stats_layer = RunningStatsLayer(window_size=50)
x = torch.randn(4, 3, 200)  # [batch, channels, time]

output = stats_layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}  # [batch, 2*channels, time-window+1]")
print("\nOutput contains running mean and std for each channel")

## 9. Building Complete Neural Network Pipelines

In [None]:
# Build a complete signal processing neural network
class SignalProcessingNet(nn.Module):
    """Complete pipeline for EEG/biosignal classification."""
    
    def __init__(self, n_channels=32, n_classes=4, fs=250):
        super().__init__()
        
        # Preprocessing layers
        self.bandpass = stx.nn.BandPassFilter(1, 40, fs, filter_length=51)
        self.spatial_filter = stx.nn.SpatialAttention(n_channels, reduction_ratio=4)
        
        # Feature extraction
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=25, stride=2)
        self.bn1 = nn.BatchNorm1d(64)
        self.activation = nn.ELU()
        self.dropout = nn.Dropout(0.5)
        
        # Temporal attention
        self.conv2 = nn.Conv1d(64, 128, kernel_size=15, stride=2)
        self.bn2 = nn.BatchNorm1d(128)
        
        # Global pooling and classification
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(128, n_classes)
        
    def forward(self, x):
        # Input shape: [batch, channels, time]
        
        # Preprocessing
        x = self.bandpass(x)
        x, attention = self.spatial_filter(x)
        
        # Feature extraction
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)
        x = self.dropout(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)
        
        # Classification
        x = self.global_pool(x).squeeze(-1)
        x = self.classifier(x)
        
        return x, attention

# Create and test the network
model = SignalProcessingNet(n_channels=32, n_classes=4, fs=250)
x = torch.randn(8, 32, 1000)  # 4 seconds at 250 Hz

output, attention = model(x)
print(f"Model output shape: {output.shape}")
print(f"Attention weights shape: {attention.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

# Visualize model predictions
with torch.no_grad():
    probs = torch.softmax(output, dim=1)
    predictions = torch.argmax(probs, dim=1)
    
print(f"\nPredictions: {predictions.tolist()}")
print(f"Confidence: {probs.max(dim=1)[0].tolist()}")

## 10. Summary and Best Practices

### Key Takeaways

1. **Signal Processing Integration**: SciTeX seamlessly integrates signal processing with deep learning
2. **Domain-Specific Layers**: Specialized layers for neuroscience and biosignal analysis
3. **Differentiable Operations**: All operations are differentiable for end-to-end training
4. **Modular Design**: Easy to combine layers for custom architectures

### Best Practices

1. **Preprocessing in the Network**:
   ```python
   # Include preprocessing as network layers
   self.preprocess = nn.Sequential(
       stx.nn.BandPassFilter(1, 40, fs),
       stx.nn.ChannelGainChanger(gain_range=(0.9, 1.1))
   )
   ```

2. **Attention for Channel Selection**:
   ```python
   # Use spatial attention to learn important channels
   self.attention = stx.nn.SpatialAttention(n_channels)
   ```

3. **Multi-Scale Analysis**:
   ```python
   # Combine different frequency bands
   self.multi_scale = nn.ModuleList([
       stx.nn.BandPassFilter(low, high, fs)
       for low, high in [(1,4), (4,8), (8,12), (13,30)]
   ])
   ```

4. **Data Augmentation**:
   ```python
   # Apply augmentation during training only
   if self.training:
       x = self.channel_dropout(x)
       x = self.freq_gain_changer(x)
   ```

In [None]:
print("\nNeural Network module tutorial completed!")
print("\nNext steps:")
print("1. Experiment with different filter parameters for your signals")
print("2. Combine multiple layers for complex architectures")
print("3. Use PAC and other neuroscience-specific analyses")
print("4. Integrate with PyTorch training loops for end-to-end learning")