# STDP-based MNIST Classification Tutorial

This notebook demonstrates how to use the STDP-based spiking neural network for MNIST digit classification.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
from functions.data import get_labeled_data

## 1. Load and Visualize MNIST Data

In [None]:
# Load training data
training = get_labeled_data('mnist/training')

# Display some examples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(training['x'][i], cmap='gray')
    ax.axis('off')
    ax.set_title(f'Label: {training["y"][i][0]}')
plt.tight_layout()
plt.show()

## 2. Network Parameters

In [None]:
# Network size
n_input = 784  # 28x28 input
n_e = 400      # excitatory neurons
n_i = n_e      # inhibitory neurons

# Neuron parameters
v_rest_e = -65. * mV
v_rest_i = -60. * mV
v_reset_e = -65. * mV
v_reset_i = -45. * mV
v_thresh_e = -52. * mV
v_thresh_i = -40. * mV
refrac_e = 5. * ms
refrac_i = 2. * ms

# STDP parameters
tc_pre_ee = 20*ms
tc_post_1_ee = 20*ms
tc_post_2_ee = 40*ms
nu_ee_pre = 0.0001      # learning rate
nu_ee_post = 0.01       # learning rate
wmax_ee = 1.0
exp_ee_pre = 0.2
exp_ee_post = exp_ee_pre

## 3. Examine Learned Weights

In [None]:
# Load pretrained weights
weights = np.load('../weights/XeAe.npy')

# Reshape and visualize some weight patterns
n_plots = 25
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    if i < n_plots:
        w = weights[:, i].reshape(28, 28)
        ax.imshow(w, cmap='viridis')
        ax.axis('off')
plt.tight_layout()
plt.show()

## 4. Test Network Performance

In [None]:
# Load test data
testing = get_labeled_data('mnist/testing', bTrain=False)

# Run test examples and analyze results
# Note: This will take some time to run
# You can modify the number of test examples to run
n_test = 100  # Number of test examples to run