# Model Training 

In [20]:
from bundle.DataCraft import * 

data = load_sentence_eeg_prob_data()

if data:
    # Print example of converted item
    print(f"Loaded {len(data)} items.")
    if data:
        print("Example of first item:", data[1]["character"])
        print("Example of first item:", data[1]["char_idx_in_sentence"])
        print("Example of first item:", data[1]["sentence"])
        print("Example of first item:", data[1]["eeg_chunk"][0:1])
        print("Example of first item:", data[1]["prob_chunk"])

Attempting to load processed data from: ../../data/sentences_eeg.pkl
Successfully loaded processed data.
Loaded 16270 items.
Example of first item: H
Example of first item: 1
Example of first item: THE QUICK DOG JUMPED OVER
Example of first item: [array([[ 0.22776094,  0.46970123,  0.70324725, ...,  0.14635862,
         1.0238168 , -0.9463383 ],
       [ 0.08742154,  0.2278345 ,  0.34740365, ..., -0.1466554 ,
         0.61614096, -1.0300957 ],
       [-0.06776999, -0.04082734, -0.06774805, ..., -0.3656779 ,
         0.22884615, -1.0415238 ],
       ...,
       [-1.3213954 , -0.93309444, -0.07610805, ..., -0.46508783,
        -0.98153764,  0.23164782],
       [-1.4564943 , -1.0470396 , -0.07786182, ..., -0.53110176,
        -1.0577328 ,  0.05647672],
       [-1.6199298 , -1.1973821 , -0.17433581, ..., -0.52290905,
        -1.0656656 , -0.04597962]], dtype=float32)]
Example of first item: [[1.490e-02 1.490e-02]
 [4.000e-04 4.000e-04]
 [4.000e-04 4.000e-04]
 [1.000e-04 1.000e-04]
 [1.000e

In [21]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pickle
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import time

# Define a custom dataset class
class EEGProbDataset(Dataset):
    def __init__(self, eeg_data, prob_data, labels):
        self.eeg_data = eeg_data
        self.prob_data = prob_data
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        eeg = self.eeg_data[idx]
        prob = self.prob_data[idx]
        label = self.labels[idx]
        return eeg, prob, label

# Define the RNN model
class EEGProbRNN(nn.Module):
    def __init__(self, eeg_input_size, prob_input_size, hidden_size, num_layers, num_classes, dropout=0.2):
        super(EEGProbRNN, self).__init__()
        
        # EEG branch
        self.eeg_rnn = nn.LSTM(
            input_size=eeg_input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Probability branch
        self.prob_rnn = nn.LSTM(
            input_size=prob_input_size,
            hidden_size=hidden_size // 2,  # Smaller hidden size for prob branch
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Attention mechanism for EEG
        self.eeg_attention = nn.Sequential(
            nn.Linear(hidden_size * 2, 1),
            nn.Tanh()
        )
        
        # Attention mechanism for Prob
        self.prob_attention = nn.Sequential(
            nn.Linear(hidden_size, 1),
            nn.Tanh()
        )
        
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(hidden_size * 2 + hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Output layer
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, eeg, prob):
        # Process EEG data
        eeg_out, _ = self.eeg_rnn(eeg)
        
        # Apply attention to EEG
        eeg_attn_weights = self.eeg_attention(eeg_out)
        eeg_attn_weights = torch.softmax(eeg_attn_weights, dim=1)
        eeg_context = torch.sum(eeg_out * eeg_attn_weights, dim=1)
        
        # Process probability data
        prob_out, _ = self.prob_rnn(prob)
        
        # Apply attention to Prob
        prob_attn_weights = self.prob_attention(prob_out)
        prob_attn_weights = torch.softmax(prob_attn_weights, dim=1)
        prob_context = torch.sum(prob_out * prob_attn_weights, dim=1)
        
        # Concatenate features from both branches
        combined = torch.cat((eeg_context, prob_context), dim=1)
        
        # Fusion layer
        fused = self.fusion(combined)
        
        # Output layer
        output = self.fc(fused)
        
        return output

# Function to prepare data for training
def prepare_data(data):
    eeg_samples = []
    prob_samples = []
    labels = []
    
    for item in data:
        if 'eeg_chunk' in item and item['eeg_chunk'] and 'converted_data' in item and 'character' in item:
            # Get the first EEG sample from the chunk
            eeg_sample = item['eeg_chunk'][0]
            
            # Get the probability matrix
            prob_sample = item['converted_data']
            
            # Get the character label
            label = item['character']
            
            eeg_samples.append(eeg_sample)
            prob_samples.append(prob_sample)
            labels.append(label)
    
    # Convert to numpy arrays
    eeg_samples = np.array(eeg_samples)
    prob_samples = np.array(prob_samples)
    
    # Encode labels
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)
    
    return eeg_samples, prob_samples, encoded_labels, label_encoder

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device='cpu'):
    # Training history
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    # Best model tracking
    best_val_acc = 0.0
    best_model_state = None
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for eeg, prob, labels in train_loader:
            eeg = eeg.float().to(device)
            prob = prob.float().to(device)
            labels = labels.to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(eeg, prob)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * eeg.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for eeg, prob, labels in val_loader:
                eeg = eeg.float().to(device)
                prob = prob.float().to(device)
                labels = labels.to(device)
                
                outputs = model(eeg, prob)
                loss = criterion(outputs, labels)
                
                val_running_loss += loss.item() * eeg.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_epoch_loss = val_running_loss / val_total
        val_epoch_acc = val_correct / val_total
        val_losses.append(val_epoch_loss)
        val_accs.append(val_epoch_acc)
        
        # Save best model
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            best_model_state = model.state_dict().copy()
        
        # Print statistics
        time_elapsed = time.time() - start_time
        print(f'Epoch {epoch+1}/{num_epochs} | Time: {time_elapsed:.2f}s')
        print(f'Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}')
        print(f'Val Loss: {val_epoch_loss:.4f} | Val Acc: {val_epoch_acc:.4f}')
        print('-' * 60)
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, train_losses, val_losses, train_accs, val_accs

# Function to plot training history
def plot_training_history(train_losses, val_losses, train_accs, val_accs, save_dir='./results'):
    os.makedirs(save_dir, exist_ok=True)
    
    # Plot losses
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{save_dir}/loss_history.png')
    plt.close()
    
    # Plot accuracies
    plt.figure(figsize=(10, 5))
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.title('Accuracy over epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(f'{save_dir}/accuracy_history.png')
    plt.close()

# Function to evaluate model on test set
def evaluate_model(model, test_loader, criterion, device='cpu', label_encoder=None):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for eeg, prob, labels in test_loader:
            eeg = eeg.float().to(device)
            prob = prob.float().to(device)
            labels = labels.to(device)
            
            outputs = model(eeg, prob)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item() * eeg.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_loss = test_loss / total
    test_acc = correct / total
    
    print(f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}')
    
    # If label encoder is provided, print some examples
    if label_encoder is not None:
        print("\nSample predictions:")
        for i in range(min(10, len(all_labels))):
            true_label = label_encoder.inverse_transform([all_labels[i]])[0]
            pred_label = label_encoder.inverse_transform([all_preds[i]])[0]
            print(f"True: {true_label} | Predicted: {pred_label} | {'✓' if true_label == pred_label else '✗'}")
    
    return test_loss, test_acc, all_preds, all_labels

# Save model function
def save_model(model, label_encoder, save_dir='../../models/rnn'):
    os.makedirs(save_dir, exist_ok=True)
    
    # Save model
    torch.save(model.state_dict(), f'{save_dir}/eeg_prob_rnn_model.pth')
    
    # Save label encoder
    with open(f'{save_dir}/label_encoder.pkl', 'wb') as f:
        pickle.dump(label_encoder, f)
    
    print(f"Model and label encoder saved to {save_dir}")

# Main execution
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


if data:
    # Prepare data
    print("Preparing data...")
    print(f"Data size: {len(data)} items")
    eeg_samples, prob_samples, encoded_labels, label_encoder = prepare_data(data)
    
    print(f"Data prepared:")
    print(f"  EEG samples shape: {eeg_samples.shape}")
    print(f"  Probability samples shape: {prob_samples.shape}")
    print(f"  Number of classes: {len(label_encoder.classes_)}")
    print(f"  Classes: {label_encoder.classes_}")
    
    # Split data into train, validation, and test sets
    X_train_eeg, X_test_eeg, X_train_prob, X_test_prob, y_train, y_test = train_test_split(
        eeg_samples, prob_samples, encoded_labels, test_size=0.2, random_state=42, stratify=encoded_labels
    )
    
    X_train_eeg, X_val_eeg, X_train_prob, X_val_prob, y_train, y_val = train_test_split(
        X_train_eeg, X_train_prob, y_train, test_size=0.25, random_state=42, stratify=y_train
    )
    
    print(f"Train set: {len(y_train)} samples")
    print(f"Validation set: {len(y_val)} samples")
    print(f"Test set: {len(y_test)} samples")
    
    # Create datasets
    train_dataset = EEGProbDataset(X_train_eeg, X_train_prob, y_train)
    val_dataset = EEGProbDataset(X_val_eeg, X_val_prob, y_val)
    test_dataset = EEGProbDataset(X_test_eeg, X_test_prob, y_test)
    
    # Create data loaders
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Get input dimensions from data
    eeg_input_size = eeg_samples.shape[2]  # Number of channels
    prob_input_size = prob_samples.shape[2]  # Number of columns in prob matrix
    num_classes = len(label_encoder.classes_)
    
    # Initialize model
    hidden_size = 128
    num_layers = 2
    model = EEGProbRNN(
        eeg_input_size=eeg_input_size,
        prob_input_size=prob_input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_classes=num_classes,
        dropout=0.3
    ).to(device)
    
    # Print model summary
    print(model)
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train model
    print("Starting training...")
    model, train_losses, val_losses, train_accs, val_accs = train_model(
        model, train_loader, val_loader, criterion, optimizer, num_epochs=20, device=device
    )
    
    # Plot training history
    plot_training_history(train_losses, val_losses, train_accs, val_accs)
    
    # Evaluate on test set
    print("\nEvaluating on test set...")
    test_loss, test_acc, all_preds, all_labels = evaluate_model(
        model, test_loader, criterion, device=device, label_encoder=label_encoder
    )
    
    # Save model
    save_model(model, label_encoder)
    
    print("\nTraining and evaluation complete!")
else:
    print("Failed to load data. Please check the file path.")

Preparing data...
Data size: 16270 items
Data prepared:
  EEG samples shape: (0,)
  Probability samples shape: (0,)
  Number of classes: 0
  Classes: []


ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.