# Event-Based Eye Tracking Tutorial

This notebook introduces event-based eye tracking and shows how to:
1. Load and visualize event camera data from the **Ini-30** and **3ET** datasets
2. Understand the eye tracking problem formulation
3. Understand Speck hardware constraints

**Reference resources** (for learning, not copying!):
- [RETINA repository](https://github.com/pbonazzi/retina) - Reference SNN implementation
- [3ET repository](https://github.com/qinche106/cb-convlstm-eyetracking) - Reference ConvLSTM implementation
- [Sinabs tutorials](https://sinabs.readthedocs.io/v3.1.1/tutorials/tutorials.html) - SNN training
- [Speck tutorials](https://sinabs.readthedocs.io/v3.1.1/speck/tutorials.html) - Hardware deployment

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# You can also set paths directly here
INI30_PATH = os.getenv("INI30_DATA_PATH", "./retina/evs_ini30")
THREET_PATH = os.getenv("THREET_DATA_PATH", "./3et_data")

## 1. What is Event-Based Vision?

Traditional cameras capture frames at fixed intervals (e.g., 30 FPS). Event cameras (Dynamic Vision Sensors - DVS) work differently:

- Each pixel **independently** detects changes in brightness
- When brightness changes by a threshold, the pixel emits an **event**
- Events are **asynchronous** with microsecond resolution

An event is a tuple: `(t, x, y, p)`
- `t`: timestamp (microseconds)
- `x, y`: pixel coordinates
- `p`: polarity (+1 for brightness increase, -1 for decrease)

**Advantages for eye tracking:**
- High temporal resolution (~1 MHz vs 30-120 Hz)
- Low latency
- Low power consumption
- No motion blur

## 2. Loading Ini-30 Dataset

The Ini-30 dataset contains:
- 30 subjects with real DVS eye recordings
- Events stored in `.aedat4` format
- Labels: pupil center coordinates at ~50 Hz

Download from: https://zenodo.org/records/11203260

In [None]:
import pandas as pd

# Check dataset structure
ini30_path = Path(INI30_PATH)
print(f"Dataset path: {ini30_path}")
print(f"\nSubjects found: {sorted([d.name for d in ini30_path.iterdir() if d.is_dir()])}")

In [None]:
# Load the master annotations file
silver_csv = ini30_path / "silver.csv"
labels_df = pd.read_csv(silver_csv, delimiter='\t')
print(f"Total samples: {len(labels_df)}")
print(f"\nColumns: {labels_df.columns.tolist()}")
print(f"\nSample entries:")
labels_df.head()

In [None]:
# Load events from one subject using dv-processing
try:
    import dv_processing as dv
    
    subject = "ID_001"
    aedat_path = ini30_path / subject / "events.aedat4"
    
    # Open the aedat4 file
    reader = dv.io.MonoCameraRecording(str(aedat_path))
    
    print(f"Camera name: {reader.getCameraName()}")
    print(f"Resolution: {reader.getEventResolution()}")
    
    # Read first batch of events
    events = reader.getNextEventBatch()
    print(f"\nFirst batch: {len(events)} events")
    print(f"Time range: {events.timestamps()[0]} - {events.timestamps()[-1]} us")
    
except ImportError:
    print("dv-processing not installed. Install with: pip install dv-processing")
    print("Alternative: use the code from the RETINA repository")

In [None]:
# Visualize events as an accumulated frame
try:
    import dv_processing as dv
    
    # Read more events for visualization
    reader = dv.io.MonoCameraRecording(str(aedat_path))
    
    # Accumulate events over 50ms
    accumulator = dv.Accumulator(reader.getEventResolution())
    accumulator.setMinPotential(0.0)
    accumulator.setMaxPotential(1.0)
    accumulator.setNeutralPotential(0.5)
    accumulator.setEventContribution(0.15)
    accumulator.setDecayFunction(dv.Accumulator.Decay.LINEAR)
    accumulator.setDecayParam(1e-6)
    
    # Read events for ~50ms
    all_events = dv.EventStore()
    start_time = None
    while reader.isRunning():
        events = reader.getNextEventBatch()
        if events is None:
            break
        all_events.add(events)
        if start_time is None:
            start_time = events.timestamps()[0]
        if events.timestamps()[-1] - start_time > 50000:  # 50ms
            break
    
    accumulator.accept(all_events)
    frame = accumulator.generateFrame()
    
    # Load corresponding label
    annotations = pd.read_csv(ini30_path / subject / "annotations.csv")
    first_label = annotations.iloc[0]
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Accumulated frame
    axes[0].imshow(frame.image, cmap='gray')
    axes[0].scatter([first_label['center_x']], [first_label['center_y']], 
                    c='red', s=100, marker='x', linewidths=2, label='Pupil center')
    axes[0].set_title(f'Accumulated Events (50ms) - {subject}')
    axes[0].legend()
    
    # Event scatter plot
    coords = all_events.coordinates()
    pols = all_events.polarities()
    sample_idx = np.random.choice(len(coords), min(5000, len(coords)), replace=False)
    
    axes[1].scatter(coords[sample_idx, 0], coords[sample_idx, 1], 
                    c=pols[sample_idx], cmap='coolwarm', s=1, alpha=0.5)
    axes[1].scatter([first_label['center_x']], [first_label['center_y']], 
                    c='green', s=100, marker='x', linewidths=2)
    axes[1].set_xlim(0, 640)
    axes[1].set_ylim(480, 0)
    axes[1].set_title('Event Scatter (sampled, red=ON, blue=OFF)')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nLabel coordinates: ({first_label['center_x']:.1f}, {first_label['center_y']:.1f})")

except Exception as e:
    print(f"Visualization error: {e}")

## 3. Loading 3ET Synthetic Dataset

The 3ET dataset contains:
- 22 subjects with synthetic event-based eye recordings
- Events stored in `.h5` format (pre-processed event frames)
- Labels: pupil center (x, y) coordinates

**Download options:**
- [H5 event frames (recommended)](https://drive.google.com/drive/folders/16qH_wv_oVNysJARtHIUrIXbHjOygfq_i)
- [AEDAT raw + video files](https://drive.google.com/drive/folders/1HeOS5YBLruzHjwMKyBQfVTc_mJbsy_R1)

**Repository:** https://github.com/qinche106/cb-convlstm-eyetracking

In [None]:
import h5py

threet_path = Path(THREET_PATH)
print(f"3ET path: {threet_path}")

# Check if data exists
if threet_path.exists():
    # 3ET dataset structure varies - check what's available
    h5_files = list(threet_path.rglob("*.h5"))
    if h5_files:
        print(f"Found {len(h5_files)} H5 files")
        print(f"Example files: {[f.name for f in h5_files[:5]]}")
    else:
        print("No H5 files found. Download the dataset from Google Drive.")
else:
    print(f"Path {threet_path} does not exist. Set THREET_DATA_PATH in .env")
    print("Download from: https://drive.google.com/drive/folders/16qH_wv_oVNysJARtHIUrIXbHjOygfq_i")

In [None]:
# Load a sample from 3ET (if available)
# Note: 3ET H5 files contain pre-processed event frames, not raw events
try:
    h5_files = list(threet_path.rglob("*.h5"))
    if h5_files:
        h5_file = h5_files[0]
        print(f"Loading: {h5_file.name}")
        
        with h5py.File(h5_file, 'r') as f:
            print(f"\nH5 file structure:")
            for key in f.keys():
                print(f"  {key}: shape={f[key].shape}, dtype={f[key].dtype}")
            
            # Try to load and visualize
            if 'vector' in f.keys():
                data = f['vector'][:]
                print(f"\nData shape: {data.shape}")
                
                # Visualize a few frames
                if len(data.shape) >= 2:
                    fig, axes = plt.subplots(1, min(4, len(data)), figsize=(16, 4))
                    if len(data) == 1:
                        axes = [axes]
                    for i, ax in enumerate(axes):
                        frame_data = data[i].squeeze() if len(data[i].shape) > 2 else data[i]
                        ax.imshow(frame_data, cmap='gray')
                        ax.set_title(f'Frame {i}')
                        ax.axis('off')
                    plt.suptitle(f'3ET Event Frames from {h5_file.name}')
                    plt.tight_layout()
                    plt.show()
    else:
        print("No H5 files found.")
except Exception as e:
    print(f"Could not load 3ET data: {e}")

## 4. Event Representations

Raw events need to be converted into a format suitable for neural networks. Common representations:

### 4.1 Event Frames (Histogram)
Accumulate events over a time window into a 2D histogram.

```
frame[y, x] = count of events at (x, y) in time window
```

### 4.2 Voxel Grid (Temporal Bins)
Divide time into bins and create a 3D tensor.

```
voxel[t_bin, y, x] = weighted sum of events in time bin t_bin at (x, y)
```

### 4.3 Time Surfaces
Store the most recent timestamp at each pixel.

```
surface[y, x] = most recent event timestamp at (x, y)
```

In [None]:
def events_to_frame(events, height=480, width=640, separate_polarity=True):
    """
    Convert events to a frame representation.
    
    Args:
        events: structured array with fields 't', 'x', 'y', 'p'
        height, width: sensor dimensions
        separate_polarity: if True, create 2 channels (ON/OFF)
    
    Returns:
        frame: (C, H, W) tensor where C=2 if separate_polarity else C=1
    """
    if separate_polarity:
        frame = np.zeros((2, height, width), dtype=np.float32)
        
        # ON events (p=1)
        on_mask = events['p'] == 1
        np.add.at(frame[0], (events['y'][on_mask], events['x'][on_mask]), 1)
        
        # OFF events (p=0 or p=-1)
        off_mask = ~on_mask
        np.add.at(frame[1], (events['y'][off_mask], events['x'][off_mask]), 1)
    else:
        frame = np.zeros((1, height, width), dtype=np.float32)
        np.add.at(frame[0], (events['y'], events['x']), 1)
    
    return frame


def events_to_voxel_grid(events, num_bins=5, height=480, width=640):
    """
    Convert events to a voxel grid representation.
    
    Args:
        events: structured array with fields 't', 'x', 'y', 'p'
        num_bins: number of temporal bins
        height, width: sensor dimensions
    
    Returns:
        voxel: (num_bins, H, W) tensor
    """
    voxel = np.zeros((num_bins, height, width), dtype=np.float32)
    
    t = events['t'].astype(np.float32)
    t_normalized = (t - t.min()) / (t.max() - t.min() + 1e-6)  # Normalize to [0, 1]
    t_bins = (t_normalized * (num_bins - 1)).astype(np.int32)
    t_bins = np.clip(t_bins, 0, num_bins - 1)
    
    # Polarity as weight
    p = events['p'].astype(np.float32) * 2 - 1  # Convert to -1/+1
    
    for i in range(len(events)):
        voxel[t_bins[i], events['y'][i], events['x'][i]] += p[i]
    
    return voxel


print("Event representation functions defined.")

In [None]:
# Demonstrate different representations
try:
    import dv_processing as dv
    
    # Load some events
    reader = dv.io.MonoCameraRecording(str(ini30_path / "ID_001" / "events.aedat4"))
    all_events = dv.EventStore()
    start_time = None
    while reader.isRunning():
        events = reader.getNextEventBatch()
        if events is None:
            break
        all_events.add(events)
        if start_time is None:
            start_time = events.timestamps()[0]
        if events.timestamps()[-1] - start_time > 30000:  # 30ms
            break
    
    # Convert to numpy structured array
    coords = all_events.coordinates()
    times = all_events.timestamps()
    pols = all_events.polarities()
    
    events_array = np.zeros(len(times), dtype=[('t', '<i8'), ('x', '<i4'), ('y', '<i4'), ('p', '<i4')])
    events_array['t'] = times
    events_array['x'] = coords[:, 0]
    events_array['y'] = coords[:, 1]
    events_array['p'] = pols
    
    # Create representations
    frame = events_to_frame(events_array, separate_polarity=True)
    voxel = events_to_voxel_grid(events_array, num_bins=5)
    
    # Visualize
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Frame representation
    axes[0, 0].imshow(frame[0], cmap='Reds')
    axes[0, 0].set_title('Event Frame - ON channel')
    
    axes[0, 1].imshow(frame[1], cmap='Blues')
    axes[0, 1].set_title('Event Frame - OFF channel')
    
    axes[0, 2].imshow(frame[0] - frame[1], cmap='coolwarm')
    axes[0, 2].set_title('Event Frame - Combined')
    
    # Voxel grid (show 3 time bins)
    for i, bin_idx in enumerate([0, 2, 4]):
        axes[1, i].imshow(voxel[bin_idx], cmap='coolwarm', vmin=-5, vmax=5)
        axes[1, i].set_title(f'Voxel Grid - Time bin {bin_idx}')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Frame shape: {frame.shape}")
    print(f"Voxel shape: {voxel.shape}")

except Exception as e:
    print(f"Visualization error: {e}")

## 5. Speck Hardware Constraints

To deploy on Speck, your model must satisfy these constraints:

| Constraint | Requirement |
|------------|-------------|
| **Neuron type** | IAF (Integrate-and-Fire) only |
| **Activation** | Spiking neurons (no ReLU/Sigmoid) |
| **Pooling** | Sum pooling only (no Max/Avg) |
| **Weights** | 8-bit quantized |
| **Neuron states** | 16-bit |
| **Max layers** | 9 convolutional layers |
| **Max input** | 128x128 pixels |
| **Max channels** | 1024 per layer |

### Memory Constraints

Each layer must fit within the chip's memory:

**Kernel Memory (KMT):**
$$KMT = c \times 2^{\lceil\log_2(k_x \times k_y)\rceil + \lceil\log_2(f)\rceil}$$

**Neuron Memory (NM):**
$$NM = f \times f_x \times f_y$$

Where:
- $c$ = input channels
- $f$ = output channels (filters)
- $k_x, k_y$ = kernel size
- $f_x, f_y$ = output feature map size

In [None]:
# Example: Check if a model is Speck-compatible using sinabs
try:
    import torch
    import torch.nn as nn
    import sinabs
    import sinabs.layers as sl
    from sinabs.backend.dynapcnn import DynapcnnNetwork
    
    # Define a simple Speck-compatible SNN
    class SimpleSNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.seq = nn.Sequential(
                # Conv layer 1
                nn.Conv2d(2, 16, kernel_size=3, stride=2, padding=1, bias=False),
                sl.IAFSqueeze(batch_size=1),
                
                # Conv layer 2  
                nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
                sl.IAFSqueeze(batch_size=1),
                
                # Conv layer 3
                nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1, bias=False),
                sl.IAFSqueeze(batch_size=1),
                
                # Output
                nn.Flatten(),
                nn.Linear(32 * 8 * 8, 2, bias=False),  # 2 outputs: x, y
            )
        
        def forward(self, x):
            return self.seq(x)
    
    # Create model
    model = SimpleSNN()
    print("Model architecture:")
    print(model)
    
    # Try to convert for Speck
    input_shape = (2, 64, 64)  # 2 channels, 64x64 input
    
    dynapcnn_net = DynapcnnNetwork(
        snn=model.seq,
        input_shape=input_shape,
        discretize=True
    )
    
    print(f"\nModel converted successfully!")
    print(f"Number of DynapCNN layers: {len(dynapcnn_net.sequence)}")
    
    # Check memory mapping
    try:
        config = dynapcnn_net.make_config(device="speck2fmodule")
        print("\nMemory validation: PASSED - Model fits on Speck!")
    except Exception as e:
        print(f"\nMemory validation: FAILED - {e}")
    
except ImportError as e:
    print(f"sinabs not installed or import error: {e}")
    print("Install with: pip install sinabs")

## 6. Eye Tracking Problem Formulation

**Input**: Event stream from DVS camera looking at an eye

**Output**: Pupil center coordinates $(x, y)$

**Approach**: 
1. Accumulate events into frames/voxels (temporal windowing)
2. Feed through SNN
3. Predict coordinates (regression)

**Evaluation Metric**: Euclidean distance in pixels

$$\text{Error} = \sqrt{(x_{pred} - x_{true})^2 + (y_{pred} - y_{true})^2}$$

### Typical Pipeline

```
DVS Events -> Temporal Windowing -> Event Frames -> SNN -> (x, y) prediction
                                      |
                              Downscale to 64x64
                              (Speck compatible)
```

In [None]:
def euclidean_error(predictions, targets):
    """
    Calculate Euclidean distance error.
    
    Args:
        predictions: (N, 2) array of (x, y) coordinates
        targets: (N, 2) array of (x, y) coordinates
    
    Returns:
        Mean Euclidean distance in pixels
    """
    return np.mean(np.sqrt(np.sum((predictions - targets) ** 2, axis=1)))


# Example
predictions = np.array([[30.5, 25.2], [31.0, 24.8], [29.8, 25.5]])
targets = np.array([[32.0, 25.0], [32.0, 25.0], [32.0, 25.0]])

error = euclidean_error(predictions, targets)
print(f"Example predictions: {predictions}")
print(f"Ground truth: {targets[0]}")
print(f"Mean Euclidean error: {error:.2f} pixels")

## 7. Your Challenge

Now that you understand the data and constraints, it's time to build something **new**!

### The Goal

Create a novel SNN-based eye tracking solution that:
1. Predicts pupil coordinates with good accuracy
2. Runs on Speck hardware (or SpecksimSimulator)
3. **Brings something new** - don't just reproduce existing work!


### Ideas to Explore

**Architecture innovations:**
- Attention mechanisms for SNNs
- Skip connections / residual learning
- Different temporal aggregation strategies
- Efficient micro-architectures

**Training strategies:**
- Novel surrogate gradients
- Knowledge distillation
- Self-supervised pre-training
- Data augmentation for events

**Input representations:**
- Learned event encodings
- Multi-scale processing
- Adaptive temporal windowing
- Hybrid representations

### Learning Resources

- [Sinabs BPTT Tutorial](https://sinabs.readthedocs.io/v3.1.1/tutorials/bptt.html)
- [ANN to SNN Conversion](https://sinabs.readthedocs.io/v3.1.1/tutorials/ann_to_snn_conversion.html)
- [Speck Deployment Guide](https://sinabs.readthedocs.io/v3.1.1/speck/nmnist.html)
- [Specksim](https://sinabs.readthedocs.io/v3.1.1/speck/specksim.html)

### Key Tips

1. **Validate Speck compatibility early** - don't train a model that won't fit!
2. **Start simple, iterate fast** - get something working, then improve
3. **Use 64x64 input** - standard for Speck deployment
4. **IAF neurons only** - no LIF or other neuron types on Speck
5. **Sum pooling only** - no max or average pooling
6. **Be creative!** - the best solutions often come from unexpected directions

## References

**Papers:**
- [RETINA: Low-Power Eye Tracking with Event Camera and Spiking Hardware](https://arxiv.org/abs/2312.00425)
- [3ET: Efficient Event-based Eye Tracking using a Change-Based ConvLSTM Network](https://ieeexplore.ieee.org/document/10389541)

**Code Repositories:**
- [RETINA](https://github.com/pbonazzi/retina) - Reference SNN implementation
- [3ET ConvLSTM](https://github.com/qinche106/cb-convlstm-eyetracking) - Reference ConvLSTM implementation
- [Sinabs Library](https://sinabs.readthedocs.io/v3.1.1/) - SNN framework for Speck
- [Tonic](https://tonic.readthedocs.io/) - Event camera data loading

**Datasets:**
- [Ini-30 on Zenodo](https://zenodo.org/records/11203260) - Real DVS eye recordings
- [3ET Synthetic Dataset](https://drive.google.com/drive/folders/16qH_wv_oVNysJARtHIUrIXbHjOygfq_i) - Synthetic event frames