# Image Encoding and Network Processing Example

This notebook demonstrates how to encode images and process them through a spiking neural network.

## Overview
1. Create or load an image
2. Encode it using retinal encoding (ON/OFF channels)
3. Convert to spike trains
4. Build a visual processing network
5. Process the image through the network
6. Visualize results

In [None]:
# Setup and imports
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any

# Add parent directory to path
sys.path.insert(0, os.path.abspath('../..'))

from core.encoding import RetinalEncoder, RateEncoder
from core.network import NeuromorphicNetwork, NetworkBuilder

# Configure matplotlib for inline display
%matplotlib inline
plt.rcParams['figure.dpi'] = 100

## Step 1: Create a Sample Image

We'll create a simple synthetic image with basic patterns for testing.

In [None]:
def create_sample_image(size=(64, 64)):
    """Create a sample image with simple patterns."""
    image = np.zeros(size, dtype=np.uint8)
    
    # Vertical line
    image[:, size[1]//2-2:size[1]//2+2] = 255
    
    # Horizontal line
    image[size[0]//2-2:size[0]//2+2, :] = 255
    
    # Circle in the center
    center = (size[0]//2, size[1]//2)
    radius = min(size) // 4
    y, x = np.ogrid[:size[0], :size[1]]
    mask = (x - center[1])**2 + (y - center[0])**2 <= radius**2
    image[mask] = 200
    
    return image

# Create the image
image = create_sample_image(size=(64, 64))

# Display it
plt.figure(figsize=(6, 6))
plt.imshow(image, cmap='gray')
plt.title('Sample Input Image')
plt.colorbar()
plt.axis('off')
plt.show()

print(f"Image shape: {image.shape}")
print(f"Value range: [{image.min()}, {image.max()}]")

## Step 2: Retinal Encoding

Apply retinal encoding to extract ON-center/OFF-surround and OFF-center/ON-surround responses.

In [None]:
# Initialize retinal encoder
encoder = RetinalEncoder(resolution=(32, 32))

# Encode the image
encoded = encoder.encode(image)

# Visualize the encoded representations
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image, cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

im1 = axes[1].imshow(encoded['on_center'], cmap='hot')
axes[1].set_title('ON-Center Response')
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.046)

im2 = axes[2].imshow(encoded['off_center'], cmap='hot')
axes[2].set_title('OFF-Center Response')
axes[2].axis('off')
plt.colorbar(im2, ax=axes[2], fraction=0.046)

plt.suptitle('Retinal Encoding', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"ON-center shape: {encoded['on_center'].shape}")
print(f"OFF-center shape: {encoded['off_center'].shape}")

## Step 3: Convert to Spike Trains

Convert the encoded image intensities to spike trains using rate encoding.

In [None]:
# Initialize rate encoder
rate_encoder = RateEncoder(max_rate=100.0)  # 100 Hz max firing rate

# Normalize images to 0-1 range
on_normalized = encoded['on_center'] / 255.0
off_normalized = encoded['off_center'] / 255.0

# Generate spike trains for 100ms simulation
duration = 100.0  # ms
on_spikes = rate_encoder.encode_array(on_normalized, duration=duration)
off_spikes = rate_encoder.encode_array(off_normalized, duration=duration)

print(f"Generated {len(on_spikes)} ON-channel spikes")
print(f"Generated {len(off_spikes)} OFF-channel spikes")

# Visualize spike statistics
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# ON channel spikes
on_neurons = [spike[0] for spike in on_spikes]
on_times = [spike[1] for spike in on_spikes]

axes[0].scatter(on_times[:500], on_neurons[:500], s=1, alpha=0.5)
axes[0].set_xlabel('Time (ms)')
axes[0].set_ylabel('Neuron ID')
axes[0].set_title(f'ON Channel Spikes (first 500/{len(on_spikes)})')
axes[0].set_xlim(0, duration)

# OFF channel spikes
off_neurons = [spike[0] for spike in off_spikes]
off_times = [spike[1] for spike in off_spikes]

axes[1].scatter(off_times[:500], off_neurons[:500], s=1, alpha=0.5)
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Neuron ID')
axes[1].set_title(f'OFF Channel Spikes (first 500/{len(off_spikes)})')
axes[1].set_xlim(0, duration)

plt.tight_layout()
plt.show()

## Step 4: Build Visual Processing Network

Create a hierarchical network inspired by the visual cortex.

In [None]:
# Build the network
builder = NetworkBuilder()

# Add layers
builder.add_sensory_layer("retinal", size=1024)  # 32x32 input
builder.add_processing_layer("V1", size=256, neuron_type="lif")
builder.add_processing_layer("V2", size=64, neuron_type="adex")
builder.add_motor_layer("output", size=10)

# Connect layers
builder.connect_layers("retinal", "V1", 
                      connection_type="feedforward",
                      synapse_type="stdp",
                      connection_probability=0.1)

builder.connect_layers("V1", "V2",
                      connection_type="feedforward", 
                      synapse_type="stdp",
                      connection_probability=0.2)

builder.connect_layers("V2", "output",
                      connection_type="feedforward",
                      synapse_type="stdp",
                      connection_probability=0.3)

# Add lateral inhibition in V1
builder.connect_layers("V1", "V1",
                      connection_type="lateral",
                      synapse_type="stp",
                      connection_probability=0.05)

# Build the network
network = builder.build()

# Display network info
info = network.get_network_info()
print("Network Architecture:")
print("=" * 40)
print(f"Total neurons: {info['total_neurons']}")
print(f"Total synapses: {info['total_synapses']}")
print("\nLayers:")
for name, layer_info in info['layers'].items():
    print(f"  {name}: {layer_info['size']} neurons ({layer_info['neuron_type']})")
print("\nConnections:")
for conn_name, conn_info in info['connections'].items():
    print(f"  {conn_name}: {conn_info['synapse_type']} (p={conn_info['connection_probability']})")

## Step 5: Run Network Simulation

Process the encoded image through the network.

In [None]:
# Run the simulation
print("Running network simulation...")
results = network.run_simulation(duration=duration, dt=0.1)

print(f"\nSimulation completed: {results['final_time']} ms")

# Analyze spike activity
spike_counts = {}
for layer_name, spike_times in results['layer_spike_times'].items():
    total_spikes = sum(len(times) for times in spike_times)
    spike_counts[layer_name] = total_spikes
    avg_rate = (total_spikes / len(spike_times)) / (duration / 1000.0) if len(spike_times) > 0 else 0
    print(f"{layer_name}: {total_spikes} spikes, avg rate: {avg_rate:.1f} Hz")

## Step 6: Visualize Network Activity

Create raster plots and analyze the network's response to the input image.

In [None]:
# Create comprehensive visualization
fig = plt.figure(figsize=(15, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Original and encoded images
ax1 = fig.add_subplot(gs[0, 0])
ax1.imshow(image, cmap='gray')
ax1.set_title('Original Image')
ax1.axis('off')

ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(encoded['on_center'], cmap='hot')
ax2.set_title('ON-Center')
ax2.axis('off')

ax3 = fig.add_subplot(gs[0, 2])
ax3.imshow(encoded['off_center'], cmap='hot')
ax3.set_title('OFF-Center')
ax3.axis('off')

# Spike raster plots for each layer
layer_names = ['retinal', 'V1', 'V2']
for idx, layer_name in enumerate(layer_names):
    ax = fig.add_subplot(gs[1, idx])
    
    if layer_name in results['layer_spike_times']:
        spike_times = results['layer_spike_times'][layer_name]
        
        # Plot first N neurons
        n_neurons_to_plot = min(50, len(spike_times))
        for neuron_id in range(n_neurons_to_plot):
            times = spike_times[neuron_id]
            if times:
                ax.scatter(times, [neuron_id] * len(times), s=0.5, c='black')
        
        ax.set_xlabel('Time (ms)')
        ax.set_ylabel('Neuron ID')
        ax.set_title(f'{layer_name} Layer')
        ax.set_xlim(0, duration)
        ax.set_ylim(0, n_neurons_to_plot)

# Spike count histogram
ax4 = fig.add_subplot(gs[2, 0])
layers = list(spike_counts.keys())
counts = list(spike_counts.values())
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(layers)))
bars = ax4.bar(layers, counts, color=colors)
ax4.set_ylabel('Total Spikes')
ax4.set_title('Spike Counts by Layer')
ax4.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, count in zip(bars, counts):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height,
            f'{int(count)}', ha='center', va='bottom')

# Weight matrix visualization (if available)
ax5 = fig.add_subplot(gs[2, 1:])
weight_data = []
labels = []
for conn_name, weight_matrix in results['weight_matrices'].items():
    if weight_matrix is not None:
        weight_data.append(weight_matrix.flatten())
        labels.append(conn_name)

if weight_data:
    ax5.boxplot(weight_data, labels=labels)
    ax5.set_ylabel('Weight Value')
    ax5.set_title('Weight Distributions')
    ax5.tick_params(axis='x', rotation=45)
    ax5.grid(True, alpha=0.3)

plt.suptitle('Image Processing Through Spiking Neural Network', fontsize=16, fontweight='bold')
plt.show()

## Analysis: Receptive Fields

Analyze what features the network has learned to detect.

In [None]:
# Analyze weight matrices to understand learned features
print("Weight Matrix Analysis:")
print("=" * 40)

for conn_name, weight_matrix in results['weight_matrices'].items():
    if weight_matrix is not None:
        print(f"\nConnection: {conn_name}")
        print(f"  Shape: {weight_matrix.shape}")
        print(f"  Mean weight: {np.mean(weight_matrix):.4f}")
        print(f"  Std deviation: {np.std(weight_matrix):.4f}")
        print(f"  Min weight: {np.min(weight_matrix):.4f}")
        print(f"  Max weight: {np.max(weight_matrix):.4f}")
        print(f"  Sparsity: {np.mean(weight_matrix == 0):.2%}")

## Interactive Experiment: Different Image Patterns

Try different input patterns to see how the network responds.

In [None]:
def test_pattern(pattern_type='vertical'):
    """Test network response to different patterns."""
    
    # Create pattern
    test_image = np.zeros((64, 64), dtype=np.uint8)
    
    if pattern_type == 'vertical':
        test_image[:, 30:34] = 255
    elif pattern_type == 'horizontal':
        test_image[30:34, :] = 255
    elif pattern_type == 'diagonal':
        for i in range(64):
            if 0 <= i < 64:
                test_image[i, i] = 255
                if i > 0:
                    test_image[i-1, i] = 200
                if i < 63:
                    test_image[i+1, i] = 200
    elif pattern_type == 'random':
        test_image = np.random.randint(0, 256, (64, 64), dtype=np.uint8)
    
    # Encode and process
    encoded_test = encoder.encode(test_image)
    
    # Quick visualization
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(test_image, cmap='gray')
    axes[0].set_title(f'{pattern_type.capitalize()} Pattern')
    axes[0].axis('off')
    
    axes[1].imshow(encoded_test['on_center'], cmap='hot')
    axes[1].set_title('ON Response')
    axes[1].axis('off')
    
    axes[2].imshow(encoded_test['off_center'], cmap='hot')
    axes[2].set_title('OFF Response')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return encoded_test

# Test different patterns
patterns = ['vertical', 'horizontal', 'diagonal', 'random']
for pattern in patterns:
    print(f"\nTesting {pattern} pattern:")
    encoded_test = test_pattern(pattern)

## Conclusions

This notebook demonstrated:

1. **Retinal Encoding**: How images are transformed into ON/OFF center-surround responses
2. **Rate Coding**: Converting pixel intensities into spike trains
3. **Hierarchical Processing**: Building networks inspired by the visual cortex
4. **STDP Learning**: How synaptic weights adapt based on spike timing
5. **Network Dynamics**: Visualizing how information flows through the network

### Next Steps

- Try loading real images using PIL or OpenCV
- Experiment with different network architectures
- Implement feature detectors (edge, corner, etc.)
- Add feedback connections for top-down processing
- Test with video sequences for temporal processing