In [None]:
#Model Evaluation with Spike Count and Spike Timed Accuracy

import torch
import numpy as np

def evaluate_snn_performance(model, test_loader, target_label=2, threshold=0.5):
    """
    Evaluate SNN model performance by calculating accuracy, spike count error, and timing precision.
    
    Args:
    - model: Trained SNN model.
    - test_loader: DataLoader for test data.
    - target_label: Label representing the target class for spike count and accuracy.
    - threshold: Threshold to convert spikes to binary predictions.
    
    Returns:
    - accuracy: Overall accuracy of the SNN.
    - avg_spike_count_error: Average spike count difference from the expected spike count.
    - avg_spike_timing_error: Average timing error between expected and predicted spike times.
    """
    model.eval()  # Set the model to evaluation mode
    correct_predictions = 0
    total_samples = 0
    spike_count_errors = []
    spike_timing_errors = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.transpose(1, 2)  # Adjust input dimensions if necessary
            outputs = model(inputs)  # Forward pass
            outputs_sum = outputs.sum(dim=1)  # Sum spikes over time

            # Convert to binary predictions based on spike counts
            predicted_labels = (outputs_sum > threshold).float()
            correct_predictions += (predicted_labels == labels).sum().item()
            total_samples += labels.size(0)

            # Calculate spike count error (difference between expected and actual spike counts)
            expected_spikes = (labels == target_label).float()
            spike_count_error = torch.abs(outputs_sum - expected_spikes).mean().item()
            spike_count_errors.append(spike_count_error)

            # Calculate spike timing accuracy (if you have ground truth spike timings)
            # Assuming you have true spike times per label, calculate timing differences
            # For now, we just store placeholder timing accuracy as zero.
            spike_timing_errors.append(0)  # Replace this with actual timing calculation

    # Calculate overall accuracy
    accuracy = correct_predictions / total_samples
    avg_spike_count_error = np.mean(spike_count_errors)
    avg_spike_timing_error = np.mean(spike_timing_errors)

    return accuracy, avg_spike_count_error, avg_spike_timing_error

# Example usage
accuracy, avg_spike_count_error, avg_spike_timing_error = evaluate_snn_performance(model, test_loader)

print(f"Accuracy: {accuracy:.2f}")
print(f"Average Spike Count Error: {avg_spike_count_error:.2f}")
print(f"Average Spike Timing Error: {avg_spike_timing_error:.2f}")


In [None]:
# Precision , Recall and F1 score (Binary Classification)

def precision_recall_f1(test_outputs, test_labels):
    """
    Calculate precision, recall, and F1 score for binary classification without using external libraries.
    
    Args:
    - test_outputs: List of predicted spike outputs (binary).
    - test_labels: List of ground truth labels (binary).
    
    Returns:
    - precision: Precision score.
    - recall: Recall score.
    - f1_score: F1 score.
    """
    tp = np.sum((test_outputs == 1) & (test_labels == 1))
    fp = np.sum((test_outputs == 1) & (test_labels == 0))
    fn = np.sum((test_outputs == 0) & (test_labels == 1))

    precision = tp / (tp + fp + 1e-6)  # Add small value to avoid division by zero
    recall = tp / (tp + fn + 1e-6)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)
    
    return precision, recall, f1_score

# Assuming test_outputs and test_labels are collected as binary arrays
precision, recall, f1_score = precision_recall_f1(test_outputs, test_labels)

print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")


In [None]:
#Visualisation of Input and Output 
# Visualization of input spikes
plt.figure(figsize=(10, 6))
input_spikes = spikes_tensor[:100, :10].numpy()  # Visualizing the first 100 samples and 10 features
plt.imshow(input_spikes.T, cmap='binary', interpolation='nearest')
plt.colorbar(label='Spike Activity')
plt.title('Input Spike Trains (First 100 Samples)')
plt.xlabel('Time')
plt.ylabel('Neurons/Features')
plt.show()

# Visualization of output spikes (predicted labels)
plt.figure(figsize=(10, 6))
predicted_spikes = outputs[:100].detach().numpy()  # First 100 output spike trains
plt.imshow(predicted_spikes.T, cmap='binary', interpolation='nearest')
plt.colorbar(label='Spike Activity')
plt.title('Predicted Output Spike Trains (First 100 Samples)')
plt.xlabel('Time')
plt.ylabel('Output Neurons')
plt.show()


In [None]:
#Measuring Accuracy 

correct = 0
total = 0

# Calculate accuracy
for i in range(len(test_outputs)):
    predicted_label = 1 if test_outputs[i] > 0.5 else 0  # Assuming binary classification and sigmoid activation
    if predicted_label == test_labels[i]:
        correct += 1
    total += 1

accuracy = correct / total
print(f"Accuracy: {accuracy:.2f}")
