### SingLEM: Feature Extraction Demo

This notebook provides a hands-on guide to using the pretrained SingLEM model for feature extraction. We will cover two main examples:
1. A simple dummy data example: How to load the model and pass a random tensor through it to understand the input and output shapes.
2. A real EEG data example: A more realistic workflow showing how to load a raw EEG file, preprocess it, and extract features for all channels and trials.

### 1. Setup and Model Loading
First, we'll set up the necessary paths, import the model architecture, and load the pretrained weights. The model will be automatically moved to a GPU if one is available.

In [10]:
import os
import sys
import mne
import torch

# --- Setup Paths ---
# This assumes the notebook is in the 'examples' directory
current_dir = os.getcwd()
project_root = os.path.dirname(current_dir)
# Add the 'singlem' package directory to the Python path
singlem_path = os.path.join(project_root, 'SingLEM')
sys.path.append(singlem_path)

# --- Import Model and Define Weights Path ---
from model import EEGEncoder, Config
weights_path = os.path.join(project_root, 'weights', 'singlem_pretrained.pt')

# --- Initialize Model and Load Weights ---
print("Loading SingLEM model...")
config = Config()
config.mask_prob = 0.0  # Set mask probability to 0 for feature extraction
device = 'cuda' if torch.cuda.is_available() else 'cpu'
feature_extractor = EEGEncoder(config)

encoder_state = torch.load(weights_path, map_location=device)
feature_extractor.load_state_dict(encoder_state)
feature_extractor = feature_extractor.to(device)
feature_extractor.eval()    # Set the model to evaluation mode
print(f"Model loaded successfully on '{device}'.")

Loading SingLEM model...
Model loaded successfully on 'cuda'.


  encoder_state = torch.load(weights_path, map_location=device)


### 2. Dummy Data Example
Let's start by creating a random tensor to represent a batch of single-channel EEG sequences and passing it through the model. This is the quickest way to verify that the model is working and to understand its expected input and output shapes.

The expected input shape for SingLEM is (batch_size, num_tokens, samples_per_token).

In [None]:
# Create a dummy tensor representing 1 batch of 10 one-second tokens
# Shape: (batch_size, num_tokens, samples_per_token)
dummy_eeg_sequence = torch.randn(1, 10, 128, device=device)

# Extract features:
with torch.no_grad():
    features, _, _ = feature_extractor(dummy_eeg_sequence)

print(f"Input shape: {dummy_eeg_sequence.shape}")
print(f"Output feature shape: {features.shape}")

# The output shape is (batch_size, num_tokens, feature_dimension)

Input shape: torch.Size([1, 10, 128])
Output feature shape: torch.Size([1, 10, 16])


### 3. Real EEG Data Example
Now, let's walk through a more realistic example using a sample EEG file. We'll use the MNE-Python library, a standard tool for EEG analysis, to load and preprocess the data.

Our goal is to take a multi-channel EEG recording and extract SingLEM features for every channel independently.

Note: You will need to have an EEG file (e.g., EEG.gdf) in the same directory as this notebook for this example to run.

In [None]:
# --- 1. Load and Preprocess EEG Data ---
# This example uses a GDF file, but MNE supports many formats (EDF, BDF, BrainVision, etc.)
try:
    raw = mne.io.read_raw_gdf('EEG.gdf', preload=True, verbose=False)
except FileLinks:
    print("Sample file 'EEG.gdf' not found. Skipping real data example.")
    raw = None
if raw:
    print("Original raw data info:")
    print(raw.info)

    # Apply minimal preprocessing
    # NOte: These steps should match the preprocessing of SingLEM's pretraining process and your actual data
    raw.drop_channels(ch_names=['EOG1', 'EOG2', 'EOG3', 'EMGg', 'EMGd'])
    raw.notch_filter(50, verbose=False)
    raw.filter(0.5, 50, verbose=False)
    raw.resample(128, verbose=False)

    print("\nPreprocessed raw data info:")
    print(raw.info)

    # --- 2. Tokenize the Data ---
    # The simplest way to tokenize the data is using MNE's make_fixed_length_epochs function.
    # We create 1-second epochs (tokens) with a 25% overlap (stride of 0.75s)
    epochs = mne.make_fixed_length_epochs(raw, duration=1.0, overlap=0.25, verbose=False)

    # Get data as a Numpy array: (num_tokens, num_channels, samples_per_token)
    data = epochs.get_data()
    print(f"\nData tokenized into shape: {data.shape}")

    # --- 3. Reshape for SingLEM ---
    # SingLEM processes each channel independently. We need to reshape the data
    # so that the channel and token/trial dimensions are combined into the batch dimension.
    
    # First, swap axes to (num_channels, num_tokens, samples_per_token)
    data_transposed = data.transpose(1, 0, 2)
    
    # We will process all tokens from all channels in one large batch.
    # New shape: (num_channels * num_tokens, 1, samples_per_token)
    # The '1' represents a sequence length of a single token for this simple case.
    # For longer sequences, you would group tokens first.
    batched_input = data_transposed.reshape(-1, 1, 128)
    
    print(f"Data reshaped for model input: {batched_input.shape}")

    # --- 4. Extract Features ---
    # Convert to a PyTorch tensor and pass through the model
    input_tensor = torch.tensor(batched_input, dtype=torch.float32, device=device)
    
    with torch.no_grad():
        features, _, _ = feature_extractor(input_tensor)
        
    print(f"\nExtracted features shape: {features.shape}")
    
    # --- 5. Reshape Features Back ---
    # The output can be reshaped back to separate the channel and token dimensions
    # Shape: (num_channels, num_tokens, feature_dimension)
    num_channels = data.shape[1]
    features_reshaped = features.reshape(num_channels, -1, features.shape[-1])
    print(f"Features reshaped to (channels, tokens, feature_dim): {features_reshaped.shape}")

Original raw data info:
<Info | 8 non-empty values
 bads: []
 ch_names: Fz, FCz, Cz, CPz, Pz, C1, C3, C5, C2, C4, C6, EOG1, EOG2, EOG3, ...
 chs: 32 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 256.0 Hz
 meas_date: unspecified
 nchan: 32
 projs: []
 sfreq: 512.0 Hz
 subject_info: <subject_info | his_id: X, last_name: >
>

Preprocessed raw data info:
<Info | 8 non-empty values
 bads: []
 ch_names: Fz, FCz, Cz, CPz, Pz, C1, C3, C5, C2, C4, C6, F4, FC2, FC4, FC6, ...
 chs: 27 EEG
 custom_ref_applied: False
 highpass: 0.5 Hz
 lowpass: 50.0 Hz
 meas_date: unspecified
 nchan: 27
 projs: []
 sfreq: 128.0 Hz
 subject_info: <subject_info | his_id: X, last_name: >
>
Using data from preloaded Raw for 249 events and 128 original time points ...
0 bad epochs dropped

Data tokenized into shape: (249, 27, 128)
Data reshaped for model input: (6723, 1, 128)

Extracted features shape: torch.Size([6723, 1, 16])
Features reshaped to (channels, tokens, feature_dim): torch.Size([27, 249, 16])
