# Tutorial 1: MNIST Digit Classification with VSA

This tutorial demonstrates how to use VSAX for image classification using the MNIST digits dataset.

## What You'll Learn

- How to encode images as hypervectors
- How to create class prototypes using VSA
- How to perform similarity-based classification
- How to compare different VSA models (FHRR, MAP, Binary)

## Why VSA for Classification?

Vector Symbolic Architectures offer a unique approach to classification:
- **Interpretable**: Class representations are explicit hypervectors
- **Few-shot learning**: Can learn from few examples per class
- **Compositional**: Can combine features naturally
- **Efficient**: GPU-accelerated with JAX

## Setup

First, let's import the necessary libraries.

In [None]:
import sys

sys.path.insert(0, '../..')

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from vsax import VSAMemory, create_binary_model, create_fhrr_model, create_map_model

# Import our dataset helpers
sys.path.insert(0, '../utils')
from dataset_helpers import load_mnist_digits, normalize_images

print("✓ Imports successful")

## Load and Explore MNIST Data

We'll use scikit-learn's digits dataset (8x8 images of handwritten digits).

In [None]:
# Load MNIST digits
X_train, X_test, y_train, y_test = load_mnist_digits()

# Normalize to [0, 1]
X_train = normalize_images(X_train)
X_test = normalize_images(X_test)

print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Image dimensions: {X_train.shape[1]} pixels (8x8 flattened)")
print(f"Classes: {np.unique(y_train)}")

In [None]:
# Visualize some examples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flat):
    ax.imshow(X_train[i].reshape(8, 8), cmap='gray')
    ax.set_title(f"Digit: {y_train[i]}")
    ax.axis('off')
plt.tight_layout()
plt.show()

## VSA-Based Classification

### Step 1: Create VSA Model

Let's start with the FHRR model (complex hypervectors with exact unbinding).

In [None]:
# Create FHRR model with 1024 dimensions
model = create_fhrr_model(dim=1024)
memory = VSAMemory(model)

print(f"Model: {model.rep_cls.__name__}")
print(f"Operations: {model.opset.__class__.__name__}")
print(f"Dimension: {model.dim}")

### Step 2: Encode Images as Hypervectors

Each image is encoded by:
1. Creating a random basis hypervector for each pixel position
2. Scaling each basis vector by the pixel intensity
3. Bundling all scaled pixel vectors together

In [None]:
# Create basis vectors for each of the 64 pixel positions
pixel_names = [f"pixel_{i}" for i in range(64)]
memory.add_many(pixel_names)

print(f"Created {len(pixel_names)} pixel basis vectors")

In [None]:
def encode_image(image, model, memory):
    """Encode an image as a hypervector.
    
    Args:
        image: Flattened image array (64,)
        model: VSAModel instance
        memory: VSAMemory with pixel basis vectors
    
    Returns:
        Encoded hypervector
    """
    # Get all pixel basis vectors
    pixel_vecs = [memory[f"pixel_{i}"].vec for i in range(64)]

    # Scale each pixel vector by intensity and bundle
    scaled_vecs = []
    for i, intensity in enumerate(image):
        if intensity > 0:  # Only include active pixels
            # For complex vectors, scale by multiplying
            scaled = pixel_vecs[i] * intensity
            scaled_vecs.append(scaled)

    if len(scaled_vecs) == 0:
        # Return zero vector for blank images
        return jnp.zeros(model.dim, dtype=pixel_vecs[0].dtype)

    # Bundle all scaled pixel vectors
    return model.opset.bundle(*scaled_vecs)

# Test encoding
test_encoding = encode_image(X_train[0], model, memory)
print(f"Encoded shape: {test_encoding.shape}")
print(f"Encoded dtype: {test_encoding.dtype}")

### Step 3: Create Class Prototypes

For each digit class (0-9), we create a prototype by averaging the encodings of all training examples.

In [None]:
# Encode all training images
print("Encoding training images...")
train_encodings = []
for img in X_train:
    train_encodings.append(encode_image(img, model, memory))
train_encodings = jnp.stack(train_encodings)

print(f"Encoded {len(train_encodings)} training images")

In [None]:
# Create prototype for each digit class
prototypes = {}
for digit in range(10):
    # Get all encodings for this digit
    digit_mask = y_train == digit
    digit_encodings = train_encodings[digit_mask]

    # Average to create prototype
    prototype = model.opset.bundle(*digit_encodings)
    prototypes[digit] = prototype

    print(f"Digit {digit}: {digit_mask.sum()} training examples")

print(f"\nCreated {len(prototypes)} class prototypes")

### Step 4: Classify Test Images

Classification is done by finding the most similar prototype using cosine similarity.

In [None]:
def classify_image(image, model, memory, prototypes):
    """Classify an image using prototype matching.
    
    Args:
        image: Flattened image array
        model: VSAModel instance
        memory: VSAMemory with pixel basis vectors
        prototypes: Dict mapping digit -> prototype vector
    
    Returns:
        Predicted digit (0-9)
    """
    # Encode the test image
    encoding = encode_image(image, model, memory)

    # Compute similarity to each prototype
    similarities = {}
    for digit, prototype in prototypes.items():
        # For complex vectors, use absolute value of dot product
        sim = jnp.abs(jnp.vdot(encoding, prototype))
        similarities[digit] = float(sim)

    # Return digit with highest similarity
    return max(similarities, key=similarities.get)

# Test on a few examples
print("Testing classification on first 5 test images:")
for i in range(5):
    pred = classify_image(X_test[i], model, memory, prototypes)
    print(f"  Image {i}: True={y_test[i]}, Predicted={pred}, {'✓' if pred == y_test[i] else '✗'}")

In [None]:
# Classify all test images
print("Classifying all test images...")
predictions = []
for img in X_test:
    pred = classify_image(img, model, memory, prototypes)
    predictions.append(pred)

predictions = np.array(predictions)
accuracy = accuracy_score(y_test, predictions)

print(f"\nTest Accuracy: {accuracy:.2%}")

### Step 5: Evaluate Performance

In [None]:
# Classification report
print("Classification Report:")
print(classification_report(y_test, predictions))

In [None]:
# Confusion matrix
cm = confusion_matrix(y_test, predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=range(10), yticklabels=range(10))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'FHRR Model - Confusion Matrix (Accuracy: {accuracy:.2%})')
plt.show()

## Compare Different VSA Models

Let's compare FHRR, MAP, and Binary models on the same task.

In [None]:
def evaluate_model(model_name, model_fn, dim):
    """Evaluate a VSA model on MNIST classification.
    
    Args:
        model_name: Name of the model
        model_fn: Factory function to create model
        dim: Dimension for the model
    
    Returns:
        Test accuracy
    """
    print(f"\n{'='*60}")
    print(f"Evaluating {model_name} Model (dim={dim})")
    print('='*60)

    # Create model and memory
    model = model_fn(dim=dim)
    memory = VSAMemory(model)
    memory.add_many([f"pixel_{i}" for i in range(64)])

    # Encode training images and create prototypes
    print("Encoding training images...")
    train_encodings = [encode_image(img, model, memory) for img in X_train]
    train_encodings = jnp.stack(train_encodings)

    prototypes = {}
    for digit in range(10):
        digit_mask = y_train == digit
        digit_encodings = train_encodings[digit_mask]
        prototypes[digit] = model.opset.bundle(*digit_encodings)

    # Classify test images
    print("Classifying test images...")
    predictions = [classify_image(img, model, memory, prototypes) for img in X_test]
    predictions = np.array(predictions)

    accuracy = accuracy_score(y_test, predictions)
    print(f"Test Accuracy: {accuracy:.2%}")

    return accuracy

In [None]:
# Compare models
results = {}

results['FHRR'] = evaluate_model('FHRR', create_fhrr_model, dim=1024)
results['MAP'] = evaluate_model('MAP', create_map_model, dim=1024)
results['Binary'] = evaluate_model('Binary', create_binary_model, dim=10000)

In [None]:
# Plot comparison
plt.figure(figsize=(10, 6))
models = list(results.keys())
accuracies = [results[m] for m in models]

bars = plt.bar(models, accuracies, color=['#3498db', '#2ecc71', '#e74c3c'])
plt.ylabel('Test Accuracy', fontsize=12)
plt.title('VSA Model Comparison on MNIST Digits', fontsize=14, fontweight='bold')
plt.ylim([0, 1])

# Add accuracy labels on bars
for bar, acc in zip(bars, accuracies):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{acc:.2%}', ha='center', va='bottom', fontsize=11)

plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

print("\nFinal Results:")
for model, acc in results.items():
    print(f"  {model}: {acc:.2%}")

## Key Takeaways

1. **VSA for Classification**: We successfully classified MNIST digits using prototype-based VSA classification
2. **Simple Approach**: The method is straightforward - encode images, create prototypes, match by similarity
3. **Model Comparison**: Different VSA models (FHRR, MAP, Binary) show competitive performance
4. **Interpretable**: Each class has an explicit prototype hypervector that represents it
5. **Scalable**: JAX makes this GPU-accelerated and efficient for larger datasets

## Next Steps

- Try different encoding strategies (e.g., using ScalarEncoder)
- Experiment with different dimensions
- Use fewer training examples (few-shot learning)
- Try on full MNIST (28x28 images)
- Explore Tutorial 2: Knowledge Graph Reasoning