# Basic Usage of STDP-based MNIST Classification

This notebook demonstrates the basic functionality of the STDP-based MNIST classification system.

## Table of Contents:
1. Setup and Imports
2. Loading and Visualizing MNIST Data
3. Training a Small Network
4. Testing and Visualization
5. Weight Analysis

## 1. Setup and Imports

First, let's import all necessary libraries and set up our environment.

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
from functions.data import get_labeled_data, get_data_subset
from functions.quick_analysis import quick_analyze

# Configure Brian2
prefs.codegen.target = 'cython'

# Enable interactive plotting
%matplotlib inline

## 2. Loading and Visualizing MNIST Data

Let's load a small subset of MNIST data and visualize some examples.

In [None]:
# Load training data
training = get_labeled_data('training', bTrain=True, MNIST_data_path='../mnist/')

# Display some examples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(training['x'][i].reshape(28, 28), cmap='gray')
    ax.set_title(f'Label: {training["y"][i]}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## 3. Training a Small Network

Now let's train a network on a small subset of data to demonstrate the process.

In [None]:
# Import the main script as a module
import importlib.util
spec = importlib.util.spec_from_file_location("mnist_stdp", "../diehl_cook_spiking_mnist_brian2.py")
mnist_stdp = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mnist_stdp)

# Configure for a quick training run
mnist_stdp.args.train = True
mnist_stdp.args.test = False
mnist_stdp.args.train_size = 1000
mnist_stdp.args.epochs = 1
mnist_stdp.args.verbose = True

# Start training
mnist_stdp.main()

## 4. Testing and Visualization

Let's test our trained network on a small test set and visualize the results.

In [None]:
# Configure for testing
mnist_stdp.args.train = False
mnist_stdp.args.test = True
mnist_stdp.args.test_size = 100

# Run tests
results = mnist_stdp.main()

# Analyze results
quick_analyze(results)

## 5. Weight Analysis

Finally, let's analyze the learned weights to understand what patterns the network has learned.

In [None]:
# Load and visualize weights
weights = np.load('../weights/random/XeAe.npy', allow_pickle=True)
w_matrix = np.zeros((784, 400))
for i, j, w in weights:
    w_matrix[int(i), int(j)] = w

# Plot some weight patterns
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(w_matrix[:, i].reshape(28, 28), cmap='hot_r')
    ax.axis('off')
plt.suptitle('Learned Weight Patterns')
plt.tight_layout()
plt.show()