In [5]:
# Sentence Context Integration for EEG-based BCI

import numpy as np
import matplotlib.pyplot as plt
import warnings
from sklearn.preprocessing import scale
from scipy.io import loadmat
from scipy import signal
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict

# Ignore warnings
warnings.filterwarnings('ignore')

def load_and_preprocess_data(file_path, apply_filtering=True):
    """
    Load and preprocess EEG data
    """
    # Load data
    data = loadmat(file_path)
    signals = data['Signal']
    flashing = data['Flashing']
    stimulus_type = data['StimulusType']
    target_chars = data['TargetChar']
    
    # Get data dimensions
    trials = len(target_chars[0])
    sampling_frequency = 240
    
    if apply_filtering:
        # Apply 4th-order Butterworth bandpass filter (0.1-20 Hz)
        b, a = signal.butter(4, [0.1 / sampling_frequency, 20 / sampling_frequency], 'bandpass')
        for trial in range(trials):
            signals[trial, :, :] = signal.filtfilt(b, a, signals[trial, :, :], axis=0)
    
    # Downsample from 240Hz to 120Hz
    down_sampling_frequency = 120
    scale_factor = round(sampling_frequency / down_sampling_frequency)
    
    signals = signals[:, 0:-1:scale_factor, :]
    flashing = flashing[:, 0:-1:scale_factor]
    stimulus_type = stimulus_type[:, 0:-1:scale_factor]
    
    return signals, flashing, stimulus_type, target_chars, trials, down_sampling_frequency

def extract_sentences_from_char_sequence(target_chars):
    """
    Extract sentences from character sequence
    """
    chars_sequence = []
    for i in range(len(target_chars[0])):
        char = ''.join([c for c in target_chars[0][i] if c.strip()])
        chars_sequence.append(char)
    
    # Convert to string for easier visualization
    full_text = ''.join(chars_sequence)
    
    # Split into sentences based on special characters or patterns
    sentences = []
    current_sentence = ""
    
    for char in chars_sequence:
        if char in ['_']:  # Space character
            current_sentence += " "
        elif char in ['.', '!', '?']:  # End of sentence markers
            current_sentence += char
            sentences.append(current_sentence.strip())
            current_sentence = ""
        else:
            current_sentence += char
    
    # Add last sentence if not empty
    if current_sentence.strip():
        sentences.append(current_sentence.strip())
    
    return chars_sequence, sentences

def create_char_to_sentence_mapping(chars_sequence, sentences):
    """
    Create mapping from character index to sentence index and position
    """
    char_to_sentence = {}
    char_idx = 0
    
    for sent_idx, sentence in enumerate(sentences):
        # Remove spaces for matching with original char sequence
        clean_sentence = sentence.replace(" ", "")
        
        for pos_in_sent, _ in enumerate(clean_sentence):
            if char_idx < len(chars_sequence):
                char_to_sentence[char_idx] = {
                    'sentence_idx': sent_idx,
                    'position_in_sentence': pos_in_sent,
                    'character': chars_sequence[char_idx],
                    'full_sentence': sentence
                }
                char_idx += 1
    
    return char_to_sentence

def extract_trial_features(signals, flashing, stimulus_type, trial_idx, 
                          window_duration=650, sampling_rate=120):
    """
    Extract features and labels for a single trial
    """
    # Calculate window samples
    window_samples = round(sampling_rate * (window_duration / 1000))
    samples_per_trial = len(flashing[trial_idx])
    
    trial_features = []
    trial_labels = []
    trial_stimulus_codes = []
    flash_indices = []
    
    for sample in range(samples_per_trial):
        if (sample == 0) or (flashing[trial_idx, sample - 1] == 0 and flashing[trial_idx, sample] == 1):
            # This is a flash onset
            lower_sample = sample
            upper_sample = min(sample + window_samples, samples_per_trial)
            
            # Extract window
            window = signals[trial_idx, lower_sample:upper_sample, :]
            
            # Check if window is complete (some might be cut off at the end)
            if window.shape[0] == window_samples:
                trial_features.append(window)
                flash_indices.append(sample)
                
                # Add label (1 for P300, 0 for no P300)
                if stimulus_type[trial_idx, sample] == 1:
                    trial_labels.append(1)
                else:
                    trial_labels.append(0)
    
    return np.array(trial_features), np.array(trial_labels), flash_indices

def create_sentence_based_dataset(signals, flashing, stimulus_type, target_chars, trials, char_to_sentence):
    """
    Create a dataset with sentence context
    """
    # Dataset structure
    sentence_data = defaultdict(lambda: defaultdict(dict))
    
    # Extract features for each trial/character
    for trial in range(trials):
        # Extract features for this trial
        trial_features, trial_labels, _ = extract_trial_features(
            signals, flashing, stimulus_type, trial)
        
        # Normalize features
        for i in range(len(trial_features)):
            trial_features[i] = scale(trial_features[i], axis=0)
        
        # Get sentence context for this character
        if trial in char_to_sentence:
            sent_idx = char_to_sentence[trial]['sentence_idx'] 
            pos = char_to_sentence[trial]['position_in_sentence']
            char = char_to_sentence[trial]['character']
            
            # Store in structured format
            sentence_data[sent_idx][pos] = {
                'character': char,
                'features': trial_features,
                'labels': trial_labels,
                'trial_idx': trial
            }
    
    return sentence_data

def create_contextual_sequences(sentence_data, context_size=3):
    """
    Create sequences with context_size characters before and after each target
    """
    X_sequences = []
    y_sequences = []
    contexts = []
    
    for sent_idx, sent_data in sentence_data.items():
        # Get all positions in this sentence
        positions = sorted(list(sent_data.keys()))
        
        for i, pos in enumerate(positions):
            # Get context positions
            context_positions = []
            for c in range(-context_size, context_size + 1):
                if c != 0 and 0 <= i + c < len(positions):
                    context_positions.append(positions[i + c])
            
            target_features = sent_data[pos]['features']
            target_labels = sent_data[pos]['labels']
            target_char = sent_data[pos]['character']
            
            # Get context features
            context_chars = []
            for ctx_pos in context_positions:
                ctx_char = sent_data[ctx_pos]['character']
                context_chars.append(ctx_char)
            
            # Create context string
            context_str = ''.join(context_chars)
            
            # Add to dataset
            X_sequences.append(target_features)
            y_sequences.append(target_labels)
            contexts.append({
                'target_char': target_char,
                'context': context_str,
                'sentence_idx': sent_idx,
                'position': pos
            })
    
    return X_sequences, y_sequences, contexts

def create_sequence_batches(X_sequences, y_sequences, contexts, batch_size=32):
    """
    Create batches for training
    """
    indices = np.arange(len(X_sequences))
    np.random.shuffle(indices)
    
    for start_idx in range(0, len(indices), batch_size):
        batch_indices = indices[start_idx:start_idx + batch_size]
        
        batch_X = [X_sequences[i] for i in batch_indices]
        batch_y = [y_sequences[i] for i in batch_indices]
        batch_contexts = [contexts[i] for i in batch_indices]
        
        # Convert to torch tensors
        batch_X = [torch.tensor(x, dtype=torch.float32) for x in batch_X]
        batch_y = [torch.tensor(y, dtype=torch.float32) for y in batch_y]
        
        yield batch_X, batch_y, batch_contexts

class ContextualP300Classifier(nn.Module):
    """
    Neural network model for P300 classification with contextual information
    """
    def __init__(self, input_channels=64, input_time_points=78, 
                 filter_size=10, hidden_size=128, dropout_rate=0.5):
        super(ContextualP300Classifier, self).__init__()
        
        # Spatial filtering - reduce 64 channels to smaller representation
        self.spatial_filter = nn.Conv2d(1, filter_size, (input_channels, 1))
        
        # Temporal convolution
        self.temporal_conv = nn.Conv2d(1, filter_size, (1, 10))
        
        # Feature size after convolutions
        self.feature_size = filter_size * (input_time_points - 10 + 1)
        
        # Fully connected layers
        self.fc1 = nn.Linear(self.feature_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        # Input shape: [batch_size, time_points, channels]
        batch_size = x.size(0)
        
        # Spatial filtering
        # Reshape for spatial convolution
        x_spatial = x.unsqueeze(1)  # [batch_size, 1, time_points, channels]
        x_spatial = self.spatial_filter(x_spatial)  # [batch_size, filter_size, time_points, 1]
        x_spatial = x_spatial.squeeze(-1)  # [batch_size, filter_size, time_points]
        
        # Temporal convolution
        # Reshape for temporal convolution
        x_temporal = x.permute(0, 2, 1).unsqueeze(1)  # [batch_size, 1, channels, time_points]
        x_temporal = self.temporal_conv(x_temporal)  # [batch_size, filter_size, channels, time_points-filter+1]
        x_temporal = x_temporal.squeeze(2)  # [batch_size, filter_size, time_points-filter+1]
        
        # Combine features
        x_spatial_flat = x_spatial.reshape(batch_size, -1)
        x_temporal_flat = x_temporal.reshape(batch_size, -1)
        
        # Use spatial features for now (can be extended to combine both)
        x = x_spatial_flat
        
        # Apply fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        
        return x.squeeze(1)


def visualize_sentence_structure(sentence_data):
    """
    Visualize the structure of sentence data
    """
    fig, ax = plt.subplots(figsize=(15, 8))
    
    y_positions = []
    x_positions = []
    labels = []
    
    for sent_idx, sent_data in sentence_data.items():
        for pos, char_data in sent_data.items():
            y_positions.append(sent_idx)
            x_positions.append(pos)
            labels.append(char_data['character'])
    
    ax.scatter(x_positions, y_positions, c='blue', s=100)
    
    # Add character labels
    for i, txt in enumerate(labels):
        ax.annotate(txt, (x_positions[i], y_positions[i]), fontsize=12, 
                   ha='center', va='center', color='white')
    
    ax.set_xlabel('Position in Sentence')
    ax.set_ylabel('Sentence Index')
    ax.set_title('Character Distribution in Sentences')
    ax.grid(True)
    
    return fig

def visualize_p300_responses(features, labels, char_to_sentence, trial_idx, n_channels=4):
    """
    Visualize P300 responses for a specific trial
    """
    # Get a subset of channels to visualize
    channels_to_plot = [0, 10, 20, 30]  # Example channels
    
    if trial_idx in char_to_sentence:
        sent_info = char_to_sentence[trial_idx]
        title = f"Trial {trial_idx}: '{sent_info['character']}' (Position {sent_info['position_in_sentence']} in Sentence {sent_info['sentence_idx']})"
    else:
        title = f"Trial {trial_idx}"
    
    fig, axes = plt.subplots(n_channels, 1, figsize=(15, 10), sharex=True)
    
    # Split by P300 presence
    p300_features = features[labels == 1]
    non_p300_features = features[labels == 0]
    
    for i, channel in enumerate(channels_to_plot):
        # Plot P300 responses
        if len(p300_features) > 0:
            p300_mean = np.mean(p300_features[:, :, channel], axis=0)
            axes[i].plot(p300_mean, 'r-', linewidth=2, label='P300')
        
        # Plot non-P300 responses
        if len(non_p300_features) > 0:
            non_p300_mean = np.mean(non_p300_features[:, :, channel], axis=0)
            axes[i].plot(non_p300_mean, 'b-', linewidth=2, label='Non-P300')
        
        axes[i].set_ylabel(f'Channel {channel}')
        axes[i].legend()
        axes[i].grid(True)
    
    axes[-1].set_xlabel('Time (samples)')
    fig.suptitle(title)
    plt.tight_layout()
    
    return fig

def integrate_sentence_context(contributor_selected="I"):
    """
    Main function to integrate sentence context into EEG data
    """
    # File paths
    contributor_train_file_path = f'../data/Contributor_{contributor_selected}_Train.mat'
    
    print(f"Processing data for Contributor {contributor_selected}...")
    
    # Load and preprocess data
    signals, flashing, stimulus_type, target_chars, trials, sampling_frequency = load_and_preprocess_data(
        contributor_train_file_path)
    
    print(f"Data loaded. Trials: {trials}, Sampling rate: {sampling_frequency}Hz")
    
    # Extract sentences from character sequence
    chars_sequence, sentences = extract_sentences_from_char_sequence(target_chars)
    print(f"Extracted {len(sentences)} sentences from {len(chars_sequence)} characters")
    
    # Create character to sentence mapping
    char_to_sentence = create_char_to_sentence_mapping(chars_sequence, sentences)
    print(f"Created mapping for {len(char_to_sentence)} characters")
    
    # Print a few sentences
    print("\nSample sentences:")
    for i, sentence in enumerate(sentences[:3]):
        print(f"Sentence {i}: {sentence}")
    
    # Create sentence-based dataset
    sentence_data = create_sentence_based_dataset(
        signals, flashing, stimulus_type, target_chars, trials, char_to_sentence)
    print(f"\nCreated sentence-based dataset with {len(sentence_data)} sentences")
    
    # Visualize sentence structure
    fig = visualize_sentence_structure(sentence_data)
    plt.savefig(f"../output/sentence_structure_{contributor_selected}.png")
    plt.close(fig)
    
    # Create contextual sequences
    X_sequences, y_sequences, contexts = create_contextual_sequences(sentence_data, context_size=2)
    print(f"\nCreated {len(X_sequences)} contextual sequences with context size 2")
    
    # Print some context examples
    print("\nSample contexts:")
    for i in range(3):
        ctx = contexts[i]
        print(f"Target: '{ctx['target_char']}', Context: '{ctx['context']}', " +
              f"Sentence: {ctx['sentence_idx']}, Position: {ctx['position']}")
    
    
    
    # Visualize P300 responses for a few trials
    for trial_idx in range(3):
        # Extract features for this trial
        trial_features, trial_labels, _ = extract_trial_features(
            signals, flashing, stimulus_type, trial_idx)
        
        fig = visualize_p300_responses(trial_features, trial_labels, char_to_sentence, trial_idx)
        plt.savefig(f"../output/p300_trial_{trial_idx}_{contributor_selected}.png")
        plt.close(fig)
    
    print("\nSentence context integration complete.")
    return {
        'sentences': sentences,
        'char_to_sentence': char_to_sentence,
        'sentence_data': sentence_data,
        'X_sequences': X_sequences,
        'y_sequences': y_sequences,
        'contexts': contexts
    }

if __name__ == "__main__":
    # Execute the integration
    results = integrate_sentence_context("I")
    
    # Access components
    sentence_data = results['sentence_data']
    
    print(f"\nSuccess! Model and sentence data ready for use.")

Processing data for Contributor I...
Data loaded. Trials: 85, Sampling rate: 120Hz
Extracted 1 sentences from 85 characters
Created mapping for 84 characters

Sample sentences:
Sentence 0: EAEVQTDOJG8RBRGONCEDHCTUIDBPUHMEM6OUXOCFOUKWA4VJEFRZROLHYNQDW EKTLBWXEPOUIKZERYOOTHQI

Created sentence-based dataset with 1 sentences

Created 84 contextual sequences with context size 2

Sample contexts:
Target: 'E', Context: 'AE', Sentence: 0, Position: 0
Target: 'A', Context: 'EEV', Sentence: 0, Position: 1
Target: 'E', Context: 'EAVQ', Sentence: 0, Position: 2

Training model...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (180x15360 and 1104x128)