In [None]:
# !conda install scipy -y
# !conda install seaborn -y
# !conda install sklearn-pandas -y



In [None]:
# Add these imports to your main.py
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import numpy as np
import torch
from torch.utils.data import DataLoader



def test_model(net, testloader, criterion, device='cpu'):
    """Test the model and return predictions, labels, and loss"""
    net.eval()
    test_loss = 0.0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            # Get predictions
            _, predicted = torch.max(outputs, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_test_loss = test_loss / len(testloader)
    accuracy = accuracy_score(all_labels, all_predictions)
    
    return all_predictions, all_labels, avg_test_loss, accuracy

def analyze_results(predictions, true_labels, label_map, model_name="CNN"):
    """Comprehensive analysis of model results"""
    
    # Reverse label mapping for display
    reverse_label_map = {v: k for k, v in label_map.items()}
    label_names = [reverse_label_map[i] for i in range(len(label_map))]
    
    print(f"\n{'='*50}")
    print(f"  {model_name} MODEL ANALYSIS")
    print(f"{'='*50}")
    
    # Overall accuracy
    accuracy = accuracy_score(true_labels, predictions)
    print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    # Classification report
    print(f"\nDetailed Classification Report:")
    print("-" * 40)
    report = classification_report(true_labels, predictions, 
                                 target_names=label_names, 
                                 digits=4)
    print(report)
    
    # Confusion Matrix
    cm = confusion_matrix(true_labels, predictions)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=label_names, yticklabels=label_names)
    plt.title(f'{model_name} - Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()
    
    # Per-class analysis
    print(f"\nPer-Class Analysis:")
    print("-" * 40)
    for i, label_name in enumerate(label_names):
        true_positives = cm[i, i]
        total_actual = np.sum(cm[i, :])
        total_predicted = np.sum(cm[:, i])
        
        if total_actual > 0:
            recall = true_positives / total_actual
            precision = true_positives / total_predicted if total_predicted > 0 else 0
            print(f"{label_name:15s}: Precision={precision:.3f}, Recall={recall:.3f}, Count={total_actual}")
    
    # Identify problematic classes
    print(f"\nMost Confused Classes:")
    print("-" * 40)
    
    # Find highest off-diagonal values in confusion matrix
    confusion_pairs = []
    for i in range(len(label_names)):
        for j in range(len(label_names)):
            if i != j and cm[i, j] > 0:
                confusion_pairs.append((cm[i, j], label_names[i], label_names[j]))
    
    confusion_pairs.sort(reverse=True)
    for count, true_class, pred_class in confusion_pairs[:5]:
        print(f"{true_class} → {pred_class}: {count} misclassifications")
    
    return accuracy, cm

def plot_training_history(train_losses, train_accuracies=None):
    """Plot training history if you tracked it"""
    fig, axes = plt.subplots(1, 2 if train_accuracies else 1, figsize=(12, 4))
    
    if train_accuracies:
        axes[0].plot(train_losses)
        axes[0].set_title('Training Loss')
        axes[0].set_xlabel('Epoch/Batch')
        axes[0].set_ylabel('Loss')
        
        axes[1].plot(train_accuracies)
        axes[1].set_title('Training Accuracy')
        axes[1].set_xlabel('Epoch/Batch')
        axes[1].set_ylabel('Accuracy')
    else:
        axes.plot(train_losses)
        axes.set_title('Training Loss')
        axes.set_xlabel('Epoch/Batch')
        axes.set_ylabel('Loss')
    
    plt.tight_layout()
    plt.show()

def model_summary(net):
    """Print model architecture summary"""
    print(f"\n{'='*50}")
    print(f"  MODEL ARCHITECTURE SUMMARY")
    print(f"{'='*50}")
    
    total_params = sum(p.numel() for p in net.parameters())
    trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Model Size: ~{total_params * 4 / 1024 / 1024:.2f} MB")
    
    print(f"\nLayer Details:")
    print("-" * 40)
    for name, module in net.named_modules():
        if len(list(module.children())) == 0:  # Only leaf modules
            num_params = sum(p.numel() for p in module.parameters())
            if num_params > 0:
                print(f"{name:20s}: {str(module):50s} | {num_params:,} params")

# Usage - Add this after your training is complete:
print("Testing the model...")

# Test the model
predictions, true_labels, test_loss, test_accuracy = test_model(net, testloader, criterion)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")

# Comprehensive analysis
analyze_results(predictions, true_labels, LABEL_MAP, "Intra-Subject CNN")

# Load the saved model
net = MEGNet(num_classes=len(LABEL_MAP), 
             input_channels=248, 
             input_time_steps=train_data[0].shape[1])

# Load the saved state dict
net.load_state_dict(torch.load('CNN_1.pth'))

# Set to evaluation mode for testing
net.eval()

print("Model loaded successfully from CNN_1.pth")

# Model summary
model_summary(net)

# If you tracked training history, uncomment and use:
# plot_training_history(your_training_losses)

Testing the model...


NameError: name 'net' is not defined