# SWR CNN Autoencoder & Classification Pipeline

This notebook runs the complete SWR classification pipeline by importing and executing the main functions from the scripts in the `../cnn_autoencoder` directory.

**Workflow:**
1.  **Setup Paths:** Configure paths and set the main recording directory for inputs/outputs.
2.  **Step 1:** Generate spectrograms and extract biological features.
3.  **Step 2:** Train the CNN autoencoder (ResNet or VAE) on the spectrograms.
4.  **Step 3:** Cluster events using a combination of autoencoder and biological features.
5.  **Step 4:** Evaluate and visualize the final clusters.

In [None]:
import os
import sys

# --- 1. DEFINE PATHS & PARAMETERS ---

# !! IMPORTANT: This is the user-specified path for all outputs (and inputs)
recording_path = r"F:\Spikeinterface_practice\s4_rec"

# This is the root of the project (pfr_neurophys_data_analysis)
# Assumes this notebook is in pfr_neurophys_data_analysis/notebooks/
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Add the project root to the Python path to allow imports from cnn_autoencoder
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Change the current working directory to the recording path
# All generated files (models, plots, .pkl) will be saved here.
try:
    os.chdir(recording_path)
    print(f"Changed working directory to: {os.getcwd()}")
except FileNotFoundError:
    print(f"ERROR: Recording path not found: {recording_path}")
    print("Please update the 'recording_path' variable in this cell.")

print(f"Project root added to sys.path: {project_root}")

# --- 2. DEFINE PIPELINE PARAMETERS ---
# These can be adjusted as needed
ARCH = 'resnet'      # 'resnet', 'vae', or 'attention'
LATENT_DIM = 128     # Latent dimension for the autoencoder
EPOCHS = 15          # Number of epochs for training

In [None]:
# Import your existing modules
from open_ephys_loader import fast_openephys_dat_lfp
from spike_analysis import SpikeAnalysis, loader, process_spike_data, load_processed_spike_data
from swr_detection.swr_hmm_detection import SWRHMMParams, SWRHMMDetector
from swr_detection.pipeline import find_region_channels, build_region_lfp
from swr_detection.swr_spectral_features import batch_compute_spectral_features

# Import new feature extraction module
from feature_extraction import batch_extract_features, validate_biological_features
"""
Performs SWR detection, computes spectrograms, AND extracts biological features.
This creates a richer dataset for clustering.
"""
print("="*80)
print("IMPROVED SWR DETECTION WITH COMPREHENSIVE FEATURE EXTRACTION")
print("="*80)

# --- Configuration ---
dat_path = r"D:\Spikeinterface_practice\s4_rec\ephys.dat"
num_channels = 43
selected_channels = {
    'CA1_tet1': 17, 'CA1_tet2': 21, 'RTC_tet1': 14, 'PFC_tet1': 0, 'PFC_tet2': 5
}
fs_in = 30000.0
fs_out = 1000.0
output_dir = "all_spectrograms"

# --- Load LFP and Spike Data ---
print("\n--- Loading Data ---")
try:
    loader = fast_openephys_dat_lfp(
        filepath=dat_path,
        num_channels=num_channels,
        tetrode_groups={},
        selected_channels=selected_channels,
        sampling_frequency=fs_in,
        target_sampling_frequency=fs_out,
        return_mode="loader",
    )
    fs = float(loader.sampling_frequency)
    t_lfp = loader.time_vector()
    print(f"✓ LFP duration: {loader.duration:.2f}s at {fs:.1f} Hz")

    # Load spike data
    npy_path = r'D:\Spikeinterface_practice\s4_rec\phyMS5'
    save_path = r'D:\Spikeinterface_practice\s4_rec'
    if not os.path.exists(os.path.join(save_path, 'units.npy')):
        print("Processing spike data...")
        process_spike_data(npy_path, save_path, samp_freq=30000)

    units_file = os.path.join(save_path, 'units.npy')
    processed_spike_data = load_processed_spike_data(units_file)

    spike_analysis = SpikeAnalysis(
        processed_data=processed_spike_data,
        sampling_rate=30000,
        duration=loader.duration
    )

    region_mapping = {7: 'CA1', 8: 'CA1', 6: 'RTC', 2: 'PFC', 3: 'PFC'}
    spike_analysis.assign_brain_regions(region_mapping)

    mua_by_region = spike_analysis.compute_mua_all_regions(t_lfp=t_lfp, kernel_width=0.01)
    mua_vec = mua_by_region['CA1']

    region_channels = find_region_channels(list(loader.selected_channels.keys()))
    region_lfp = build_region_lfp(loader, region_channels)
    lfp_array = region_lfp['CA1']
    
    print(f"✓ LFP shape: {lfp_array.shape}")
    print(f"✓ MUA vector length: {len(mua_vec)}")

except FileNotFoundError as e:
    print(f"\nERROR: Data files not found: {e}")
    print(f"Attempted to load LFP from: {dat_path}")
    print("Cannot proceed without data. Exiting.")
    return

# --- SWR Detection ---
print("\n--- Detecting SWRs ---")
ripple_th = 2.75
params = SWRHMMParams(
    ripple_band=(125, 250),
    threshold_multiplier=ripple_th,
    use_smoothing=True,
    smoothing_sigma=0.01,
    normalization_method='zscore',
    min_duration=0.025,
    max_duration=0.4,
    min_event_separation=0.07,
    merge_interval=0.07,
    trace_window=1.0,
    adaptive_classification=True,
    dbscan_eps=0.15,
    mua_threshold_multiplier=2.5,
    mua_min_duration=0.03,
    enable_mua=True,
    use_hmm_edge_detection=False,
    hmm_margin=0.1,
    use_global_hmm=False,
    global_hmm_fraction=0.1,
    hmm_states_ripple=2,
    hmm_states_mua=2,
    use_hysteresis=True,
    hysteresis_low_multiplier=0.75,
    hysteresis_confirmation_window=0.07
)

detector = SWRHMMDetector(
    lfp_data=lfp_array,
    fs=fs,
    mua_data=mua_vec,
    params=params
)

detector.detect_events(channels=[0], average_mode=False)
detector.classify_events_improved()

print(f"✓ Found {len(detector.swr_events)} events")

# --- Compute Spectrograms ---
print("\n--- Computing Spectrograms ---")
for event in detector.swr_events:
    event['spec_method'] = 'cwt'

lfp_channel = region_lfp['CA1'][0]
n_computed = batch_compute_spectral_features(
    detector, 
    lfp_channel, 
    fs,
    use_optimized_cwt=True,
    n_workers=20,
    verbose=True,
    target_freq_bins=150,
    n_bins=100,
    smoothing_sigma=1.0,
    pre_ms=250,
    post_ms=250
)
print(f"✓ Successfully computed {n_computed} spectrograms")

## Step 1: Generate Spectrograms and Extract Features

This step runs `generate_spectrograms.py`.
- Detects SWR events from the raw data.
- Generates spectrogram images for each event.
- Extracts biological features (duration, frequency, power, etc.).

**Output files (saved to recording_path):**
- `all_spectrograms/` (directory)
- `detected_events.pkl`
- `biological_features.pkl`

In [None]:
print("\n" + "="*80)
print("STEP 1: GENERATING SPECTROGRAMS AND FEATURES...")
print("="*80)

try:
    # Imports from pfr_neurophys_data_analysis/cnn_autoencoder/generate_spectrograms.py
    from cnn_autoencoder.generate_spectrograms import generate_spectrograms_and_features
    
    generate_spectrograms_and_features()
    
    print("\n" + "-"*80)
    print("STEP 1 COMPLETE")
    print("-"*80)
except ImportError:
    print(f"ERROR: Could not import 'generate_spectrograms_and_features'.")
    print(f"Ensure 'generate_spectrograms.py' is in {os.path.join(project_root, 'cnn_autoencoder')}")
except Exception as e:
    print(f"An error occurred during Step 1: {e}")

## Step 2: Train Autoencoder

This step runs `train_autoencoder.py`.
- Loads the spectrograms from `all_spectrograms/`.
- Trains the specified CNN autoencoder (`resnet` or `vae`).

**Output files (saved to recording_path):**
- `full_model_resnet.pkl` (or similar, based on arch)
- `encoder_model_resnet.pkl`
- `training_history_resnet.png`

In [None]:
print("\n" + "="*80)
print("STEP 2: TRAINING AUTOENCODER...")
print("="*80)

try:
    # Imports from pfr_neurophys_data_analysis/cnn_autoencoder/train_autoencoder.py
    from cnn_autoencoder.train_autoencoder import train_autoencoder_improved
    
    train_autoencoder_improved(
        arch=ARCH,
        latent_dim=LATENT_DIM,
        epochs=EPOCHS,
        lr=None,      # Use 'None' to trigger auto learning rate finder
        beta=1.0      # Beta parameter for VAE (ignored for ResNet)
    )
    
    print("\n" + "-"*80)
    print("STEP 2 COMPLETE")
    print("-"*80)
except ImportError:
    print(f"ERROR: Could not import 'train_autoencoder_improved'.")
    print(f"Ensure 'train_autoencoder.py' is in {os.path.join(project_root, 'cnn_autoencoder')}")
except Exception as e:
    print(f"An error occurred during Step 2: {e}")

## Step 3: Cluster Events

This step runs `cluster_events.py`.
- Loads the trained encoder (`encoder_model_...pkl`).
- Loads the biological features (`biological_features.pkl`).
- Generates latent features from the autoencoder.
- Combines features and performs clustering to find optimal k.

**Output files (saved to recording_path):**
- `events_with_clusters_combined.pkl`
- `clustering_info_combined.pkl`
- `clustering_metrics_plot.png`
- `dendrogram.png`

In [None]:
print("\n" + "="*80)
print("STEP 3: CLUSTERING EVENTS...")
print("="*80)

try:
    # Imports from pfr_neurophys_data_analysis/cnn_autoencoder/cluster_events.py
    from cnn_autoencoder.cluster_events import cluster_events_improved
    
    cluster_events_improved(
        arch=ARCH,
        latent_dim=LATENT_DIM,
        k_range=range(2, 13),  # Test k from 2 to 12
        ae_weight=0.7,
        bio_weight=0.3,
        use_combined_features=True
    )
    
    print("\n" + "-"*80)
    print("STEP 3 COMPLETE")
    print("-"*80)
except ImportError:
    print(f"ERROR: Could not import 'cluster_events_improved'.")
    print(f"Ensure 'cluster_events.py' is in {os.path.join(project_root, 'cnn_autoencoder')}")
except Exception as e:
    print(f"An error occurred during Step 3: {e}")

## Step 4: Evaluate and Validate Clusters

This step runs `evaluate_clusters.py`.
- Loads the clustered events (`events_with_clusters_combined.pkl`).
- Performs detailed biological validation.
- Generates summary plots and visualizations.

**Output files (saved to recording_path):**
- `feature_space_combined.png`
- `cluster_validation_report_combined.txt`
- `cluster_..._summary.png` (and other plots)

In [None]:
print("\n" + "="*80)
print("STEP 4: EVALUATING CLUSTERS...")
print("="*80)

try:
    # Imports from pfr_neurophys_data_analysis/cnn_autoencoder/evaluate_clusters.py
    from cnn_autoencoder.evaluate_clusters import evaluate_clusters_improved
    
    # Use 'combined' since we clustered with combined features in Step 3
    evaluate_clusters_improved(feature_type='combined')
    
    print("\n" + "="*80)
    print("PIPELINE FINISHED!")
    print(f"All outputs saved to: {os.getcwd()}")
    print("="*80)
except ImportError:
    print(f"ERROR: Could not import 'evaluate_clusters_improved'.")
    print(f"Ensure 'evaluate_clusters.py' is in {os.path.join(project_root, 'cnn_autoencoder')}")
except Exception as e:
    print(f"An error occurred during Step 4: {e}")