# EEG Model Activation Map Visualization

This notebook loads the trained CNN-LSTM model and generates activation maps to visualize:
1. **Gradient-based Saliency Maps** - Which input features influence the prediction most
2. **CNN Feature Maps** - What the convolutional layers are learning
3. **Class Activation Maps (CAM)** - Spatial importance for ADHD vs Control classification
4. **LSTM Attention Weights** - Temporal importance across frequency bands

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pickle
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Load Model Architecture and Weights

In [None]:
# Load model architecture
class EEG_CNN_LSTM_HPO(nn.Module):
    def __init__(self,
                cnn_kernels_1=32,
                cnn_kernel_size_1=3,
                cnn_kernels_2=64,
                cnn_kernel_size_2=3,
                cnn_dropout=0.3,
                cnn_dense=64,
                lstm_hidden_size=64,
                lstm_layers=2,
                lstm_dense=64,
                dropout=0.3,
                num_classes=2):
        super().__init__()

        # CNN feature extractor
        pad1 = cnn_kernel_size_1 // 2
        self.conv1 = nn.Conv2d(1, int(cnn_kernels_1), kernel_size=cnn_kernel_size_1, padding=pad1)
        self.pool1 = nn.AvgPool2d(2)
        pad2 = cnn_kernel_size_2 // 2
        self.conv2 = nn.Conv2d(int(cnn_kernels_1), int(cnn_kernels_2), kernel_size=cnn_kernel_size_2, padding=pad2)
        self.pool2 = nn.AvgPool2d(2)
        self.cnn_dropout = nn.Dropout(cnn_dropout)

        # Compute dims after CNN
        with torch.no_grad():
            dummy = torch.zeros(1, 1, 77, 19)
            out = self._forward_cnn(dummy)
            self.seq_len = out.size(-1)
            self.feature_dim = out.size(1) * out.size(2)

        self.cnn_dense = nn.Linear(self.feature_dim, int(cnn_dense))

        # LSTM
        self.lstm = nn.LSTM(
            input_size=int(cnn_dense),
            hidden_size=int(lstm_hidden_size),
            num_layers=int(lstm_layers),
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0.0
        )

        self.lstm_dense = nn.Linear(int(lstm_hidden_size), int(lstm_dense))

        # Classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(int(lstm_dense), num_classes)
        )
        
        # For storing activations
        self.conv1_activations = None
        self.conv2_activations = None
        self.lstm_outputs = None

    def _forward_cnn(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.cnn_dropout(x)
        return x

    def forward(self, x, save_activations=False):
        # CNN extraction
        if save_activations:
            x = F.relu(self.conv1(x))
            self.conv1_activations = x.clone()
            x = self.pool1(x)
            x = F.relu(self.conv2(x))
            self.conv2_activations = x.clone()
            x = self.pool2(x)
            x = self.cnn_dropout(x)
        else:
            x = self._forward_cnn(x)

        # Prepare for LSTM
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(x.size(0), x.size(1), -1)
        x = F.relu(self.cnn_dense(x))

        # LSTM
        if save_activations:
            lstm_out, (h_n, _) = self.lstm(x)
            self.lstm_outputs = lstm_out.clone()
            out = h_n[-1]
        else:
            _, (h_n, _) = self.lstm(x)
            out = h_n[-1]

        # Classifier
        out = F.relu(self.lstm_dense(out))
        return self.classifier(out)

In [None]:
# Load the trained model
params = {'batch_size': 80, 'cnn_dense': 256, 'cnn_dropout': np.float64(0.38218620920862145), 'cnn_kernel_size_1': 5, 'cnn_kernel_size_2': 5, 'cnn_kernels_1': 48, 'cnn_kernels_2': 32, 'learning_rate': np.float64(0.0017576118123159641), 'lstm_dense': 32, 'lstm_hidden_size': 128, 'lstm_layers': 3, 'optimizer': 'rmsprop'}

model = EEG_CNN_LSTM_HPO(
    cnn_kernels_1=params['cnn_kernels_1'],
    cnn_kernel_size_1=params['cnn_kernel_size_1'],
    cnn_kernels_2=params['cnn_kernels_2'],
    cnn_dropout=float(params['cnn_dropout']),
    cnn_dense=params['cnn_dense'],
    lstm_hidden_size=params['lstm_hidden_size'],
    lstm_layers=params['lstm_layers'],
    lstm_dense=params['lstm_dense'],
    dropout=float(params['cnn_dropout']),  # use cnn_dropout as a simple shared dropout param
    num_classes=2
).to(device)
model.load_state_dict(torch.load('exports/eeg_cnn_lstm_hpo.pth', map_location=device))
model.eval()
print("✓ Model loaded successfully")

## 2. Load Test Data and Scaler

In [None]:
# Load EEGScaler
class EEGScaler:
    """Standardize EEG data per electrode (channel)."""
    def __init__(self):
        self.mean_ = None
        self.scale_ = None

    def fit(self, X):
        self.mean_ = X.mean(axis=(0, 1), keepdims=True)
        self.scale_ = X.std(axis=(0, 1), keepdims=True) + 1e-8
        return self

    def transform(self, X):
        return (X - self.mean_) / self.scale_

    def fit_transform(self, X):
        return self.fit(X).transform(X)

    def inverse_transform(self, X_scaled):
        return X_scaled * self.scale_ + self.mean_

In [None]:
# Load a sample EEG file for visualization
sample_file = 'extracted_eegs/v107.csv'  # Change this to test different subjects
sample_df = pd.read_csv(sample_file)

print(f"Loaded sample: {sample_file}")
print(f"Shape: {sample_df.shape}")
print(f"Electrodes: {sample_df.columns.tolist()}")

## 3. Preprocess Sample Data

In [None]:
# Convert raw EEG to frequency domain representation (matching training data format)
from scipy.signal import welch

SAMPLE_RATE = 128  # Hz

def eeg_to_frequency_features(df, sfreq=SAMPLE_RATE, nperseg=256):
    """
    Convert time-domain EEG to frequency-domain features matching training format.
    Returns shape: (n_windows, n_frequencies, n_electrodes)
    """
    electrodes = df.columns.tolist()
    n_electrodes = len(electrodes)
    
    # Compute PSD for each electrode
    psd_list = []
    for electrode in electrodes:
        signal = df[electrode].to_numpy()
        freqs, psd = welch(signal, sfreq, nperseg=nperseg)
        psd_list.append(psd)
    
    # Stack: (n_electrodes, n_frequencies)
    psd_matrix = np.array(psd_list).T  # (n_frequencies, n_electrodes)
    
    # Add window dimension (treating as single window)
    psd_matrix = psd_matrix[np.newaxis, ...]  # (1, n_frequencies, n_electrodes)
    
    return psd_matrix, freqs, electrodes

# Process sample
X_sample, freqs, electrodes = eeg_to_frequency_features(sample_df)
print(f"Frequency features shape: {X_sample.shape}")
print(f"Frequency range: {freqs[0]:.2f} - {freqs[-1]:.2f} Hz")

In [None]:
# Load scaler and normalize
with open('saved_scaler.pkl', 'rb') as f:
    scaler = pickle.load(f)

# Add channel dimension and scale
X_sample = X_sample[..., np.newaxis]  # (1, freq, electrodes, 1)
X_sample_scaled = scaler.transform(X_sample)

# Convert to tensor
X_tensor = torch.tensor(X_sample_scaled, dtype=torch.float32).permute(0, 3, 1, 2)  # (1, 1, freq, electrodes)
X_tensor = X_tensor.to(device)

print(f"Model input shape: {X_tensor.shape}")

## 4. Make Prediction

In [None]:
# Get prediction
with torch.no_grad():
    output = model(X_tensor)
    probabilities = F.softmax(output, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()
    confidence = probabilities[0, predicted_class].item()

class_names = ['ADHD', 'Control']
print(f"\nPrediction: {class_names[predicted_class]}")
print(f"Confidence: {confidence:.2%}")
print(f"\nClass probabilities:")
for i, name in enumerate(class_names):
    print(f"  {name}: {probabilities[0, i].item():.2%}")

## 5. Generate Gradient-Based Saliency Map

This shows which input features (frequency × electrode combinations) most influence the prediction.

In [None]:
def compute_saliency_map(model, input_tensor, target_class):
    """
    Compute gradient-based saliency map.
    Higher values indicate more important features for the prediction.
    """
    model.eval()
    input_tensor.requires_grad = True
    
    # Forward pass
    output = model(input_tensor)
    
    # Backward pass for target class
    model.zero_grad()
    output[0, target_class].backward()
    
    # Get gradients
    saliency = input_tensor.grad.data.abs().squeeze().cpu().numpy()
    
    return saliency

# Compute saliency map for predicted class
saliency_map = compute_saliency_map(model, X_tensor, predicted_class)
print(f"Saliency map shape: {saliency_map.shape}")

In [None]:
# Visualize saliency map
fig, ax = plt.subplots(figsize=(14, 8))

# Limit to 0-40 Hz for better visualization
freq_mask = freqs <= 40
freqs_plot = freqs[freq_mask]
saliency_plot = saliency_map[:len(freqs_plot), :]

im = ax.imshow(saliency_plot, aspect='auto', cmap='hot', interpolation='bilinear')
ax.set_xlabel('Electrodes', fontsize=12)
ax.set_ylabel('Frequency (Hz)', fontsize=12)
ax.set_title(f'Saliency Map - {class_names[predicted_class]} Prediction\n' + 
             f'(Brighter = More Important for Classification)', fontsize=14, fontweight='bold')

# Set electrode labels
ax.set_xticks(range(len(electrodes)))
ax.set_xticklabels(electrodes, rotation=45, ha='right')

# Set frequency labels
y_ticks = np.linspace(0, len(freqs_plot)-1, 10, dtype=int)
ax.set_yticks(y_ticks)
ax.set_yticklabels([f'{freqs_plot[i]:.1f}' for i in y_ticks])

# Add frequency band regions
bands = [
    (0.5, 4, 'Delta', 'purple'),
    (4, 8, 'Theta', 'blue'),
    (8, 13, 'Alpha', 'green'),
    (13, 30, 'Beta', 'orange')
]

for low, high, name, color in bands:
    low_idx = np.argmin(np.abs(freqs_plot - low))
    high_idx = np.argmin(np.abs(freqs_plot - high))
    ax.axhspan(low_idx, high_idx, alpha=0.1, color=color, label=name)

plt.colorbar(im, ax=ax, label='Gradient Magnitude')
ax.legend(loc='upper right', fontsize=10)
plt.tight_layout()
plt.savefig('activation_maps/saliency_map.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saliency map saved to activation_maps/saliency_map.png")

## 6. Visualize CNN Feature Maps

This shows what patterns the convolutional layers are detecting.

In [None]:
# Forward pass with activation saving
with torch.no_grad():
    _ = model(X_tensor, save_activations=True)

conv1_act = model.conv1_activations.squeeze().cpu().numpy()
conv2_act = model.conv2_activations.squeeze().cpu().numpy()

print(f"Conv1 activations shape: {conv1_act.shape}")
print(f"Conv2 activations shape: {conv2_act.shape}")

In [None]:
# Visualize first layer feature maps (show first 16 filters)
n_filters = min(16, conv1_act.shape[0])
fig, axes = plt.subplots(4, 4, figsize=(16, 12))
fig.suptitle('Conv1 Feature Maps (First 16 Filters)', fontsize=16, fontweight='bold')

for idx, ax in enumerate(axes.flat):
    if idx < n_filters:
        im = ax.imshow(conv1_act[idx], aspect='auto', cmap='viridis')
        ax.set_title(f'Filter {idx+1}', fontsize=10)
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    else:
        ax.axis('off')

plt.tight_layout()
plt.savefig('activation_maps/conv1_feature_maps.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Conv1 feature maps saved")

In [None]:
# Visualize second layer feature maps (show first 16 filters)
n_filters = min(16, conv2_act.shape[0])
fig, axes = plt.subplots(4, 4, figsize=(16, 12))
fig.suptitle('Conv2 Feature Maps (First 16 Filters)', fontsize=16, fontweight='bold')

for idx, ax in enumerate(axes.flat):
    if idx < n_filters:
        im = ax.imshow(conv2_act[idx], aspect='auto', cmap='viridis')
        ax.set_title(f'Filter {idx+1}', fontsize=10)
        ax.axis('off')
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    else:
        ax.axis('off')

plt.tight_layout()
plt.savefig('activation_maps/conv2_feature_maps.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Conv2 feature maps saved")

## 7. Visualize Average Feature Map per Layer

In [None]:
# Average across all filters to see overall activation pattern
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Conv1 average
conv1_avg = conv1_act.mean(axis=0)
im1 = ax1.imshow(conv1_avg, aspect='auto', cmap='plasma')
ax1.set_title('Conv1 Average Activation', fontsize=14, fontweight='bold')
ax1.set_xlabel('Width (Electrode dimension)', fontsize=12)
ax1.set_ylabel('Height (Frequency dimension)', fontsize=12)
plt.colorbar(im1, ax=ax1, label='Activation')

# Conv2 average
conv2_avg = conv2_act.mean(axis=0)
im2 = ax2.imshow(conv2_avg, aspect='auto', cmap='plasma')
ax2.set_title('Conv2 Average Activation', fontsize=14, fontweight='bold')
ax2.set_xlabel('Width (Electrode dimension)', fontsize=12)
ax2.set_ylabel('Height (Frequency dimension)', fontsize=12)
plt.colorbar(im2, ax=ax2, label='Activation')

plt.tight_layout()
plt.savefig('activation_maps/average_activations.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Average activation maps saved")

## 8. LSTM Temporal Attention Analysis

Visualize which temporal positions (frequency bands) the LSTM focuses on.

In [None]:
# Get LSTM outputs
lstm_out = model.lstm_outputs.squeeze().cpu().numpy()  # (seq_len, hidden_size)
print(f"LSTM outputs shape: {lstm_out.shape}")

# Compute attention as the magnitude of each time step
temporal_attention = np.linalg.norm(lstm_out, axis=1)
temporal_attention = temporal_attention / temporal_attention.max()  # Normalize

print(f"Temporal attention shape: {temporal_attention.shape}")

In [None]:
# Visualize temporal attention
fig, ax = plt.subplots(figsize=(14, 6))

time_steps = np.arange(len(temporal_attention))
ax.bar(time_steps, temporal_attention, color='steelblue', alpha=0.7, edgecolor='navy')
ax.set_xlabel('Temporal Position (Electrode-related sequence)', fontsize=12)
ax.set_ylabel('Attention Weight (Normalized)', fontsize=12)
ax.set_title('LSTM Temporal Attention Weights\n(Higher values = More important for classification)', 
             fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('activation_maps/lstm_temporal_attention.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ LSTM temporal attention saved")

## 9. Feature Importance by Electrode

In [None]:
# Sum saliency across frequencies for each electrode
electrode_importance = saliency_map.sum(axis=0)
electrode_importance = electrode_importance / electrode_importance.max()  # Normalize

# Create DataFrame for visualization
importance_df = pd.DataFrame({
    'Electrode': electrodes,
    'Importance': electrode_importance
}).sort_values('Importance', ascending=True)

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
bars = ax.barh(importance_df['Electrode'], importance_df['Importance'], 
                color='coral', edgecolor='darkred', alpha=0.8)

# Highlight top 5
for i, bar in enumerate(bars[-5:]):
    bar.set_color('crimson')
    bar.set_alpha(1.0)

ax.set_xlabel('Normalized Importance', fontsize=12)
ax.set_ylabel('Electrode', fontsize=12)
ax.set_title(f'Electrode Importance for {class_names[predicted_class]} Classification\n' + 
             '(Top 5 highlighted in red)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig('activation_maps/electrode_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nTop 5 Most Important Electrodes:")
print(importance_df.tail(5).to_string(index=False))
print("\n✓ Electrode importance saved")

## 10. Feature Importance by Frequency Band

In [None]:
# Define frequency bands
bands = [
    (0.5, 4, 'Delta'),
    (4, 8, 'Theta'),
    (8, 13, 'Alpha'),
    (13, 30, 'Beta'),
    (30, 40, 'Gamma')
]

# Compute importance for each band
band_importance = []
for low, high, name in bands:
    mask = (freqs >= low) & (freqs < high)
    if mask.sum() > 0:
        importance = saliency_map[mask, :].sum()
        band_importance.append({'Band': name, 'Frequency': f'{low}-{high} Hz', 'Importance': importance})

band_df = pd.DataFrame(band_importance)
band_df['Importance'] = band_df['Importance'] / band_df['Importance'].max()

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
colors = ['purple', 'blue', 'green', 'orange', 'red']
bars = ax.bar(band_df['Band'], band_df['Importance'], color=colors, alpha=0.7, edgecolor='black')

ax.set_ylabel('Normalized Importance', fontsize=12)
ax.set_xlabel('Frequency Band', fontsize=12)
ax.set_title(f'Frequency Band Importance for {class_names[predicted_class]} Classification', 
             fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{height:.2f}',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('activation_maps/band_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nFrequency Band Importance:")
print(band_df.to_string(index=False))
print("\n✓ Band importance saved")

## 11. Create Summary Visualization

In [None]:
# Create comprehensive summary figure
fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# 1. Saliency Map
ax1 = fig.add_subplot(gs[0:2, 0:2])
im1 = ax1.imshow(saliency_plot, aspect='auto', cmap='hot', interpolation='bilinear')
ax1.set_xlabel('Electrodes', fontsize=10)
ax1.set_ylabel('Frequency (Hz)', fontsize=10)
ax1.set_title('Saliency Map', fontsize=12, fontweight='bold')
ax1.set_xticks(range(len(electrodes)))
ax1.set_xticklabels(electrodes, rotation=45, ha='right', fontsize=8)
y_ticks = np.linspace(0, len(freqs_plot)-1, 8, dtype=int)
ax1.set_yticks(y_ticks)
ax1.set_yticklabels([f'{freqs_plot[i]:.1f}' for i in y_ticks], fontsize=8)
plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)

# 2. Electrode Importance
ax2 = fig.add_subplot(gs[0, 2])
top_5_electrodes = importance_df.tail(5)
ax2.barh(top_5_electrodes['Electrode'], top_5_electrodes['Importance'], 
         color='crimson', alpha=0.8)
ax2.set_xlabel('Importance', fontsize=10)
ax2.set_title('Top 5 Electrodes', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='x')

# 3. Band Importance
ax3 = fig.add_subplot(gs[1, 2])
ax3.bar(band_df['Band'], band_df['Importance'], color=colors, alpha=0.7, edgecolor='black')
ax3.set_ylabel('Importance', fontsize=10)
ax3.set_title('Frequency Bands', fontsize=12, fontweight='bold')
ax3.tick_params(axis='x', rotation=45)
ax3.grid(True, alpha=0.3, axis='y')

# 4. Prediction Info
ax4 = fig.add_subplot(gs[2, 0])
ax4.axis('off')
info_text = f"""PREDICTION SUMMARY

Sample: {sample_file.split('/')[-1]}
Predicted Class: {class_names[predicted_class]}
Confidence: {confidence:.2%}

Class Probabilities:
  ADHD: {probabilities[0, 0].item():.2%}
  Control: {probabilities[0, 1].item():.2%}
"""
ax4.text(0.1, 0.5, info_text, fontsize=11, verticalalignment='center',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
         family='monospace')

# 5. Conv1 Average
ax5 = fig.add_subplot(gs[2, 1])
im5 = ax5.imshow(conv1_avg, aspect='auto', cmap='plasma')
ax5.set_title('Conv1 Avg Activation', fontsize=12, fontweight='bold')
ax5.set_xlabel('Width', fontsize=10)
ax5.set_ylabel('Height', fontsize=10)
plt.colorbar(im5, ax=ax5, fraction=0.046, pad=0.04)

# 6. Conv2 Average
ax6 = fig.add_subplot(gs[2, 2])
im6 = ax6.imshow(conv2_avg, aspect='auto', cmap='plasma')
ax6.set_title('Conv2 Avg Activation', fontsize=12, fontweight='bold')
ax6.set_xlabel('Width', fontsize=10)
ax6.set_ylabel('Height', fontsize=10)
plt.colorbar(im6, ax=ax6, fraction=0.046, pad=0.04)

fig.suptitle('EEG Model Activation Analysis - Complete Summary', 
             fontsize=16, fontweight='bold', y=0.995)

plt.savefig('activation_maps/complete_summary.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Complete summary visualization saved to activation_maps/complete_summary.png")

## 12. Export Results to CSV

In [None]:
# Save numerical results
results = {
    'sample': sample_file.split('/')[-1],
    'predicted_class': class_names[predicted_class],
    'confidence': confidence,
    'adhd_probability': probabilities[0, 0].item(),
    'control_probability': probabilities[0, 1].item(),
}

# Add top electrodes
for i, row in enumerate(importance_df.tail(5).iterrows(), 1):
    results[f'top_electrode_{i}'] = row[1]['Electrode']
    results[f'top_electrode_{i}_importance'] = row[1]['Importance']

# Add band importance
for _, row in band_df.iterrows():
    results[f'{row["Band"].lower()}_importance'] = row['Importance']

results_df = pd.DataFrame([results])
results_df.to_csv('activation_maps/activation_analysis_results.csv', index=False)

print("\n✓ Results exported to activation_maps/activation_analysis_results.csv")
print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)
print("\nGenerated files:")
print("  - saliency_map.png")
print("  - conv1_feature_maps.png")
print("  - conv2_feature_maps.png")
print("  - average_activations.png")
print("  - lstm_temporal_attention.png")
print("  - electrode_importance.png")
print("  - band_importance.png")
print("  - complete_summary.png")
print("  - activation_analysis_results.csv")

## 13. Batch Process Multiple Samples (Optional)

In [None]:
# Uncomment to process multiple samples
"""
from pathlib import Path

sample_files = sorted(Path('extracted_eegs').glob('*.csv'))[:5]  # Process first 5
all_results = []

for sample_path in sample_files:
    print(f"\nProcessing {sample_path.name}...")
    
    # Load and process
    df = pd.read_csv(sample_path)
    X, freqs, electrodes = eeg_to_frequency_features(df)
    X = X[..., np.newaxis]
    X = scaler.transform(X)
    X_tensor = torch.tensor(X, dtype=torch.float32).permute(0, 3, 1, 2).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(X_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        conf = probs[0, pred].item()
    
    # Compute saliency
    saliency = compute_saliency_map(model, X_tensor, pred)
    
    # Save results
    result = {
        'sample': sample_path.name,
        'predicted_class': class_names[pred],
        'confidence': conf
    }
    all_results.append(result)
    
    # Save individual saliency map
    plt.figure(figsize=(14, 8))
    plt.imshow(saliency[:len(freqs_plot), :], aspect='auto', cmap='hot')
    plt.title(f'{sample_path.stem} - {class_names[pred]} ({conf:.2%})')
    plt.colorbar(label='Gradient Magnitude')
    plt.savefig(f'activation_maps/{sample_path.stem}_saliency.png', dpi=150)
    plt.close()

# Save batch results
batch_df = pd.DataFrame(all_results)
batch_df.to_csv('activation_maps/batch_results.csv', index=False)
print(f"\n✓ Processed {len(all_results)} samples")
"""
pass