In [3]:
# ===============================
# Imports and Configs
# ===============================
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns

from bundle.DataCraft import load_sentence_eeg_prob_data

# ===============================
# Constants
# ===============================
NUM_CLASSES = 36
MODEL_PATH = "../../model/ecd/trained_eegcnn_model_selected_channels_set2.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SELECTED_CHANNELS = [10, 33, 48, 50, 52, 55, 59, 61]

# ===============================
# Dataset
# ===============================
class EEGDataset(Dataset):
    def __init__(self, data, label_encoder):
        self.data = data
        self.label_encoder = label_encoder

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chunk = np.array(self.data[idx]["eeg_chunk"], dtype=np.float32)[:, :, SELECTED_CHANNELS]
        label = self.label_encoder.transform([self.data[idx]["character"]])[0]
        return torch.tensor(chunk).unsqueeze(0), torch.tensor(label)

# ===============================
# Model Definition
# ===============================
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class EEGCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(EEGCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(32)

        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(64)
        self.se1 = SEBlock(64)

        self.conv3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(128)
        self.se2 = SEBlock(128)

        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Linear(128, 64)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = self.se1(x)
        x = torch.relu(self.bn3(self.conv3(x)))
        x = self.se2(x)
        x = self.pool(x).squeeze()
        x = self.dropout(torch.relu(self.fc1(x)))
        return self.fc2(x)

# ===============================
# Load Data and Model
# ===============================


raw_data = load_sentence_eeg_prob_data("../../data/sentences_eeg_train_set2.pkl")
all_labels = [item["character"] for item in raw_data]
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

raw_data = load_sentence_eeg_prob_data("../../data/sentences_eeg_val_set2.pkl")
if not raw_data:
    raise ValueError("Failed to load data.")
_, test_set = train_test_split(raw_data, test_size=0.2, random_state=3)
test_dataset = EEGDataset(test_set, label_encoder)
test_loader = DataLoader(test_dataset, batch_size=64)

model = EEGCNN()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

# ===============================
# Evaluation
# ===============================
# ===============================
# Evaluation with Character Output (Print First 30 Only)
# ===============================
correct = total = 0
y_true, y_pred = [], []
printed = 0  # Print counter

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = model(inputs)
        preds = outputs.argmax(dim=1)

        correct += (preds == targets).sum().item()
        total += targets.size(0)

        y_true.extend(targets.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

        # Print only the first 30 predictions as characters
        if printed < 30:
            pred_chars = label_encoder.inverse_transform(preds.cpu().numpy())
            true_chars = label_encoder.inverse_transform(targets.cpu().numpy())
            for pred_char, true_char in zip(pred_chars, true_chars):
                if printed >= 30:
                    break
                print(f"Predicted: {pred_char} | True: {true_char}")
                printed += 1

# Accuracy and Report
print(f"\nTest Accuracy: {correct / total:.2%}")
print("\nClassification Report:")
print(classification_report(
    y_true,
    y_pred,
    labels=label_encoder.transform(label_encoder.classes_),
    target_names=label_encoder.classes_,
    zero_division=0
))

# # ===============================
# # Visualizations
# # ===============================
# def plot_conv1_activations(model, input_tensor, output_dir, index):
#     model.eval()
#     with torch.no_grad():
#         input_tensor = input_tensor.unsqueeze(0).to(DEVICE)
#         x1 = F.relu(model.bn1(model.conv1(input_tensor)))
#     x1 = x1.squeeze(0)[:, 30, :, :]
#     fig, axes = plt.subplots(4, 8, figsize=(20, 10))
#     fig.suptitle("Conv1 Activations at Frame 30", fontsize=16)
#     for i in range(32):
#         ax = axes[i // 8][i % 8]
#         sns.heatmap(x1[i].cpu(), ax=ax, cmap='plasma', cbar=False)
#         ax.set_title(f"F{i}")
#         ax.axis('off')
#     plt.tight_layout()
#     plt.savefig(os.path.join(output_dir, f"sample_{index}_conv1_activations.png"))
#     plt.close()
# 
# 
# def plot_conv2_activations(model, input_tensor, output_dir, index):
#     model.eval()
#     with torch.no_grad():
#         input_tensor = input_tensor.unsqueeze(0).to(DEVICE)
#         x1 = F.relu(model.bn1(model.conv2(input_tensor)))
#     x1 = x1.squeeze(0)[:, 30, :, :]
#     fig, axes = plt.subplots(4, 8, figsize=(20, 10))
#     fig.suptitle("Conv1 Activations at Frame 30", fontsize=16)
#     for i in range(32):
#         ax = axes[i // 8][i % 8]
#         sns.heatmap(x1[i].cpu(), ax=ax, cmap='plasma', cbar=False)
#         ax.set_title(f"F{i}")
#         ax.axis('off')
#     plt.tight_layout()
#     plt.savefig(os.path.join(output_dir, f"sample_{index}_conv1_activations.png"))
#     plt.close()
# 
# def plot_conv2_activations(model, input_tensor, output_dir, index):
#     model.eval()
#     with torch.no_grad():
#         input_tensor = input_tensor.unsqueeze(0).to(DEVICE)  # (1, 1, 31, 78, 8)
#         x1 = F.relu(model.bn1(model.conv1(input_tensor)))    # (1, 32, 31, 78, 8)
#         x2 = F.relu(model.bn2(model.conv2(x1)))              # (1, 64, 31, 78, 8)
# 
#     x2 = x2.squeeze(0)[:, 30, :, :]  # Shape: (64, 78, 8)
# 
#     fig, axes = plt.subplots(8, 8, figsize=(20, 16))
#     fig.suptitle("Conv2 Activations at Frame 30", fontsize=16)
# 
#     for i in range(64):
#         ax = axes[i // 8][i % 8]
#         sns.heatmap(x2[i].cpu(), ax=ax, cmap='plasma', cbar=False)
#         ax.set_title(f"F{i}")
#         ax.axis('off')
# 
#     plt.tight_layout()
#     plt.savefig(os.path.join(output_dir, f"sample_{index}_conv2_activations.png"))
#     plt.close()
# 
# def plot_conv3_activations(model, input_tensor, output_dir, index):
#     model.eval()
#     with torch.no_grad():
#         input_tensor = input_tensor.unsqueeze(0).to(DEVICE)
#         x1 = F.relu(model.bn1(model.conv1(input_tensor)))    # (1, 32, 31, 78, 8)
#         x2 = F.relu(model.bn2(model.conv2(x1)))              # (1, 64, 31, 78, 8)
#         x2 = model.se1(x2)
#         x3 = F.relu(model.bn3(model.conv3(x2)))              # (1, 128, 31, 78, 8)
# 
#     x3 = x3.squeeze(0)[:, 30, :, :]  # Shape: (128, 78, 8)
# 
#     fig, axes = plt.subplots(8, 16, figsize=(24, 16))
#     fig.suptitle("Conv3 Activations at Frame 30", fontsize=16)
# 
#     for i in range(128):
#         ax = axes[i // 16][i % 16]
#         sns.heatmap(x3[i].cpu(), ax=ax, cmap='plasma', cbar=False)
#         ax.set_title(f"F{i}", fontsize=8)
#         ax.axis('off')
# 
#     plt.tight_layout()
#     plt.savefig(os.path.join(output_dir, f"sample_{index}_conv3_activations.png"))
#     plt.close()
# 
# # ===============================
# # Main Visualization Caller
# # ===============================
# def create_visualizations_model(data, model, output_dir="../../visualizations"):
#     if not data:
#         print("No data to visualize.")
#         return
# 
#     os.makedirs(output_dir, exist_ok=True)
#     index = 10
#     if index >= len(data):
#         print("Invalid index.")
#         return
# 
#     sample = data[index]
#     chunk = np.array(sample["eeg_chunk"], dtype=np.float32)[:, :, SELECTED_CHANNELS]
#     chunk[30] *= 3.0
#     input_tensor = torch.tensor(chunk).unsqueeze(0)
# 
#     plot_conv1_activations(model, input_tensor, output_dir, index)
#     plot_conv2_activations(model, input_tensor, output_dir, index)
#     plot_conv3_activations(model, input_tensor, output_dir, index)
# 
# 
#     print(f"Visualizations generated at: {output_dir}")
# 
# # Run visualization
# create_visualizations_model(raw_data, model)

Attempting to load processed data from: ../../data/sentences_eeg_train_set2.pkl
Successfully loaded processed data.
Attempting to load processed data from: ../../data/sentences_eeg_val_set2.pkl
Successfully loaded processed data.
Predicted: F | True: F
Predicted: O | True: O
Predicted: K | True: H
Predicted: E | True: E
Predicted: K | True: H
Predicted: F | True: F
Predicted: C | True: T
Predicted: K | True: K
Predicted: K | True: H
Predicted: R | True: R
Predicted: L | True: L
Predicted: R | True: R
Predicted: I | True: I
Predicted: N | True: C
Predicted: E | True: Y
Predicted: D | True: D
Predicted: R | True: G
Predicted: A | True: A
Predicted: E | True: O
Predicted: R | True: B
Predicted: E | True: E
Predicted: O | True: O
Predicted: K | True: H
Predicted: E | True: E
Predicted: D | True: D
Predicted: N | True: C
Predicted: O | True: O
Predicted: K | True: H
Predicted: A | True: A
Predicted: L | True: L

Test Accuracy: 66.43%

Classification Report:
              precision    recall