In [1]:
# Dataloader for ZuCo 1.0 data set (SR)

In [16]:
import os
import numpy as np
from scipy.io import loadmat
import torch
from transformer_lens import HookedTransformer
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from tqdm import tqdm
from sklearn.multioutput import MultiOutputRegressor

import os
import pickle
import numpy as np
from pathlib import Path


class ZucoDataLoader:
    def __init__(self, data_dir='../zuco_data/zuco1.0/task1-SR/Matlab files'):
        self.data_dir = data_dir
        self.subject_files = self._get_subject_files()
        
    def _get_subject_files(self):
        subject_files = {}
        for file_name in os.listdir(self.data_dir):
            if file_name.endswith(".mat"):
                subject_id = file_name.split('.')[0]
                subject_files[subject_id] = os.path.join(self.data_dir, file_name)
        return subject_files
    
    def get_subject_ids(self):
        return list(self.subject_files.keys())
    
    def load_subject_data(self, subject_id):
        file_path = self.subject_files[subject_id]
        print(f"Loading data from {file_path}")
        data = loadmat(file_path, squeeze_me=True, struct_as_record=False)
        return data
    
    def extract_word_level_data(self, subject_id):
        """Extract word-level EEG data with sentence context"""
        data = self.load_subject_data(subject_id)
        sentences = data['sentenceData']
        
        # Store word-level data
        word_data = []
        
        for sent_idx, sentence in enumerate(sentences):
            try:
                # Check if words is iterable
                if not hasattr(sentence, 'word'):
                    print(f"Sentence {sent_idx} has no word attribute")
                    continue
                    
                words = sentence.word
                
                # Handle case where words is not iterable (e.g., a float)
                if not hasattr(words, '__iter__'):
                    print(f"Sentence {sent_idx} words is not iterable: {type(words)}")
                    continue
                    
                sentence_text = sentence.content if hasattr(sentence, 'content') else ""
                
                for word_idx, word in enumerate(words):
                    # Extract EEG features
                    eeg_features = {}
                    word_text = word.content if hasattr(word, 'content') else ""
                    
                    # Extract each frequency band
                    for feature in ['FFD', 'TRT', 'GD', 'GPT']:
                        for band in ['_t1', '_t2', '_a1', '_a2', '_b1', '_b2', '_g1', '_g2']:
                            feature_name = feature + band
                            if hasattr(word, feature_name):
                                eeg_features[feature_name] = getattr(word, feature_name)
                    
                    word_data.append({
                        'word': word_text,
                        'word_idx': word_idx,
                        'sentence_id': sent_idx,
                        'sentence': sentence_text,
                        'eeg_features': eeg_features
                    })
            except (AttributeError, IndexError, TypeError) as e:
                print(f"Error processing sentence {sent_idx}: {e}")
                continue
        
        return word_data

class EmbeddingGenerator:
    def __init__(self, model_name='gpt2-medium'):
        """Initialize with TransformerLens"""
        print(f"Loading model {model_name}...")
        self.model = HookedTransformer.from_pretrained(model_name)
        self.model.eval()
    
    def extract_embeddings_sliding_window(self, word_data):
        """
        Generate contextual embeddings using sliding window approach (Goldstein method)
        """
        # Group by sentence
        sentences = {}
        for word in word_data:
            sent_id = word['sentence_id']
            if sent_id not in sentences:
                sentences[sent_id] = {
                    'text': word['sentence'],
                    'words': []
                }
            sentences[sent_id]['words'].append(word)
        
        # Process each sentence
        embeddings = []
        
        for sent_id, sent_info in tqdm(sentences.items(), desc="Extracting embeddings"):
            sent_text = sent_info['text']
            words = sent_info['words']
            
            # Sort words by position
            words.sort(key=lambda x: x['word_idx'])
            
            # Process each word with its preceding context
            for i, word in enumerate(words):
                # Build context window (all words up to and including current)
                word_tokens = [w['word'] for w in words[:i+1]]
                context = " ".join(word_tokens)
                
                # Get activations for this context
                _, cache = self.model.run_with_cache(context)
                
                # Extract final layer activation for last token
                # This follows Goldstein's methodology
                final_layer_activations = cache['blocks.23.hook_resid_post'][0]
                word_embedding = final_layer_activations[-1].detach().cpu().numpy()
                
                embeddings.append({
                    'word': word['word'],
                    'sentence_id': sent_id,
                    'word_idx': word['word_idx'],
                    'embedding': word_embedding,
                    'eeg_features': word['eeg_features']
                })
        
        return embeddings

class BrainEmbeddingMapper:
    def __init__(self):
        """Linear mapper between embeddings and EEG"""
        self.models = {}
    
    def train_mapper(self, embeddings, feature_name='FFD_t1', n_splits=5):
        """Train a linear mapping between embeddings and EEG features"""
        # First, check dimensions of the feature
        sample_shapes = {}
        for item in embeddings:
            if feature_name in item['eeg_features']:
                feature = item['eeg_features'][feature_name]
                if hasattr(feature, 'shape'):
                    shape = feature.shape
                    if shape not in sample_shapes:
                        sample_shapes[shape] = 0
                    sample_shapes[shape] += 1
        
        print(f"Found {len(sample_shapes)} different shapes for {feature_name}")
        for shape, count in sample_shapes.items():
            print(f"  Shape {shape}: {count} samples")
        
        # Choose most common shape with non-zero dimensions
        valid_shapes = {shape: count for shape, count in sample_shapes.items() 
                    if shape and shape[0] > 0}
        
        if not valid_shapes:
            print(f"No valid shapes found for feature {feature_name}")
            return None
        
        target_shape = max(valid_shapes.items(), key=lambda x: x[1])[0]
        print(f"Using shape {target_shape} for training")
        
        # Filter to samples with consistent dimensions
        valid_embeddings = []
        valid_features = []
        
        for item in embeddings:
            if feature_name in item['eeg_features']:
                feature = item['eeg_features'][feature_name]
                if hasattr(feature, 'shape') and feature.shape == target_shape:
                    if not np.isnan(feature).any():
                        valid_embeddings.append(item['embedding'])
                        valid_features.append(feature)
        
        if len(valid_embeddings) < 10:  # Minimum samples for training
            print(f"Not enough valid samples after filtering")
            return None
        
        print(f"Training with {len(valid_embeddings)} samples")
        
        # Add regularization to handle ill-conditioned matrices
        alpha = 10.0  # Increase regularization strength
        
        # Convert to numpy arrays
        X = np.array(valid_embeddings)
        y = np.array(valid_features)
        
        # Cross-validation
        kf = KFold(n_splits=min(n_splits, len(X)), shuffle=True, random_state=42)
        results = []
        
        for train_idx, test_idx in kf.split(X):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]
            
            # Train model
            model = Ridge(alpha=alpha)
            model.fit(X_train, y_train)
            
            # Predict
            y_pred = model.predict(X_test)
            
            # Calculate correlation for each electrode
            correlations = []
            for i in range(y_test.shape[1]):
                if np.std(y_test[:, i]) > 0 and np.std(y_pred[:, i]) > 0:
                    corr = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]
                    correlations.append(corr)
            
            results.append({
                'correlations': correlations,
                'mean_correlation': np.mean(correlations)
            })
        
        self.models[feature_name] = {
            'model': model,
            'results': results
        }
        
        return results
    
    def train_multifeature_mapper(self, embeddings, features=None, n_splits=5):
        """Train a linear mapping with multiple features at once"""
        # Get all available features if none specified
        if not features:
            all_features = set()
            for item in embeddings:
                all_features.update(item['eeg_features'].keys())
            features = sorted(list(all_features))
        
        print(f"Training with {len(features)} features")
        
        # Find most common electrode count across features
        feature_shapes = {}
        for feature in features:
            shapes = {}
            for item in embeddings:
                if feature in item['eeg_features']:
                    arr = item['eeg_features'][feature]
                    if hasattr(arr, 'shape') and len(arr.shape) > 0:
                        shape = arr.shape
                        if shape not in shapes:
                            shapes[shape] = 0
                        shapes[shape] += 1
            
            if shapes:
                feature_shapes[feature] = max(shapes.items(), key=lambda x: x[1])[0]
        
        if not feature_shapes:
            print("No valid features found")
            return None
        
        # Filter to features with same electrode count
        valid_features = []
        target_shape = (105,)  # Standard electrode count
        
        for feature, shape in feature_shapes.items():
            if shape == target_shape:
                valid_features.append(feature)
        
        if not valid_features:
            print("No features with consistent electrode count")
            return None
        
        print(f"Using {len(valid_features)} features with {target_shape[0]} electrodes")
        
        # Collect valid samples
        valid_data = []
        
        for item in embeddings:
            sample = {
                'embedding': item['embedding'],
                'targets': {}
            }
            
            has_valid_data = False
            for feature in valid_features:
                if feature in item['eeg_features']:
                    arr = item['eeg_features'][feature]
                    if hasattr(arr, 'shape') and arr.shape == target_shape:
                        if not np.isnan(arr).any():
                            sample['targets'][feature] = arr
                            has_valid_data = True
            
            if has_valid_data:
                valid_data.append(sample)
        
        print(f"Found {len(valid_data)} samples with valid data")
        
        if len(valid_data) < 100:
            print("Not enough valid samples for training")
            return None
        
        # Create model for each feature
        results = {}
        
        for feature in valid_features:
            # Get samples with this feature
            feature_samples = []
            feature_targets = []
            
            for sample in valid_data:
                if feature in sample['targets']:
                    feature_samples.append(sample['embedding'])
                    feature_targets.append(sample['targets'][feature])
            
            if len(feature_samples) < 100:
                print(f"Skipping feature {feature}: not enough samples")
                continue
            
            X = np.array(feature_samples)
            y = np.array(feature_targets)
            
            print(f"Training model for {feature} with {len(X)} samples")
            
            # Cross-validation
            kf = KFold(n_splits=min(n_splits, len(X)), shuffle=True, random_state=42)
            feature_results = []
            
            for train_idx, test_idx in kf.split(X):
                X_train, X_test = X[train_idx], X[test_idx]
                y_train, y_test = y[train_idx], y[test_idx]
                
                # Train model with increased regularization
                model = Ridge(alpha=50.0)
                model.fit(X_train, y_train)
                
                # Predict
                y_pred = model.predict(X_test)
                
                # Calculate correlation for each electrode
                correlations = []
                for i in range(y_test.shape[1]):
                    if np.std(y_test[:, i]) > 0 and np.std(y_pred[:, i]) > 0:
                        corr = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]
                        correlations.append(corr)
                
                feature_results.append({
                    'correlations': correlations,
                    'mean_correlation': np.mean(correlations)
                })
            
            results[feature] = {
                'model': model,
                'results': feature_results
            }
        
        # Store models
        self.models.update(results)
        
        return results


    def train_all_features(self, embeddings, n_splits=5):
        """Train linear mappings for all available EEG features"""
        # Find all features that appear in the data
        all_features = set()
        for item in embeddings:
            all_features.update(item['eeg_features'].keys())
        
        print(f"Found {len(all_features)} features in the data")
        
        # Train a model for each feature
        results_by_feature = {}
        
        for feature_name in all_features:
            print(f"\nTraining model for {feature_name}")
            results = self.train_mapper(embeddings, feature_name=feature_name, n_splits=n_splits)
            
            if results:
                mean_corr = np.mean([fold['mean_correlation'] for fold in results])
                print(f"Mean correlation: {mean_corr:.4f}")
                results_by_feature[feature_name] = results
        
        return results_by_feature
    
    
    def extract_steering_vector(self, feature_name='FFD_t1', method='weighted', threshold=0.1):
        """
        Extract a steering vector using different methods:
        - 'weighted': Weight electrodes by correlation strength
        - 'top_n': Use only top N electrodes 
        - 'threshold': Use electrodes with correlation above threshold
        """
        if feature_name not in self.models:
            print(f"No model trained for feature {feature_name}")
            return None
            
        feature_data = self.models[feature_name]
        
        # Handle both single-feature and multi-feature formats
        if isinstance(feature_data, dict) and 'model' in feature_data:
            model = feature_data['model']
            results = feature_data['results']
        else:
            model = feature_data  # Original format
            results = self.models[feature_name]['results']
        
        weights = model.coef_.T  # [embedding_dim, n_electrodes]
        
        # Calculate correlation strength per electrode
        correlation_means = []
        for result in results:
            correlation_means.append(np.array(result['correlations']))
        electrode_correlations = np.mean(np.stack(correlation_means), axis=0)
        
        # Select electrodes based on method
        if method == 'weighted':
            # Weight each electrode by its correlation strength
            electrode_weights = np.abs(electrode_correlations)
            electrode_weights = electrode_weights / np.sum(electrode_weights)
            steering_vector = np.zeros(weights.shape[0])
            
            for i, weight in enumerate(electrode_weights):
                if not np.isnan(weight):
                    steering_vector += weight * weights[:, i]
                    
        elif method == 'top_n':
            n_electrodes = 10  # Default to top 10
            # Get top N electrodes by absolute correlation
            top_indices = np.argsort(np.abs(electrode_correlations))[-n_electrodes:]
            steering_vector = np.mean(weights[:, top_indices], axis=1)
            
        elif method == 'threshold':
            # Use electrodes above correlation threshold
            mask = np.abs(electrode_correlations) > threshold
            if not np.any(mask):
                print(f"No electrodes above threshold {threshold}")
                return None
            steering_vector = np.mean(weights[:, mask], axis=1)
        
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Normalize
        steering_vector = steering_vector / np.linalg.norm(steering_vector)
        return steering_vector  
    
    def extract_combined_steering_vector(self, method='weighted', threshold=0.1):
        """
        Extract a steering vector that combines information across multiple features,
        weighting each feature by its overall prediction performance.
        """
        # Find all available features
        available_features = [f for f in self.models.keys()]
        if not available_features:
            print("No models trained for any features")
            return None
            
        print(f"Combining steering vectors from features: {available_features}")
        
        # Get individual steering vectors for each feature
        feature_vectors = {}
        feature_scores = {}
        
        for feature in available_features:
            # Extract steering vector using existing method
            vector = self.extract_steering_vector(
                feature_name=feature,
                method=method,
                threshold=threshold
            )
            
            if vector is not None:
                feature_vectors[feature] = vector
                
                # Get average correlation score for this feature
                feature_data = self.models[feature]
                if isinstance(feature_data, dict) and 'results' in feature_data:
                    mean_corr = np.mean([np.mean(r['correlations']) for r in feature_data['results']])
                else:
                    mean_corr = np.mean([np.mean(r['correlations']) for r in feature_data['results']])
                
                feature_scores[feature] = mean_corr
        
        if not feature_vectors:
            print("No valid steering vectors extracted")
            return None
        
        # Weight features by their scores
        total_score = sum(feature_scores.values())
        weights = {f: score/total_score for f, score in feature_scores.items()}
        
        # Combine vectors (they should all have the same dimensionality)
        dim = len(next(iter(feature_vectors.values())))
        combined_vector = np.zeros(dim)
        
        for feature, vector in feature_vectors.items():
            combined_vector += weights[feature] * vector
        
        # Normalize
        combined_vector = combined_vector / np.linalg.norm(combined_vector)
        
        return combined_vector
    
    

In [3]:


# 1. Initialize the data loader
zuco_loader = ZucoDataLoader(data_dir='../zuco_data/zuco1.0/task1-SR/Matlab files')

# 2. Extract and save word-level data for all subjects
all_subjects_data = {}
subject_data_path = Path('saved_data/subject_word_data.pkl')
subject_data_path.parent.mkdir(parents=True, exist_ok=True)

if subject_data_path.exists():
    print(f"Loading subject data from {subject_data_path}")
    with open(subject_data_path, 'rb') as f:
        all_subjects_data = pickle.load(f)
else:
    for subject_id in zuco_loader.get_subject_ids():
        word_data = zuco_loader.extract_word_level_data(subject_id)
        all_subjects_data[subject_id] = word_data
        print(f"Extracted {len(word_data)} words from subject {subject_id}")
        
    # Save the results
    with open(subject_data_path, 'wb') as f:
        pickle.dump(all_subjects_data, f)
    print(f"Saved subject data to {subject_data_path}")

# 3. Generate embeddings with checkpoints
embeddings_path = Path('saved_data/embeddings.pkl')

if embeddings_path.exists():
    print(f"Loading embeddings from {embeddings_path}")
    with open(embeddings_path, 'rb') as f:
        sentence_embeddings = pickle.load(f)
else:
    # Gather unique sentences across all subjects
    unique_sentences = {}
    for subject_id, word_data in all_subjects_data.items():
        for word in word_data:
            sent_id = word['sentence_id']
            if sent_id not in unique_sentences:
                unique_sentences[sent_id] = word['sentence']
    
    # Initialize embeddings generator
    embedding_gen = EmbeddingGenerator(model_name='gpt2-medium')
    sentence_embeddings = {}
    
    # Generate embeddings with checkpoints
    checkpoint_path = Path('saved_data/embeddings_checkpoint.pkl')
    
    try:
        # If checkpoint exists, load it
        if checkpoint_path.exists():
            with open(checkpoint_path, 'rb') as f:
                sentence_embeddings = pickle.load(f)
            print(f"Loaded checkpoint with {len(sentence_embeddings)} sentences")
        
        # Process remaining sentences
        remaining_sentences = {k: v for k, v in unique_sentences.items() 
                               if k not in sentence_embeddings}
        
        for i, (sent_id, sentence) in enumerate(remaining_sentences.items()):
            # Create dummy word data structure for the embeddings function
            words = sentence.split()
            dummy_words = [{'word': word, 'word_idx': i, 'sentence_id': sent_id, 
                           'sentence': sentence, 'eeg_features': {}} 
                           for i, word in enumerate(words)]
            
            embeddings = embedding_gen.extract_embeddings_sliding_window(dummy_words)
            sentence_embeddings[sent_id] = {e['word_idx']: e['embedding'] for e in embeddings}
            
            # Save checkpoint every 50 sentences
            if (i + 1) % 50 == 0:
                with open(checkpoint_path, 'wb') as f:
                    pickle.dump(sentence_embeddings, f)
                print(f"Saved checkpoint after {i+1}/{len(remaining_sentences)} sentences")
                
        # Save final embeddings
        with open(embeddings_path, 'wb') as f:
            pickle.dump(sentence_embeddings, f)
        print(f"Saved embeddings to {embeddings_path}")
        
        # Remove checkpoint file
        if checkpoint_path.exists():
            os.remove(checkpoint_path)
            
    except Exception as e:
        # Save checkpoint on error
        print(f"Error during embedding generation: {e}")
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(sentence_embeddings, f)
        print(f"Saved checkpoint to {checkpoint_path}")
        raise e


Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZKB_SR.mat
Extracted 7129 words from subject resultsZKB_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZDM_SR.mat
Extracted 7129 words from subject resultsZDM_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZJN_SR.mat
Extracted 7129 words from subject resultsZJN_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZAB_SR.mat
Extracted 7129 words from subject resultsZAB_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZKH_SR.mat
Extracted 7129 words from subject resultsZKH_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZMG_SR.mat
Extracted 7129 words from subject resultsZMG_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZGW_SR.mat
Extracted 7129 words from subject resultsZGW_SR
Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZKW_SR.mat
Extracted 7129 words from subjec

Extracting embeddings:   0%|          | 0/1 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.97s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.49s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.72s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  2.51it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.47s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.68s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.67s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.32s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.19s/it]
Extracting embeddings: 100%|████████

Saved checkpoint after 50/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  2.02it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.10s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.18s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.23it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.19s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.37it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.32s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  2.53it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.05it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.53s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.80s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s]
Extracting embeddings: 100%|██████

Saved checkpoint after 100/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.00s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.12it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.62s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.96s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.36s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.54s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.53s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.67s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.02s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.07s/it]
Extracting embeddings: 100%|██████

Saved checkpoint after 150/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.78s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:04<00:00,  4.60s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.11s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  2.32it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.35s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.14s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.34s/it]
Extracting embeddings: 100%|██████

Saved checkpoint after 200/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.24s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  2.08it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.28s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.75s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.97s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.52s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.30s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]
Extracting embeddings: 100%|██████

Saved checkpoint after 250/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.64s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.92s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.15s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.60s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.52s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.31s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.55s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.09s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.48s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.90s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.07s/it]
Extracting embeddings: 100%|██████

Saved checkpoint after 300/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.06it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.49s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.17s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.77s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.38s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.44s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.50s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.91s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.35s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.98s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.29s/it]
Extracting embeddings: 100%|██████

Saved checkpoint after 350/400 sentences


Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.69it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.06s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.14s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:02<00:00,  2.81s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.32s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.78s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:03<00:00,  3.93s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.35s/it]
Extracting embeddings: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
Extracting embeddings: 100%|██████

Saved checkpoint after 400/400 sentences
Saved embeddings to saved_data/embeddings.pkl





# linear mapping

In [4]:

# 4. Linear mapping with save points
results_path = Path('saved_data/mapping_results.pkl')

if results_path.exists():
    print(f"Loading mapping results from {results_path}")
    with open(results_path, 'rb') as f:
        results_by_subject = pickle.load(f)
else:
    mapper = BrainEmbeddingMapper()
    results_by_subject = {}
    
    for subject_id, word_data in all_subjects_data.items():
        subject_result_path = Path(f'saved_data/subject_{subject_id}_results.pkl')
        
        if subject_result_path.exists():
            with open(subject_result_path, 'rb') as f:
                results = pickle.load(f)
            results_by_subject[subject_id] = results
            print(f"Loaded results for subject {subject_id}")
            continue
            
        # Combine word data with embeddings
        combined_data = []
        for word in word_data:
            sent_id = word['sentence_id']
            word_idx = word['word_idx']
            
            if sent_id in sentence_embeddings and word_idx in sentence_embeddings[sent_id]:
                combined_data.append({
                    'word': word['word'],
                    'sentence_id': sent_id,
                    'word_idx': word_idx,
                    'embedding': sentence_embeddings[sent_id][word_idx],
                    'eeg_features': word['eeg_features']
                })
        
        # Train linear mapping for this subject
        # results = mapper.train_mapper(combined_data, feature_name='FFD_t1')
        results = mapper.train_multifeature_mapper(combined_data, n_splits=5)
        results_by_subject[subject_id] = results
        
        # Save subject results
        with open(subject_result_path, 'wb') as f:
            pickle.dump(results, f)
        print(f"Saved results for subject {subject_id}")
        
        # Print results for this subject
        if results:
            # New structure is a dictionary by feature
            feature_means = []
            for feature, feature_results in results.items():
                feature_mean = np.mean([fold['mean_correlation'] for fold in feature_results['results']])
                feature_means.append(feature_mean)
                print(f"Feature {feature}: Mean correlation = {feature_mean:.4f}")
            
            overall_mean = np.mean(feature_means)
            print(f"Subject {subject_id}: Overall mean correlation = {overall_mean:.4f}")
    
    # Save all results
    with open(results_path, 'wb') as f:
        pickle.dump(results_by_subject, f)
    print(f"Saved all mapping results to {results_path}")

Training with 32 features
Using 32 features with 105 electrodes
Found 4829 samples with valid data
Training model for FFD_a1 with 4828 samples
Training model for FFD_a2 with 4828 samples
Training model for FFD_b1 with 4828 samples
Training model for FFD_b2 with 4828 samples
Training model for FFD_g1 with 4828 samples
Training model for FFD_g2 with 4828 samples
Training model for FFD_t1 with 4828 samples
Training model for FFD_t2 with 4828 samples
Training model for GD_a1 with 4829 samples
Training model for GD_a2 with 4829 samples
Training model for GD_b1 with 4829 samples
Training model for GD_b2 with 4829 samples
Training model for GD_g1 with 4829 samples
Training model for GD_g2 with 4829 samples
Training model for GD_t1 with 4829 samples
Training model for GD_t2 with 4829 samples
Training model for GPT_a1 with 4829 samples
Training model for GPT_a2 with 4829 samples
Training model for GPT_b1 with 4829 samples
Training model for GPT_b2 with 4829 samples
Training model for GPT_g1 wit

  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_a2 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_b1 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_b2 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_g1 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_g2 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_t1 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_t2 with 5646 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_a1 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_a2 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_b1 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_b2 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_g1 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_g2 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_t1 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_t2 with 5648 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_a1 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_a2 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_b1 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_b2 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_g1 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_g2 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_t1 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_t2 with 5652 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a1 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a2 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b1 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b2 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g1 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g2 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t1 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t2 with 5650 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Saved results for subject resultsZJN_SR
Feature FFD_a1: Mean correlation = 0.1108
Feature FFD_a2: Mean correlation = 0.1188
Feature FFD_b1: Mean correlation = 0.1173
Feature FFD_b2: Mean correlation = 0.1513
Feature FFD_g1: Mean correlation = 0.1874
Feature FFD_g2: Mean correlation = 0.1914
Feature FFD_t1: Mean correlation = 0.0926
Feature FFD_t2: Mean correlation = 0.1000
Feature GD_a1: Mean correlation = 0.1154
Feature GD_a2: Mean correlation = 0.1246
Feature GD_b1: Mean correlation = 0.1240
Feature GD_b2: Mean correlation = 0.1577
Feature GD_g1: Mean correlation = 0.1961
Feature GD_g2: Mean correlation = 0.1976
Feature GD_t1: Mean correlation = 0.0977
Feature GD_t2: Mean correlation = 0.0986
Feature GPT_a1: Mean correlation = 0.1215
Feature GPT_a2: Mean correlation = 0.1362
Feature GPT_b1: Mean correlation = 0.1452
Feature GPT_b2: Mean correlation = 0.1672
Feature GPT_g1: Mean correlation = 0.2083
Feature GPT_g2: Mean correlation = 0.2200
Feature GPT_t1: Mean correlation = 0.1170
Fe

  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a2 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b1 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b2 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g1 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g2 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t1 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t2 with 5095 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Saved results for subject resultsZAB_SR
Feature FFD_a1: Mean correlation = 0.1122
Feature FFD_a2: Mean correlation = 0.1166
Feature FFD_b1: Mean correlation = 0.1232
Feature FFD_b2: Mean correlation = 0.1701
Feature FFD_g1: Mean correlation = 0.1953
Feature FFD_g2: Mean correlation = 0.1970
Feature FFD_t1: Mean correlation = 0.0882
Feature FFD_t2: Mean correlation = 0.1064
Feature GD_a1: Mean correlation = 0.1167
Feature GD_a2: Mean correlation = 0.1232
Feature GD_b1: Mean correlation = 0.1289
Feature GD_b2: Mean correlation = 0.1801
Feature GD_g1: Mean correlation = 0.1994
Feature GD_g2: Mean correlation = 0.2033
Feature GD_t1: Mean correlation = 0.0944
Feature GD_t2: Mean correlation = 0.1096
Feature GPT_a1: Mean correlation = 0.1228
Feature GPT_a2: Mean correlation = 0.1269
Feature GPT_b1: Mean correlation = 0.1410
Feature GPT_b2: Mean correlation = 0.1936
Feature GPT_g1: Mean correlation = 0.2175
Feature GPT_g2: Mean correlation = 0.2144
Feature GPT_t1: Mean correlation = 0.1088
Fe

  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_a2 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_b1 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_b2 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_g1 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_g2 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_t1 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_t2 with 5240 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_a1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_a2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_b1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_b2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_g1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_g2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_t1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_t2 with 5241 samples
Training model for GPT_a1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_a2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_b1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_b2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_g1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_g2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_t1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_t2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t1 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t2 with 5241 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Saved results for subject resultsZKW_SR
Feature FFD_a1: Mean correlation = 0.0580
Feature FFD_a2: Mean correlation = 0.0759
Feature FFD_b1: Mean correlation = 0.0849
Feature FFD_b2: Mean correlation = 0.0917
Feature FFD_g1: Mean correlation = 0.0998
Feature FFD_g2: Mean correlation = 0.1104
Feature FFD_t1: Mean correlation = 0.0481
Feature FFD_t2: Mean correlation = 0.0628
Feature GD_a1: Mean correlation = 0.0624
Feature GD_a2: Mean correlation = 0.0812
Feature GD_b1: Mean correlation = 0.0883
Feature GD_b2: Mean correlation = 0.0980
Feature GD_g1: Mean correlation = 0.1097
Feature GD_g2: Mean correlation = 0.1234
Feature GD_t1: Mean correlation = 0.0543
Feature GD_t2: Mean correlation = 0.0655
Feature GPT_a1: Mean correlation = 0.0710
Feature GPT_a2: Mean correlation = 0.0898
Feature GPT_b1: Mean correlation = 0.1104
Feature GPT_b2: Mean correlation = 0.1095
Feature GPT_g1: Mean correlation = 0.1264
Feature GPT_g2: Mean correlation = 0.1315
Feature GPT_t1: Mean correlation = 0.0694
Fe

  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_a2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_b1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_b2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_g1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_g2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_t1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for FFD_t2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_a1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_a2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_b1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_b2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_g1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_g2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_t1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GD_t2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_a1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_a2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_b1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_b2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_g1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_g2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_t1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for GPT_t2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_a2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_b2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_g2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t1 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Training model for TRT_t2 with 5476 samples


  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
  return linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T


Saved results for subject resultsZJM_SR
Feature FFD_a1: Mean correlation = 0.1000
Feature FFD_a2: Mean correlation = 0.1039
Feature FFD_b1: Mean correlation = 0.1308
Feature FFD_b2: Mean correlation = 0.1742
Feature FFD_g1: Mean correlation = 0.2003
Feature FFD_g2: Mean correlation = 0.2087
Feature FFD_t1: Mean correlation = 0.0929
Feature FFD_t2: Mean correlation = 0.0903
Feature GD_a1: Mean correlation = 0.1041
Feature GD_a2: Mean correlation = 0.1068
Feature GD_b1: Mean correlation = 0.1379
Feature GD_b2: Mean correlation = 0.1831
Feature GD_g1: Mean correlation = 0.2047
Feature GD_g2: Mean correlation = 0.2138
Feature GD_t1: Mean correlation = 0.1019
Feature GD_t2: Mean correlation = 0.0956
Feature GPT_a1: Mean correlation = 0.1149
Feature GPT_a2: Mean correlation = 0.1247
Feature GPT_b1: Mean correlation = 0.1511
Feature GPT_b2: Mean correlation = 0.1954
Feature GPT_g1: Mean correlation = 0.2187
Feature GPT_g2: Mean correlation = 0.2275
Feature GPT_t1: Mean correlation = 0.1177
Fe

# VISUALIZE (by Subject)

In [5]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path

# # Assuming results_by_subject has been loaded
# results_path = Path('saved_data/mapping_results.pkl')
# with open(results_path, 'rb') as f:
#     results_by_subject = pickle.load(f)

# # 1. Bar plot of mean correlation by subject
# plt.figure(figsize=(10, 6))
# subject_means = {}
# for subject_id, results in results_by_subject.items():
#     if results:  # Check if results exist
#         mean_corr = np.mean([fold['mean_correlation'] for fold in results])
#         subject_means[subject_id] = mean_corr

# # Sort by correlation value
# sorted_subjects = sorted(subject_means.items(), key=lambda x: x[1], reverse=True)
# subjects = [s[0] for s in sorted_subjects]
# means = [s[1] for s in sorted_subjects]

# plt.bar(subjects, means)
# plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
# plt.xlabel('Subject ID')
# plt.ylabel('Mean Correlation')
# plt.title('Mean Correlation by Subject')
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.savefig('subject_correlations.png')
# plt.show()

# # 2. Distribution of correlations across electrodes
# plt.figure(figsize=(12, 6))
# all_correlations = []
# all_subjects = []

# for subject_id, results in results_by_subject.items():
#     if results:
#         for fold in results:
#             all_correlations.extend(fold['correlations'])
#             all_subjects.extend([subject_id] * len(fold['correlations']))

# # Create dataframe for seaborn
# import pandas as pd
# corr_df = pd.DataFrame({
#     'Correlation': all_correlations,
#     'Subject': all_subjects
# })

# sns.violinplot(x='Subject', y='Correlation', data=corr_df)
# plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
# plt.title('Distribution of Electrode Correlations by Subject')
# plt.xticks(rotation=45)
# plt.tight_layout()
# plt.savefig('electrode_correlations.png')
# plt.show()

# # 3. Heatmap of correlations for one subject
# # Choose the subject with highest mean correlation
# best_subject = sorted_subjects[0][0]
# best_results = results_by_subject[best_subject]

# # Aggregate correlations across folds
# electrode_means = np.zeros(105)  # Assuming 105 electrodes
# for fold in best_results:
#     for i, corr in enumerate(fold['correlations']):
#         electrode_means[i] += corr
# electrode_means /= len(best_results)

# # Reshape to approximate head layout (simplified)
# # Adjust these dimensions based on actual electrode layout
# reshaped_corrs = electrode_means.reshape(15, 7)

# plt.figure(figsize=(8, 12))
# sns.heatmap(reshaped_corrs, cmap='RdBu_r', center=0, 
#             vmin=-0.3, vmax=0.3)
# plt.title(f'Electrode Correlation Heatmap - Subject {best_subject}')
# plt.tight_layout()
# plt.savefig('electrode_heatmap.png')
# plt.show()

# # 4. Cross-subject consistency
# # Create a matrix of correlations for each electrode across subjects
# all_subject_ids = list(results_by_subject.keys())
# electrode_by_subject = np.zeros((len(all_subject_ids), 105))

# for i, subject_id in enumerate(all_subject_ids):
#     results = results_by_subject[subject_id]
#     if results:
#         # Average across folds
#         fold_means = np.zeros(105)
#         for fold in results:
#             for j, corr in enumerate(fold['correlations']):
#                 fold_means[j] += corr
#         fold_means /= len(results)
#         electrode_by_subject[i] = fold_means

# # Calculate correlation between subjects
# subject_correlation = np.corrcoef(electrode_by_subject)

# plt.figure(figsize=(10, 8))
# sns.heatmap(subject_correlation, annot=True, cmap='coolwarm', 
#             xticklabels=all_subject_ids, yticklabels=all_subject_ids)
# plt.title('Cross-subject Consistency')
# plt.tight_layout()
# plt.savefig('cross_subject_consistency.png')
# plt.show()

In [42]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from pathlib import Path

##@title PLOTTING PARAMETERS (defaults)

# Set the default font size for plots
text_size = 20

figure_width = 4 # for single column
page_width = 8 # for two column

plt.rcParams.update({
    'font.size': text_size,        # Base font size
    'axes.titlesize': text_size,   # Title
    'axes.labelsize': text_size,   # Axis labels
    'xtick.labelsize': text_size,  # X tick labels
    'ytick.labelsize': text_size,  # Y tick labels
    'legend.fontsize': text_size,  # Legend
    'figure.titlesize': text_size,  # Figure title
    'lines.linewidth': 2.0,      # Line width
    'lines.markersize': 3.0,     # Marker size
})
# Axis style (get rid of top and right)
plt.rcParams['axes.spines.top'] = False # remove the top line
plt.rcParams['axes.spines.right'] = False # remove the right line

# No Grid
plt.rcParams['axes.grid'] = False

plt.rcParams['lines.markersize'] = 10 


def visualize_results(results_by_subject):
    """Visualize results from our mapping approach"""
    
    # Create output directory
    output_dir = Path('visualizations')
    output_dir.mkdir(exist_ok=True)
    
    # Prepare data
    data = []
    
    for subject_id, subject_results in results_by_subject.items():
        # Handle the case where results is a list of fold results (original mapper)
        if isinstance(subject_results, list):
            feature = 'FFD_t1'  # Default feature name
            for fold_idx, fold in enumerate(subject_results):
                mean_corr = fold['mean_correlation']
                for i, corr in enumerate(fold['correlations']):
                    data.append({
                        'Subject': subject_id,
                        'Feature': feature,
                        'Fold': fold_idx,
                        'Electrode': i,
                        'Correlation': corr
                    })
        # Handle dictionary of features (multi-feature mapper)
        elif isinstance(subject_results, dict):
            for feature, feature_data in subject_results.items():
                if 'results' in feature_data:
                    for fold_idx, fold in enumerate(feature_data['results']):
                        mean_corr = fold['mean_correlation']
                        for i, corr in enumerate(fold['correlations']):
                            data.append({
                                'Subject': subject_id,
                                'Feature': feature,
                                'Fold': fold_idx,
                                'Electrode': i,
                                'Correlation': corr
                            })
    
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Subject performance
    plt.figure(figsize=(12, 6))
    subject_means = df.groupby('Subject')['Correlation'].mean().sort_values(ascending=False)
    
    plt.bar(subject_means.index, subject_means.values)
    plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
    plt.xlabel('Subject ID')
    plt.ylabel('Mean Correlation')
    plt.title('Mean Correlation by Subject')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_dir / 'subject_correlations.png')
    plt.close()
    
    # If multiple features are present
    if len(df['Feature'].unique()) > 1:
        # Feature performance
        plt.figure(figsize=(12, 6))
        feature_means = df.groupby('Feature')['Correlation'].mean().sort_values(ascending=False)
        
        plt.bar(feature_means.index, feature_means.values)
        plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
        plt.xlabel('EEG Feature')
        plt.ylabel('Mean Correlation')
        plt.title('Mean Correlation by EEG Feature')
        plt.xticks(rotation=90)
        plt.tight_layout()
        plt.savefig(output_dir / 'feature_correlations.png')
        plt.close()
    
    # Distribution of correlations across electrodes
    plt.figure(figsize=(12, 6))
    sns.violinplot(x='Subject', y='Correlation', data=df)
    plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
    plt.title('Distribution of Electrode Correlations by Subject')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(output_dir / 'electrode_correlations.png')
    plt.close()
    
    # Cross-subject consistency
    plt.figure(figsize=(10, 8))
    subjects = sorted(df['Subject'].unique())
    corr_matrix = np.zeros((len(subjects), len(subjects)))
    np.fill_diagonal(corr_matrix, 1.0)  # Self-correlation = 1
    
    for i, subj1 in enumerate(subjects):
        for j, subj2 in enumerate(subjects):
            if i < j:  # Only calculate once for each pair
                s1 = df[df['Subject'] == subj1].groupby('Electrode')['Correlation'].mean()
                s2 = df[df['Subject'] == subj2].groupby('Electrode')['Correlation'].mean()
                
                # Find common electrodes
                common_elec = list(set(s1.index) & set(s2.index))
                if common_elec:
                    corr = np.corrcoef(s1[common_elec], s2[common_elec])[0, 1]
                    corr_matrix[i, j] = corr
                    corr_matrix[j, i] = corr  # Matrix is symmetric
    
    sns.heatmap(corr_matrix, xticklabels=subjects, yticklabels=subjects, 
               cmap='coolwarm', vmin=-1, vmax=1, annot=True, fmt='.2f')
    plt.title('Cross-subject Consistency')
    plt.tight_layout()
    plt.savefig(output_dir / 'subject_consistency.png')
    plt.close()
    
    return df

In [43]:
viz_data = visualize_results(results_by_subject)



# steering (with eeg features)

In [None]:
def apply_steering_vector(embedding_model, text, steering_vector, scale=1.0):
    # Check if steering vector is None
    if steering_vector is None:
        print("Warning: Steering vector is None. Proceeding without steering.")
        # Just run the model normally without hooks
        return embedding_model(text)
    
    # Define the hook function
    def steering_hook(acts, hook):
        # Apply to final layer activations
        if acts.shape[-1] == len(steering_vector):
            # Project activations onto steering direction and amplify
            projection = torch.matmul(
                acts, 
                torch.tensor(steering_vector, dtype=acts.dtype, device=acts.device)
            )
            
            # Apply steering by adding scaled projection
            return acts + scale * projection.unsqueeze(-1) * torch.tensor(
                steering_vector, dtype=acts.dtype, device=acts.device
            )
        return acts
    
    # Run model with hook
    output = embedding_model.run_with_hooks(
        text,
        fwd_hooks=[("blocks.23.hook_resid_post", steering_hook)]
    )
    
    return output


def generate_with_steering(model, prompt, steering_vector, scale=1.0, max_new_tokens=20):
    """Generate text with a steering vector applied during generation"""
    
    # Define the hook function
    def steering_hook(acts, hook):
        # Apply to final layer activations
        if acts.shape[-1] == len(steering_vector):
            # Project activations onto steering direction and amplify
            projection = torch.matmul(
                acts, 
                torch.tensor(steering_vector, dtype=acts.dtype, device=acts.device)
            )
            
            # Apply steering by adding scaled projection
            return acts + scale * projection.unsqueeze(-1) * torch.tensor(
                steering_vector, dtype=acts.dtype, device=acts.device
            )
        return acts
    
    # Run model with hook for generation
    output = model.generate(
        prompt, 
        max_new_tokens=max_new_tokens,
        fwd_hooks=[("blocks.23.hook_resid_post", steering_hook)]
    )
    
    return output


In [27]:
def test_steering_vector(mapper, results_by_subject, embedding_gen):
    # Check the models dictionary
    if hasattr(mapper, 'models') and isinstance(mapper.models, dict):
        print(f"Available features: {list(mapper.models.keys())}")
        
        # Try to extract steering vector for a standard feature
        feature_name = "FFD_t1"  # Standard feature to try
        print(f"Attempting to extract steering vector for feature: {feature_name}")
        
        # Extract steering vector using the existing method
        steering_vector = mapper.extract_steering_vector(
            feature_name=feature_name,
            method='weighted',
            threshold=0.1
        )
        
        if steering_vector is None:
            print(f"Failed to extract steering vector for feature {feature_name}")
            return None
        
        # Test the steering vector
        print(f"Testing with: 'The experiment results indicated that'")
        sentence = "The experiment results indicated that"
        
        # Generate without steering for comparison
        original_output = embedding_gen.model.generate(sentence, max_new_tokens=20)
        print(f"Original: {original_output}")
        
        return steering_vector
    else:
        print("Mapper does not have expected 'models' attribute")
        return None
    

def test_steering_vectors(mapper, embedding_gen):
    """Test steering vectors on various prompts"""
    
    # Get a feature for steering
    feature_name = "FFD_t1"  # First fixation duration, theta band
    
    # Extract steering vector
    steering_vector = mapper.extract_steering_vector(
        feature_name=feature_name,
        method='weighted',
        threshold=0.1
    )
    
    if steering_vector is None:
        print(f"Failed to extract steering vector for feature {feature_name}")
        return
    
    # Test prompts
    test_prompts = [
        "The experiment results indicated that",
        "The brain activity during reading showed",
        "Analysis of the EEG data revealed"
    ]
    
    # Test with different steering scales
    for prompt in test_prompts:
        print(f"\nPrompt: {prompt}")
        
        # Without steering
        print("No steering:")
        text = generate_with_steering(
            embedding_gen.model, prompt, steering_vector, scale=0.0)
        print(text)
        
        # With different steering scales
        for scale in [0.5, 1.0, 2.0]:
            print(f"\nSteering with scale={scale}:")
            text = generate_with_steering(
                embedding_gen.model, prompt, steering_vector, scale=scale)
            print(text)
        
        print("\n" + "-"*50)
    
def compare_feature_steering(mapper, embedding_gen):
    """Compare steering vectors from different EEG features"""
    
    # List of features to compare
    features = [
        "FFD_t1",  # First fixation duration, theta band
        "FFD_a1",  # First fixation duration, alpha band
        "FFD_g1",  # First fixation duration, gamma band
        "TRT_t1",  # Total reading time, theta band
    ]
    
    # Extract steering vectors
    vectors = {}
    for feature in features:
        vector = mapper.extract_steering_vector(
            feature_name=feature,
            method='weighted',
            threshold=0.1
        )
        if vector is not None:
            vectors[feature] = vector
    
    if not vectors:
        print("Failed to extract any steering vectors")
        return
    
    # Test prompt
    prompt = "The neural activity recorded in this experiment"
    
    # Generate without steering
    print(f"Prompt: {prompt}")
    print("\nNo steering:")
    text = generate_with_steering(
        embedding_gen.model, prompt, next(iter(vectors.values())), scale=0.0)
    print(text)
    
    # Compare different features with the same scale
    scale = 1.0
    for feature, vector in vectors.items():
        print(f"\nSteering with {feature} (scale={scale}):")
        text = generate_with_steering(
            embedding_gen.model, prompt, vector, scale=scale)
        print(text)

def generate_with_steering(model, prompt, steering_vector, scale=1.0, max_new_tokens=20):
    """Generate text with a steering vector applied during generation"""
    
    # Define the hook function
    def steering_hook(acts, hook):
        # Apply to final layer activations
        if acts.shape[-1] == len(steering_vector):
            # Project activations onto steering direction and amplify
            projection = torch.matmul(
                acts, 
                torch.tensor(steering_vector, dtype=acts.dtype, device=acts.device)
            )
            
            # Apply steering by adding scaled projection
            return acts + scale * projection.unsqueeze(-1) * torch.tensor(
                steering_vector, dtype=acts.dtype, device=acts.device
            )
        return acts
    
    # TransformerLens uses a context manager for hooks during generation
    with model.hooks(fwd_hooks=[("blocks.23.hook_resid_post", steering_hook)]):
        output = model.generate(prompt, max_new_tokens=max_new_tokens)
    
    return output

In [28]:
test_steering_vectors(mapper, embedding_gen)


Prompt: The experiment results indicated that
No steering:


100%|██████████| 20/20 [00:03<00:00,  5.93it/s]


The experiment results indicated that training increased acute rewarding effects (i.e., desensitization); thus, training activations

Steering with scale=0.5:


100%|██████████| 20/20 [00:01<00:00, 13.04it/s]


The experiment results indicated that horizontal frequency-restricted survival differed according to whether cellular swelling was induced by 2D laser filtering via N

Steering with scale=1.0:


100%|██████████| 20/20 [00:01<00:00, 12.70it/s]


The experiment results indicated that 75 million teens admitted to trouble with alcohol, drugs or alcohol use were 15 to 18 years old.

Steering with scale=2.0:


100%|██████████| 20/20 [00:01<00:00, 13.08it/s]


The experiment results indicated that young women who wear short dreadlocks remain attractive even if they are the same height and weight as their

--------------------------------------------------

Prompt: The brain activity during reading showed
No steering:


100%|██████████| 20/20 [00:01<00:00, 12.48it/s]


The brain activity during reading showed a temporal pattern of decreased brain activity during intact mothers with no expressive behavior or as children. A further

Steering with scale=0.5:


100%|██████████| 20/20 [00:01<00:00, 12.56it/s]


The brain activity during reading showed signs of improvement upon leaving the location of the septum and emerged brighter than the area when watching

Steering with scale=1.0:


100%|██████████| 20/20 [00:01<00:00, 12.97it/s]


The brain activity during reading showed that the first pages of each e-book was considered more intense than when reading separate volumes of the

Steering with scale=2.0:


100%|██████████| 20/20 [00:01<00:00, 12.84it/s]


The brain activity during reading showed clear peak in the dorsolateral prefrontal cortex (DLPFC), indicate increased activation, compared to

--------------------------------------------------

Prompt: Analysis of the EEG data revealed
No steering:


100%|██████████| 20/20 [00:01<00:00, 13.17it/s]


Analysis of the EEG data revealed that the EEG had remarkably consistent before pulses and after pulses. All the fields were fully diurnal,

Steering with scale=0.5:


100%|██████████| 20/20 [00:01<00:00, 13.04it/s]


Analysis of the EEG data revealed that the authors again gave rise to a dolphin eating something after the feast. This was an indication of

Steering with scale=1.0:


100%|██████████| 20/20 [00:01<00:00, 13.22it/s]


Analysis of the EEG data revealed changes that appeared to be significant and related to the projectors. Importantly, spectroscopic analysis

Steering with scale=2.0:


100%|██████████| 20/20 [00:01<00:00, 12.72it/s]

Analysis of the EEG data revealed a tension between the activity in the PCRM and the subthalamic nucleus (STN), which

--------------------------------------------------





# sentiment steering

In [38]:
def get_direct_sentiment_vector(embedding_gen):
    """Extract sentiment vector directly from language model"""
    # Define positive and negative examples
    positive_examples = ["happy", "joy", "excellent", "wonderful", "fantastic"]
    negative_examples = ["sad", "terrible", "awful", "horrible", "disappointing"]
    
    # Get embeddings
    pos_embeds = []
    for word in positive_examples:
        _, cache = embedding_gen.model.run_with_cache(word)
        embed = cache['blocks.23.hook_resid_post'][0, -1].detach().cpu().numpy()
        pos_embeds.append(embed)
    
    neg_embeds = []
    for word in negative_examples:
        _, cache = embedding_gen.model.run_with_cache(word)
        embed = cache['blocks.23.hook_resid_post'][0, -1].detach().cpu().numpy()
        neg_embeds.append(embed)
    
    # Average embeddings
    pos_centroid = np.mean(pos_embeds, axis=0)
    neg_centroid = np.mean(neg_embeds, axis=0)
    
    # Sentiment vector
    sentiment_vector = pos_centroid - neg_centroid
    sentiment_vector = sentiment_vector / np.linalg.norm(sentiment_vector)
    
    return sentiment_vector

def test_sentiment_steering(embedding_gen):
    # Get sentiment vector
    sentiment_vector = get_direct_sentiment_vector(embedding_gen)
    
    # Test prompts
    test_prompts = [
        "The story was", 
        "The experience was", 
        "The movie made me feel"
    ]
    
    for prompt in test_prompts:
        print(f"\nPrompt: {prompt}")
        
        # Without steering
        print("No steering:")
        text = generate_with_steering(embedding_gen.model, prompt, sentiment_vector, scale=0.0)
        print(text)
        
        # With positive steering
        for scale in [1.0, 2.0, 4.0]:
            print(f"\nPositive steering (scale={scale}):")
            text = generate_with_steering(embedding_gen.model, prompt, sentiment_vector, scale=scale)
            print(text)
        
        # With negative steering
        for scale in [1.0, 2.0, 4.0]:
            print(f"\nNegative steering (scale={scale}):")
            text = generate_with_steering(embedding_gen.model, prompt, -sentiment_vector, scale=scale)
            print(text)

In [39]:
test_sentiment_steering(embedding_gen)






Prompt: The story was
No steering:


100%|██████████| 20/20 [00:01<00:00, 11.04it/s]


The story was told about how my laptop instructor went furnitureless, had the storage device fall off the front side,

Positive steering (scale=1.0):


100%|██████████| 20/20 [00:01<00:00, 12.58it/s]


The story was fast becoming known. At around of 1pm on April 22, 2012, shortly before the report was

Positive steering (scale=2.0):


100%|██████████| 20/20 [00:01<00:00, 12.30it/s]


The story was another distraction in a voracious Democrat-crafted media campaign to discredit the president at every turn,

Positive steering (scale=4.0):


100%|██████████| 20/20 [00:01<00:00, 13.13it/s]


The story was broken well by PLOS Quest member Joelle Walker, who wrote a short blog post describing a paper

Negative steering (scale=1.0):


100%|██████████| 20/20 [00:01<00:00, 12.94it/s]


The story was one of impact and uncertainty. The news of technology's advance was mostly greeted with equal celebration.


Negative steering (scale=2.0):


100%|██████████| 20/20 [00:01<00:00, 13.08it/s]


The story was about how a young girl runs away from home to live a way, a space free from violence.

Negative steering (scale=4.0):


100%|██████████| 20/20 [00:01<00:00, 13.25it/s]


The story was told not by a lecherous historian, but by a favorite among patriots — he was honored

Prompt: The experience was
No steering:


100%|██████████| 20/20 [00:01<00:00, 13.25it/s]


The experience was much more uplifting. It was my first time feeling isolated with charities. I was so in charge

Positive steering (scale=1.0):


100%|██████████| 20/20 [00:01<00:00, 13.24it/s]


The experience was completely new to him. Vanver was sure his frustration had reached the point of being infectious.


Positive steering (scale=2.0):


100%|██████████| 20/20 [00:01<00:00, 12.59it/s]


The experience was part of a four-symposium at Boston University on Friday titled "Enhancing Responsive Design and

Positive steering (scale=4.0):


100%|██████████| 20/20 [00:01<00:00, 12.97it/s]


The experience was adventurous and fun. Somehow discovering it was fun. Traveling was great. Wanting to get inside

Negative steering (scale=1.0):


100%|██████████| 20/20 [00:01<00:00, 12.93it/s]


The experience was like an adventure. There were horseflies everywhere and all the time. I was frozen in place in

Negative steering (scale=2.0):


100%|██████████| 20/20 [00:01<00:00, 13.09it/s]


The experience was perhaps one of the greatest battles between rival MMO campaigns ever seen. It started with solid records scoring in

Negative steering (scale=4.0):


100%|██████████| 20/20 [00:01<00:00, 12.98it/s]


The experience was surreal. For months now, I've been working on the Kickstarter project, printed on exquisite soft tissue

Prompt: The movie made me feel
No steering:


100%|██████████| 20/20 [00:01<00:00, 13.30it/s]


The movie made me feel guilty for declining a slice with a large glass of beer, but I pretty much understood all of my

Positive steering (scale=1.0):


100%|██████████| 20/20 [00:01<00:00, 13.25it/s]


The movie made me feel out of sorts. The characters expressed a shocking sentiment. I felt unworthy of them. I felt like

Positive steering (scale=2.0):


100%|██████████| 20/20 [00:01<00:00, 13.31it/s]


The movie made me feel depressed, despondent, helpless and all those triggering emotions, but most enjoyable I guess…


Positive steering (scale=4.0):


100%|██████████| 20/20 [00:01<00:00, 12.83it/s]


The movie made me feel appreciated because it determined someone's psyche and made sure nothing more startling would happen, sacrificing padding pairs and

Negative steering (scale=1.0):


100%|██████████| 20/20 [00:01<00:00, 12.70it/s]


The movie made me feel like a 13-, 14-year-old boy, and I assembled on this page to experience it

Negative steering (scale=2.0):


100%|██████████| 20/20 [00:01<00:00, 13.34it/s]


The movie made me feel worthless and uninteresting. I didn't have 'what it takes' and nostalgia may or may not

Negative steering (scale=4.0):


100%|██████████| 20/20 [00:01<00:00, 13.25it/s]

The movie made me feel like this is a metaphor for how things get done these days in the real world: Tresdin





In [36]:
sentiment_csv_path = Path('sentiment_labels_cleaned.csv')

def get_brain_derived_sentiment_vector(mapper, all_subjects_data, sentiment_labels_path=sentiment_csv_path):
    """Extract sentiment vector using brain-to-LLM mapping with labeled sentences"""
    # Load sentiment labels
    import pandas as pd
    sentiment_df = pd.read_csv(sentiment_labels_path)
    
    # Choose a subject with good mapping results
    subject_id = 'resultsZJS_SR'  # Replace with best-performing subject
    
    # Select a brain feature that mapped well
    feature_name = 'GPT_g2'  # Or best feature based on mapping results
    
    # Get the model from mapper
    model_data = mapper.models[feature_name]
    model = model_data['model']
    
    # Get words from positive and negative sentences
    pos_brain_activity = []
    neg_brain_activity = []
    
    words_data = all_subjects_data[subject_id]
    
    for word_item in words_data:
        sent_id = word_item['sentence_id']
        
        # Find sentiment for this sentence
        sent_match = sentiment_df[sentiment_df['sentence_id'] == sent_id]
        if not sent_match.empty:
            sentiment = sent_match['sentiment'].values[0]
            
            if sentiment > 0 and feature_name in word_item['eeg_features']:
                feature = word_item['eeg_features'][feature_name]
                if hasattr(feature, 'shape') and not np.isnan(feature).any():
                    pos_brain_activity.append(feature)
            elif sentiment < 0 and feature_name in word_item['eeg_features']:
                feature = word_item['eeg_features'][feature_name]
                if hasattr(feature, 'shape') and not np.isnan(feature).any():
                    neg_brain_activity.append(feature)
    
    print(f"Found {len(pos_brain_activity)} positive and {len(neg_brain_activity)} negative samples")
    
    # Average brain activity
    if len(pos_brain_activity) < 10 or len(neg_brain_activity) < 10:
        print("Not enough brain activity data found")
        return None
        
    pos_brain = np.mean(pos_brain_activity, axis=0)
    neg_brain = np.mean(neg_brain_activity, axis=0)
    
    # Get the sentiment vector in brain space
    brain_sentiment_vector = pos_brain - neg_brain
    
    # Use the mapping to project to LLM space
    llm_sentiment_vector = model.coef_.T @ brain_sentiment_vector
    
    # Normalize
    llm_sentiment_vector = llm_sentiment_vector / np.linalg.norm(llm_sentiment_vector)
    
    return llm_sentiment_vector

# Usage:
sentiment_vector = get_brain_derived_sentiment_vector(mapper, all_subjects_data)

# Test steering with the vector
test_prompts = ["The story was", "The experience was", "The movie made me feel"]
for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    
    # Without steering
    print("No steering:")
    text = generate_with_steering(embedding_gen.model, prompt, sentiment_vector, scale=0.0)
    print(text)
    
    # With positive steering
    for scale in [1.0, 2.0, 4.0]:
        print(f"\nPositive steering (scale={scale}):")
        text = generate_with_steering(embedding_gen.model, prompt, sentiment_vector, scale=scale)
        print(text)
    
    # With negative steering
    for scale in [1.0, 2.0, 4.0]:
        print(f"\nNegative steering (scale={scale}):")
        text = generate_with_steering(embedding_gen.model, prompt, -sentiment_vector, scale=scale)
        print(text)

Found 0 positive and 0 negative samples
Not enough brain activity data found

Prompt: The story was
No steering:


  0%|          | 0/20 [00:00<?, ?it/s]


TypeError: object of type 'NoneType' has no len()

In [None]:


# class SentimentSteeringExtractor:
#     def __init__(self, data_loader, embedding_generator, sentiment_path=sentiment_csv_path):
#         self.data_loader = data_loader
#         self.embedding_generator = embedding_generator
#         self.sentiment_data = self._load_sentiment_data(sentiment_path)
        
#     def _load_sentiment_data(self, path):
#         # Load sentiment labels from CSV file
#         import pandas as pd
#         try:
#             return pd.read_csv(path)
#         except Exception as e:
#             print(f"Error loading sentiment data: {e}")
#             return None
            
#     def extract_sentiment_vectors(self, subject_id):
#         """Extract sentiment vectors from EEG data"""
#         # Get EEG data for this subject
#         word_data = self.data_loader.extract_word_level_data(subject_id)
        
#         # Group words by sentiment
#         positive_words = []
#         negative_words = []
        
#         for word in word_data:
#             # Match with sentiment data
#             word_text = word['word']
#             sent_id = word['sentence_id']
            
#             # Find sentiment for this sentence
#             sentiment = self._get_sentiment(sent_id)
#             if sentiment > 0:
#                 positive_words.append(word)
#             elif sentiment < 0:
#                 negative_words.append(word)
        
#         # Compute average brain embeddings for positive and negative words
#         positive_embedding = self._compute_brain_embedding(positive_words)
#         negative_embedding = self._compute_brain_embedding(negative_words)
        
#         # Compute sentiment direction as the difference
#         sentiment_vector = positive_embedding - negative_embedding
        
#         # Normalize
#         sentiment_vector = sentiment_vector / np.linalg.norm(sentiment_vector)
        
#         return sentiment_vector
        
#     def _get_sentiment(self, sentence_id):
#         """Get sentiment label for a sentence"""
#         if self.sentiment_data is None:
#             return 0
            
#         # Find sentiment for this sentence ID
#         try:
#             sent_row = self.sentiment_data[self.sentiment_data['sentence_id'] == sentence_id]
#             if not sent_row.empty:
#                 return sent_row['sentiment'].values[0]
#         except:
#             pass
#         return 0
    
#     def _compute_brain_embedding(self, words):
#         """Compute average brain embedding from EEG features"""
#         # Initialize embedding
#         feature_name = 'FFD_t1'  # Using first fixation duration, theta band
        
#         all_features = []
#         for word in words:
#             if feature_name in word['eeg_features']:
#                 feature = word['eeg_features'][feature_name]
#                 if hasattr(feature, 'shape') and not np.isnan(feature).any():
#                     all_features.append(feature)
        
#         if not all_features:
#             return None
            
#         # Average across all words
#         brain_embedding = np.mean(all_features, axis=0)
#         return brain_embedding
    
#     def map_to_model_space(self, brain_vector, mapper):
#         """Map brain vector to model embedding space"""
#         if brain_vector is None:
#             return None
            
#         # Use trained mapper to convert brain vector to model space
#         if not hasattr(mapper, 'models') or not mapper.models:
#             print("No trained models available")
#             return None
            
#         # Get a trained model (using first available feature)
#         feature_name = next(iter(mapper.models.keys()))
#         model_data = mapper.models[feature_name]
        
#         if not isinstance(model_data, dict) or 'model' not in model_data:
#             print("Invalid model format")
#             return None
            
#         # Extract model
#         model = model_data['model']
        
#         # Map brain vector to model space using inverse mapping
#         # This is an approximation - actual mapping would depend on model type
#         try:
#             # For Ridge regression, we can use the coefficients directly
#             model_vector = model.coef_.T @ brain_vector
            
#             # Normalize
#             model_vector = model_vector / np.linalg.norm(model_vector)
#             return model_vector
#         except:
#             print("Failed to map brain vector to model space")
#             return None

# def test_sentiment_steering(subject_id, data_loader, embedding_gen, mapper):
#     """Test sentiment steering for a subject"""
    
#     # Extract sentiment vector
#     extractor = SentimentSteeringExtractor(data_loader, embedding_gen)
#     sentiment_brain_vector = extractor.extract_sentiment_vectors(subject_id)
    
#     if sentiment_brain_vector is None:
#         print("Failed to extract sentiment vector")
#         return
        
#     # Map to model space
#     sentiment_model_vector = extractor.map_to_model_space(sentiment_brain_vector, mapper)
    
#     if sentiment_model_vector is None:
#         print("Failed to map sentiment vector to model space")
#         return
    
#     # Test prompts
#     test_prompts = [
#         "The story was",
#         "The movie made me feel",
#         "Reading this book was an experience that was"
#     ]
    
#     # Test with different steering scales
#     for prompt in test_prompts:
#         print(f"\nPrompt: {prompt}")
        
#         # Without steering
#         print("No steering:")
#         text = generate_with_steering(
#             embedding_gen.model, prompt, sentiment_model_vector, scale=0.0)
#         print(text)
        
#         # With positive steering (sentiment vector)
#         for scale in [1.0, 2.0]:
#             print(f"\nPositive steering (scale={scale}):")
#             text = generate_with_steering(
#                 embedding_gen.model, prompt, sentiment_model_vector, scale=scale)
#             print(text)
            
#         # With negative steering (negative sentiment vector)
#         for scale in [1.0, 2.0]:
#             print(f"\nNegative steering (scale={scale}):")
#             text = generate_with_steering(
#                 embedding_gen.model, prompt, -sentiment_model_vector, scale=scale)
#             print(text)
        
#         print("\n" + "-"*50)

In [None]:
# # Initialize the extraction and create steering vector
# subject_id = 'resultsZJS_SR'  # Choose a subject
# extractor = SentimentSteeringExtractor(zuco_loader, embedding_gen)

# # Extract sentiment vector and map to model space
# sentiment_brain_vector = extractor.extract_sentiment_vectors(subject_id)
# sentiment_model_vector = extractor.map_to_model_space(sentiment_brain_vector, mapper)

# # Test sentiment steering
# test_sentiment_steering(subject_id, zuco_loader, embedding_gen, mapper)

Error loading sentiment data: Error tokenizing data. C error: Expected 3 fields in line 28, saw 4

Loading data from ../zuco_data/zuco1.0/task1-SR/Matlab files/resultsZJS_SR.mat


TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'

# cross subject steering (todo)

In [None]:
# def extract_steering_vector(self, feature_name='FFD_t1', method='weighted', threshold=0.1):
#     """
#     Extract a steering vector using different methods:
#     - 'weighted': Weight electrodes by correlation strength
#     - 'top_n': Use only top N electrodes 
#     - 'threshold': Use electrodes with correlation above threshold
#     """
#     if feature_name not in self.models:
#         print(f"No model trained for feature {feature_name}")
#         return None
        
#     feature_data = self.models[feature_name]
    
#     # Handle both single-feature and multi-feature formats
#     if isinstance(feature_data, dict) and 'model' in feature_data:
#         model = feature_data['model']
#         results = feature_data['results']
#     else:
#         model = feature_data  # Original format
#         results = self.models[feature_name]['results']
    
#     weights = model.coef_.T  # [embedding_dim, n_electrodes]
    
#     # Calculate correlation strength per electrode
#     correlation_means = []
#     for result in results:
#         correlation_means.append(np.array(result['correlations']))
#     electrode_correlations = np.mean(np.stack(correlation_means), axis=0)
    
#     # Select electrodes based on method
#     if method == 'weighted':
#         # Weight each electrode by its correlation strength
#         electrode_weights = np.abs(electrode_correlations)
#         electrode_weights = electrode_weights / np.sum(electrode_weights)
#         steering_vector = np.zeros(weights.shape[0])
        
#         for i, weight in enumerate(electrode_weights):
#             if not np.isnan(weight):
#                 steering_vector += weight * weights[:, i]
                
#     elif method == 'top_n':
#         n_electrodes = 10  # Default to top 10
#         # Get top N electrodes by absolute correlation
#         top_indices = np.argsort(np.abs(electrode_correlations))[-n_electrodes:]
#         steering_vector = np.mean(weights[:, top_indices], axis=1)
        
#     elif method == 'threshold':
#         # Use electrodes above correlation threshold
#         mask = np.abs(electrode_correlations) > threshold
#         if not np.any(mask):
#             print(f"No electrodes above threshold {threshold}")
#             return None
#         steering_vector = np.mean(weights[:, mask], axis=1)
    
#     else:
#         raise ValueError(f"Unknown method: {method}")
    
#     # Normalize
#     steering_vector = steering_vector / np.linalg.norm(steering_vector)
#     return steering_vector