# Data Shape Analysis and Visualization

In [29]:
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import pandas as pd

def load_sentence_eeg_prob_data(filepath="../../data/sentence_eeg_prob_data.pkl"):
    """Loads the final processed data list from a pickle file."""
    print(f"Attempting to load processed data from: {filepath}")
    if not os.path.exists(filepath):
        print(f"Error: File not found at {filepath}.")
        return None
    try:
        with open(filepath, "rb") as f:
            data = pickle.load(f)
        print("Successfully loaded processed data.")
        if isinstance(data, list):
            return data
        else:
            print(f"Error: Loaded object is not a list (type: {type(data)}). Returning None.")
            return None
    except Exception as e:
        print(f"An unexpected error occurred during loading processed data: {e}")
        return None

def create_visualizations(data, output_dir="../../visualizations"):
    """Creates and saves various visualizations of the data structure."""
    if not data:
        print("No data to visualize.")
        return
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Generating visualizations in {output_dir}...")
    
    # 1. Character Distribution Bar Chart
    plt.figure(figsize=(14, 8))
    char_counter = Counter([item['character'] for item in data])
    chars = list(char_counter.keys())
    counts = list(char_counter.values())
    
    # Sort by frequency
    sorted_indices = np.argsort(counts)[::-1]
    chars = [chars[i] for i in sorted_indices]
    counts = [counts[i] for i in sorted_indices]
    
    plt.bar(range(len(chars)), counts, color='skyblue')
    plt.xticks(range(len(chars)), chars, rotation=45)
    plt.title('Character Distribution in Dataset', fontsize=16)
    plt.xlabel('Characters', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/character_distribution.png")
    plt.close()
    print("✓ Character distribution chart saved")
    
    # 2. Sentence Length Histogram
    plt.figure(figsize=(12, 6))
    unique_sentences = list(set([item['sentence'] for item in data]))
    sentence_lengths = [len(sentence) for sentence in unique_sentences]
    
    plt.hist(sentence_lengths, bins=20, color='lightgreen', edgecolor='black')
    plt.axvline(np.mean(sentence_lengths), color='red', linestyle='dashed', linewidth=2, 
                label=f'Mean: {np.mean(sentence_lengths):.2f}')
    plt.title('Distribution of Sentence Lengths', fontsize=16)
    plt.xlabel('Sentence Length (characters)', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/sentence_length_distribution.png")
    plt.close()
    print("✓ Sentence length distribution chart saved")
    
    # 3. Character Position in Sentences
    plt.figure(figsize=(12, 6))
    char_positions = [item['char_idx_in_sentence'] for item in data]
    
    plt.hist(char_positions, bins=30, color='salmon', edgecolor='black')
    plt.title('Distribution of Character Positions in Sentences', fontsize=16)
    plt.xlabel('Character Position Index', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/character_position_distribution.png")
    plt.close()
    print("✓ Character position distribution chart saved")
    
    # 4. EEG Chunk Size Distribution
    plt.figure(figsize=(12, 6))
    chunk_sizes = [len(item['eeg_chunk']) for item in data if 'eeg_chunk' in item and item['eeg_chunk']]
    
    if chunk_sizes:
        plt.hist(chunk_sizes, bins=20, color='lightblue', edgecolor='black')
        plt.axvline(np.mean(chunk_sizes), color='red', linestyle='dashed', linewidth=2, 
                    label=f'Mean: {np.mean(chunk_sizes):.2f}')
        plt.title('Distribution of EEG Chunk Sizes', fontsize=16)
        plt.xlabel('Number of Samples in Chunk', fontsize=14)
        plt.ylabel('Frequency', fontsize=14)
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/eeg_chunk_size_distribution.png")
        plt.close()
        print("✓ EEG chunk size distribution chart saved")
    
    # 5. EEG Sample Heatmap
    sample_item = next((item for item in data if 'eeg_chunk' in item and item['eeg_chunk']), None)
    if sample_item and sample_item['eeg_chunk']:
        # Get the first sample from the chunk
        sample = sample_item['eeg_chunk'][0]
        
        plt.figure(figsize=(14, 8))
        sns.heatmap(sample.T, cmap='viridis', cbar_kws={'label': 'Amplitude'})
        plt.title(f"EEG Sample Heatmap for Character '{sample_item['character']}'", fontsize=16)
        plt.xlabel('Time Steps', fontsize=14)
        plt.ylabel('Channels', fontsize=14)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/eeg_sample_heatmap.png")
        plt.close()
        print("✓ EEG sample heatmap saved")
        
        # 6. EEG Sample Line Plot (for first few channels)
        plt.figure(figsize=(14, 8))
        num_channels_to_plot = min(5, sample.shape[1])  # Plot up to 5 channels
        for i in range(num_channels_to_plot):
            plt.plot(sample[:, i], label=f'Channel {i+1}')
        
        plt.title(f"EEG Signal for First {num_channels_to_plot} Channels", fontsize=16)
        plt.xlabel('Time Steps', fontsize=14)
        plt.ylabel('Amplitude', fontsize=14)
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/eeg_sample_lineplot.png")
        plt.close()
        print("✓ EEG sample line plot saved")
    
    # 7. Next Character Probability Distribution
    # Collect top predicted characters and their probabilities
    top_chars = []
    top_probs = []
    
    for item in data:
        if 'next_char_probabilities' in item and item['next_char_probabilities']:
            # Get the top predicted character
            top_char, top_prob = max(item['next_char_probabilities'].items(), key=lambda x: x[1])
            top_chars.append(top_char)
            top_probs.append(top_prob)
    
    if top_chars:
        # Create a histogram of top probabilities
        plt.figure(figsize=(12, 6))
        plt.hist(top_probs, bins=20, color='plum', edgecolor='black')
        plt.title('Distribution of Top Prediction Probabilities', fontsize=16)
        plt.xlabel('Probability', fontsize=14)
        plt.ylabel('Frequency', fontsize=14)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/top_probability_distribution.png")
        plt.close()
        print("✓ Top probability distribution chart saved")
        
        # Create a bar chart of top predicted characters
        plt.figure(figsize=(14, 8))
        top_char_counter = Counter(top_chars)
        chars = list(top_char_counter.keys())
        counts = list(top_char_counter.values())
        
        # Sort by frequency
        sorted_indices = np.argsort(counts)[::-1][:20]  # Top 20 characters
        chars = [chars[i] for i in sorted_indices]
        counts = [counts[i] for i in sorted_indices]
        
        plt.bar(range(len(chars)), counts, color='orchid')
        plt.xticks(range(len(chars)), chars, rotation=45)
        plt.title('Top 20 Most Frequently Predicted Next Characters', fontsize=16)
        plt.xlabel('Characters', fontsize=14)
        plt.ylabel('Frequency', fontsize=14)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(f"{output_dir}/top_predicted_characters.png")
        plt.close()
        print("✓ Top predicted characters chart saved")
    
    # 8. Data Structure Overview Diagram
    # Create a visual representation of the data structure
    plt.figure(figsize=(14, 10))
    
    # Define the structure components
    components = [
        "Character", "Prefix", "Sentence", "Position Index", 
        "EEG Chunk", "Next Char Probabilities"
    ]
    
    # Define sizes (approximate relative memory usage)
    if data and 'eeg_chunk' in data[0] and data[0]['eeg_chunk']:
        eeg_size = sum(sample.nbytes for sample in data[0]['eeg_chunk'])
        sizes = [
            1,  # Character
            len(data[0]['prefix']) if 'prefix' in data[0] else 0,  # Prefix
            len(data[0]['sentence']) if 'sentence' in data[0] else 0,  # Sentence
            4,  # Position Index (int)
            eeg_size,  # EEG Chunk
            len(data[0]['next_char_probabilities']) * 5 if 'next_char_probabilities' in data[0] else 0  # Next Char Probs
        ]
    else:
        # Default sizes if no data
        sizes = [1, 10, 20, 4, 1000, 50]
    
    # Normalize sizes for visualization
    total = sum(sizes)
    sizes = [size/total for size in sizes]
    
    # Create a horizontal bar chart
    colors = ['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0', '#ffb3e6']
    y_pos = np.arange(len(components))
    
    plt.barh(y_pos, sizes, color=colors)
    plt.yticks(y_pos, components)
    plt.xlabel('Relative Size (normalized)', fontsize=14)
    plt.title('Data Structure Components Overview', fontsize=16)
    
    # Add size annotations
    for i, v in enumerate(sizes):
        plt.text(v + 0.01, i, f"{v:.2f}", va='center')
    
    plt.tight_layout()
    plt.savefig(f"{output_dir}/data_structure_overview.png")
    plt.close()
    print("✓ Data structure overview diagram saved")
    
    # 9. Create a summary image with key statistics
    plt.figure(figsize=(10, 8))
    plt.axis('off')
    
    # Collect key statistics
    total_items = len(data)
    unique_chars = len(char_counter)
    unique_sentences_count = len(set([item['sentence'] for item in data]))
    avg_sentence_length = np.mean([len(s) for s in set([item['sentence'] for item in data])])
    
    if sample_item and 'eeg_chunk' in sample_item and sample_item['eeg_chunk']:
        sample = sample_item['eeg_chunk'][0]
        eeg_shape = f"{sample.shape[0]} × {sample.shape[1]}"
        avg_chunk_size = np.mean([len(item['eeg_chunk']) for item in data if 'eeg_chunk' in item and item['eeg_chunk']])
    else:
        eeg_shape = "N/A"
        avg_chunk_size = 0
    
    # Create text for the summary
    summary_text = (
        "DATA SHAPE SUMMARY\n"
        "==================\n\n"
        f"Total Items: {total_items}\n"
        f"Unique Characters: {unique_chars}\n"
        f"Unique Sentences: {unique_sentences_count}\n"
        f"Avg. Sentence Length: {avg_sentence_length:.2f} chars\n"
        f"EEG Sample Shape: {eeg_shape}\n"
        f"Avg. Chunk Size: {avg_chunk_size:.2f} samples\n"
    )
    
    plt.text(0.1, 0.5, summary_text, fontsize=14, family='monospace')
    plt.tight_layout()
    plt.savefig(f"{output_dir}/data_summary.png")
    plt.close()
    print("✓ Data summary image saved")
    
    print(f"\nAll visualizations saved to {output_dir}/")
    return output_dir


# Load and Analyze Data

In [30]:
# Try to load from the default path first
data = load_sentence_eeg_prob_data()

# If that fails, try the path in the current directory
if data is None:
    data = load_sentence_eeg_prob_data("sentence_eeg_prob_data.pkl")

if data:
    # Create visualizations
    output_dir = create_visualizations(data)
    print(f"\nTo view the visualizations, check the files in the {output_dir} directory.")
else:
    print("Failed to load data. Please check the file path.")

Attempting to load processed data from: ../../data/sentence_eeg_prob_data.pkl
Successfully loaded processed data.
Generating visualizations in ../../visualizations...
✓ Character distribution chart saved
✓ Sentence length distribution chart saved
✓ Character position distribution chart saved
✓ EEG chunk size distribution chart saved
✓ EEG sample heatmap saved
✓ EEG sample line plot saved
✓ Top probability distribution chart saved
✓ Top predicted characters chart saved
✓ Data structure overview diagram saved
✓ Data summary image saved

All visualizations saved to ../../visualizations/

To view the visualizations, check the files in the ../../visualizations directory.
