In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from collections import defaultdict

# Load and preprocess the EEG data
def load_eeg_data(file_path):
    data = loadmat(file_path)
    signals = data['Signal']
    flashing = data['Flashing']
    stimulus_type = data['StimulusType']
    target_chars = data['TargetChar']
    
    return signals, flashing, stimulus_type, target_chars

# Create sentence mapping - associate each trial with its correct character
def create_sentence_mapping(target_chars, trials):
    sentence_map = {}
    for i in range(trials):
        char = ''.join([c for c in target_chars[0][i] if c.strip()])
        sentence_map[i] = char
    
    # Group trials by original sentences
    sentences = defaultdict(list)
    current_sentence = []
    current_word = ""
    
    for i in range(trials):
        char = sentence_map[i]
        if char == '_':  # Space character
            if current_word:
                current_sentence.append(current_word)
                current_word = ""
        elif char in ['.', '!', '?']:  # End of sentence
            if current_word:
                current_sentence.append(current_word)
                current_word = ""
            if current_sentence:
                sentence_key = ' '.join(current_sentence)
                sentences[sentence_key].extend(list(range(i-len(sentence_key)+1, i+1)))
                current_sentence = []
        else:
            current_word += char
    
    # Handle last sentence if exists
    if current_word:
        current_sentence.append(current_word)
    if current_sentence:
        sentence_key = ' '.join(current_sentence)
        sentences[sentence_key].extend(list(range(trials-len(sentence_key)+1, trials+1)))
    
    return sentence_map, sentences

# Extract features from EEG signals for each character in the sentence context
def extract_sentence_features(signals, flashing, stimulus_type, sentence_trials, window_duration_ms=650, sampling_rate=120):
    window_samples = round(sampling_rate * (window_duration_ms / 1000))
    features = []
    labels = []
    sentence_indices = []
    char_indices = []
    
    for sentence_idx, trials in enumerate(sentence_trials.values()):
        for char_idx, trial in enumerate(trials):
            trial_features = []
            trial_labels = []
            
            for sample in range(len(flashing[trial])):
                if (sample == 0) or (flashing[trial, sample-1] == 0 and flashing[trial, sample] == 1):
                    # Extract window following flash onset
                    lower_sample = sample
                    upper_sample = min(sample + window_samples, len(flashing[trial]))
                    window = signals[trial, lower_sample:upper_sample, :]
                    
                    # Add to feature list
                    trial_features.append(window)
                    
                    # Add label (1 for P300, 0 for no P300)
                    if stimulus_type[trial, sample] == 1:
                        trial_labels.append(1)
                    else:
                        trial_labels.append(0)
            
            features.extend(trial_features)
            labels.extend(trial_labels)
            sentence_indices.extend([sentence_idx] * len(trial_features))
            char_indices.extend([char_idx] * len(trial_features))
    
    return np.array(features), np.array(labels), np.array(sentence_indices), np.array(char_indices)

# Create sequential batches for training
def create_sequential_batches(features, labels, sentence_indices, char_indices, seq_length=5, batch_size=32):
    """
    Create batches with sequential character information
    seq_length: number of characters to include in each sequence
    """
    unique_sentences = np.unique(sentence_indices)
    X_batches = []
    y_batches = []
    
    for sent_idx in unique_sentences:
        sent_mask = (sentence_indices == sent_idx)
        sent_features = features[sent_mask]
        sent_labels = labels[sent_mask]
        sent_char_indices = char_indices[sent_mask]
        
        # Group by character
        unique_chars = np.unique(sent_char_indices)
        char_data = []
        
        for char_idx in unique_chars:
            char_mask = (sent_char_indices == char_idx)
            char_features = sent_features[char_mask]
            char_labels = sent_labels[char_mask]
            char_data.append((char_features, char_labels))
        
        # Create sequences
        for i in range(len(char_data) - seq_length + 1):
            seq_features = []
            seq_labels = []
            
            for j in range(seq_length):
                seq_features.append(char_data[i+j][0])
                seq_labels.append(char_data[i+j][1])
            
            # Add to batches
            X_batches.append(seq_features)
            y_batches.append(seq_labels)
    
    # Convert to appropriate format and create mini-batches
    X_batches = np.array(X_batches)
    y_batches = np.array(y_batches)
    
    # Shuffle and create mini-batches
    indices = np.arange(len(X_batches))
    np.random.shuffle(indices)
    
    for i in range(0, len(indices), batch_size):
        batch_indices = indices[i:i+batch_size]
        yield X_batches[batch_indices], y_batches[batch_indices]

# Define a sequence-based ECD model
class SequentialECDModel(nn.Module):
    def __init__(self, input_channels=64, seq_length=5, hidden_dim=128):
        super(SequentialECDModel, self).__init__()
        
        # CNN layers for spatial feature extraction
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, input_channels), padding=(0, 0))
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(5, 1), padding=(2, 0))
        
        # Calculate flattened feature size
        self.feature_size = self._calculate_conv_output_size()
        
        # LSTM for sequence modeling
        self.lstm = nn.LSTM(
            input_size=self.feature_size,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )
        
        # Output layers
        self.fc1 = nn.Linear(hidden_dim, 64)
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()
        
    def _calculate_conv_output_size(self):
        # This function calculates the output size after convolution
        # Placeholder value - should be calculated based on actual dimensions
        return 2048
        
    def forward(self, x):
        # x shape: [batch_size, seq_length, time_points, channels]
        batch_size, seq_length = x.shape[0], x.shape[1]
        
        # Process each sequence element independently
        sequence_outputs = []
        
        for i in range(seq_length):
            # Get current sequence element
            x_i = x[:, i]  # [batch_size, time_points, channels]
            
            # Reshape for CNN
            x_i = x_i.unsqueeze(1)  # Add channel dimension [batch_size, 1, time_points, channels]
            
            # Apply CNNs
            x_i = F.relu(self.conv1(x_i))
            x_i = F.max_pool2d(x_i, kernel_size=(2, 1))
            x_i = F.relu(self.conv2(x_i))
            x_i = F.max_pool2d(x_i, kernel_size=(2, 1))
            
            # Flatten
            x_i = x_i.view(batch_size, -1)
            sequence_outputs.append(x_i)
        
        # Stack sequence outputs
        sequence = torch.stack(sequence_outputs, dim=1)  # [batch_size, seq_length, feature_size]
        
        # Apply LSTM
        lstm_out, _ = self.lstm(sequence)
        
        # Use final output for classification
        x = F.relu(self.fc1(lstm_out[:, -1, :]))
        x = self.fc2(x)
        x = self.sigmoid(x)
        
        return x.squeeze(-1)

# Main function to create sentence-based EEG dataset and train model
def create_sentence_eeg_dataset(train_file_path, test_file_path=None):
    # Load training data
    signals_train, flashing_train, stimulus_train, word_train = load_eeg_data(train_file_path)
    trials_train = len(word_train[0])
    
    # Create sentence mapping
    sentence_map, sentences = create_sentence_mapping(word_train, trials_train)
    
    # Extract features with sentence context
    features, labels, sentence_indices, char_indices = extract_sentence_features(
        signals_train, flashing_train, stimulus_train, sentences
    )
    
    # Normalize features
    for i in range(len(features)):
        features[i] = (features[i] - np.mean(features[i], axis=0)) / (np.std(features[i], axis=0) + 1e-8)
    
    # Train-validation split
    X_train, X_val, y_train, y_val, sent_idx_train, sent_idx_val, char_idx_train, char_idx_val = train_test_split(
        features, labels, sentence_indices, char_indices, test_size=0.2, random_state=42
    )
    
    print(f"Features shape: {features.shape}")
    print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")
    print(f"Number of sentences: {len(sentences)}")
    
    return {
        'X_train': X_train,
        'y_train': y_train,
        'sent_idx_train': sent_idx_train,
        'char_idx_train': char_idx_train,
        'X_val': X_val,
        'y_val': y_val,
        'sent_idx_val': sent_idx_val,
        'char_idx_val': char_idx_val,
        'sentences': sentences,
        'sentence_map': sentence_map
    }

# Training loop
def train_sequential_ecd(dataset, epochs=10, batch_size=32, seq_length=5, learning_rate=0.001):
    # Initialize model
    model = SequentialECDModel(input_channels=64, seq_length=seq_length)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        batch_count = 0
        
        # Create sequential batches
        batch_generator = create_sequential_batches(
            dataset['X_train'], dataset['y_train'], 
            dataset['sent_idx_train'], dataset['char_idx_train'],
            seq_length=seq_length, batch_size=batch_size
        )
        
        for X_batch, y_batch in batch_generator:
            # Convert to PyTorch tensors
            X_batch = torch.tensor(X_batch, dtype=torch.float32)
            y_batch = torch.tensor(y_batch, dtype=torch.float32)
            
            # Forward pass
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            batch_count += 1
        
        # Print epoch results
        avg_loss = total_loss / batch_count
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_batch_generator = create_sequential_batches(
                dataset['X_val'], dataset['y_val'],
                dataset['sent_idx_val'], dataset['char_idx_val'],
                seq_length=seq_length, batch_size=batch_size
            )
            
            val_loss = 0
            val_count = 0
            for X_val, y_val in val_batch_generator:
                X_val = torch.tensor(X_val, dtype=torch.float32)
                y_val = torch.tensor(y_val, dtype=torch.float32)
                
                outputs = model(X_val)
                val_loss += criterion(outputs, y_val).item()
                val_count += 1
            
            avg_val_loss = val_loss / val_count
            print(f"Validation Loss: {avg_val_loss:.4f}")
    
    return model


# Set paths
train_file_path = '../data/Contributor_I_Train.mat'

# Create dataset
dataset = create_sentence_eeg_dataset(train_file_path)

print(dataset)
# Train model
# model = train_sequential_ecd(dataset, epochs=10, batch_size=16, seq_length=5)

# Save model
# torch.save(model.state_dict(), "../model/sequential_ecd_model.pth")
print("Model saved successfully.")

IndexError: index 85 is out of bounds for axis 0 with size 85