# Create EEG Sentences


In [6]:
import pickle
import os

def load_sample_groups(filepath="../data/sample_groups.pkl"):
    """Loads the grouped samples dictionary from a pickle file.

    Args:
        filepath (str): The path to the pickle file containing the sample groups dictionary.
                        Defaults to 	"sample_groups.pkl	" in the current directory.

    Returns:
        dict: The loaded sample_groups dictionary, or None if loading fails.
              Structure: {char: {set_name: [chunk1, chunk2, ...], ...}, ...}
              where each chunk is a list of numpy arrays (samples).
    """
    print(f"Attempting to load sample groups from: {filepath}")
    if not os.path.exists(filepath):
        print(f"Error: File not found at {filepath}. Please run the grouping script first.")
        return None
    
    try:
        with open(filepath, "rb") as f:
            loaded_groups = pickle.load(f)
        print("Successfully loaded sample groups dictionary.")
        # Optional: Add a check to ensure it	"s a dictionary
        if isinstance(loaded_groups, dict):
            return loaded_groups
        else:
            print(f"Error: Loaded object is not a dictionary (type: {type(loaded_groups)}). Returning None.")
            return None
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except pickle.UnpicklingError:
        print(f"Error: Could not unpickle data from {filepath}. File might be corrupted.")
        return None
    except Exception as e:
        print(f"An unexpected error occurred during loading: {e}")
        return None


sample_groups_data = load_sample_groups()

if sample_groups_data:
    print("\n--- Example Access --- ")
    target_char = 	"A" # Example character
    target_set = 	"set1" # Example set

    if target_char in sample_groups_data and target_set in sample_groups_data[target_char]:
        groups_for_set = sample_groups_data[target_char][target_set]
        num_chunks = len(groups_for_set)
        print(f"Character {target_char}, Set {target_set} has {num_chunks} chunks.")
        
        if num_chunks > 0:
            first_chunk = groups_for_set[0] # Get the first chunk (which is a list of samples)
            num_samples_in_chunk = len(first_chunk)
            print(f"  First chunk contains {num_samples_in_chunk} samples.")
            
            if num_samples_in_chunk > 0:
                first_sample_in_chunk = first_chunk[0] # Get the first sample (numpy array)
                print(f"    Shape of the first sample in the first chunk: {first_sample_in_chunk.shape}")
                # You can now work with 	"first_sample_in_chunk	" or the entire 	"first_chunk	" list
    else:
        print(f"Could not find data for Character {target_char} and Set {target_set}.")


Attempting to load sample groups from: ../data/sample_groups.pkl
Successfully loaded sample groups dictionary.

--- Example Access --- 
Character A, Set set1 has 4 chunks.
  First chunk contains 30 samples.
    Shape of the first sample in the first chunk: (78, 64)


# Add next-char prediction probabilities

In [None]:
# Add next-char prediction probabilities to each DFM sample as additional feature

import torch
import torch.nn as nn
import torch.nn.functional as F
import string


# Define character set (a-z, A-Z, 0-9)
all_chars = list(string.ascii_lowercase + string.ascii_uppercase + string.digits)
char2idx = {ch: idx for idx, ch in enumerate(all_chars)}
idx2char = {idx: ch for ch, idx in char2idx.items()}
vocab_size = len(all_chars)

print("Vocabulary size:", vocab_size)

class CharPredictor(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, hidden_dim=128):
        super(CharPredictor, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        _, (hidden, _) = self.lstm(embed)
        hidden = hidden.squeeze(0)
        out = self.fc(hidden)
        return out
    
def load_model(path="../model/api/char_predictor.pth"):
    model = CharPredictor(vocab_size)
    model.load_state_dict(torch.load(path))
    model.eval()
    print(f"Model loaded from {path}.")
    return model


def predict_next_chars(model, sentence, top_k=5):
    model.eval()
    with torch.no_grad():
        input_seq = [char2idx[ch] for ch in sentence if ch in char2idx]
        if not input_seq:
            raise ValueError("Input sentence must contain at least one known character.")

        input_seq = torch.tensor(input_seq).unsqueeze(0)
        output = model(input_seq)
        probs = F.softmax(output, dim=-1).squeeze(0)

        top_probs, top_indices = torch.topk(probs, top_k)

        result = {}
        for prob, idx in zip(top_probs, top_indices):
            result[idx2char[idx.item()]] = round(prob.item(), 4)

        return result

# Load pretrained NLP model
nlp_model = load_model()


# prefix = context before the target char

# nlp_prob_dict = predict_next_chars(nlp_model, prefix, top_k=6)
