# Developing Annotators in TRACE

This notebook guides you through creating a custom **annotator plugin** that detects events or intervals in your time-series data.

> ðŸ’¡ **Tip:** Consider making a copy of this notebook before modifying it (e.g., `my_annotator.ipynb`). If you need to restore the original, run: `tracengine reset-notebooks .`

**What you'll learn:**
1. Load your project and explore the data
2. Visualize signals with matplotlib
3. Create an annotator that detects events
4. Run the annotator and visualize detected events
5. Save your plugin for production use

## 1. Setup and Load Project

First, let's load the project and explore what data is available.

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# TRACE imports
from tracengine.project import load_project
from tracengine.data.loader import load_session_from_project

# Load the project (assumes you're running from notebooks/ folder)
project_path = Path("..").resolve()
project = load_project(project_path)

print(f"âœ“ Project: {project.name}")
print(f"âœ“ Data source: {project.get_data_path()}")

## 2. Load and Explore Your Data

Load the session data and pick a run to work with.

In [None]:
# Load all runs from the session
runs = load_session_from_project(project)
print(f"âœ“ Found {len(runs)} runs\n")

# Pick the first run for development
run = runs[0]

print(f"Run: {run.run}")
print(f"Subject: {run.subject}, Session: {run.session}")
print(f"Metadata: {run.metadata}")
print(f"\nAvailable signal groups:")
for group_name, signal_group in run.signals.items():
    channels = signal_group.list_channels()
    sr = signal_group.estimate_sampling_rate()
    print(f"  â€¢ {group_name}: {len(channels)} channels @ {sr:.1f} Hz" if sr else f"  â€¢ {group_name}: {len(channels)} channels")

## 3. Visualize Your Data

Before creating an annotator, explore what your signals look like.
Use `run.get_signal(group, channel)` to get (time, values) arrays.

In [None]:
# Example: Plot signals from a specific group
# TODO: Replace with your actual group and channel names
group_name = list(run.signals.keys())[0]  # First available group
signal_group = run.signals[group_name]
channels = signal_group.list_channels()[:3]  # Plot first 3 channels

fig, axes = plt.subplots(len(channels), 1, figsize=(12, 3 * len(channels)), sharex=True)
if len(channels) == 1:
    axes = [axes]

for ax, channel in zip(axes, channels):
    t, y = run.get_signal(group_name, channel)
    ax.plot(t, y, linewidth=0.5)
    ax.set_ylabel(channel)
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel("Time (seconds)")
fig.suptitle(f"Signals from {group_name}")
plt.tight_layout()
plt.show()

## 4. Create Your Annotator

Subclass `AnnotatorBase` to create your event detector.

**Key concepts:**
- `required_channels`: Declares what data your annotator needs (by semantic role)
- `annotate()`: Your detection logic, receives resolved channel data
- Returns a list of `Event` objects

In [None]:
from tracengine.annotate.base import AnnotatorBase
from tracengine.data.descriptors import ChannelSpec, Event, RunData


class MyAnnotator(AnnotatorBase):
    """
    Example annotator that detects threshold crossings.
    
    Customize this for your specific detection needs!
    """
    
    name = "My Custom Annotator"
    version = "1.0.0"
    produces = "timepoint"  # or "interval" for start/end pairs
    
    # Declare required channels using semantic roles
    # These get resolved to actual data columns via channel bindings
    required_channels = {
        "signal": ChannelSpec(semantic_role="my_signal"),
    }
    
    @classmethod
    def get_parameters(cls):
        """Define user-adjustable parameters for the GUI."""
        return [
            {
                "name": "threshold",
                "label": "Detection Threshold",
                "type": "float",
                "default": 0.5,
                "min": 0.0,
                "max": 10.0,
                "step": 0.1,
            }
        ]
    
    def annotate(self, run: RunData, signal=None, threshold=0.5, **kwargs) -> list[Event]:
        """
        Detect events where signal exceeds threshold.
        
        Args:
            run: The RunData object
            signal: Tuple of (time_array, value_array) - resolved from required_channels
            threshold: Detection threshold from get_parameters()
        
        Returns:
            List of detected Event objects
        """
        time, values = signal
        events = []
        
        # Example: Find positive threshold crossings
        above_threshold = values > threshold
        crossings = np.where(np.diff(above_threshold.astype(int)) == 1)[0]
        
        for i, idx in enumerate(crossings):
            events.append(Event(
                annotator=self.name,
                name=f"Detection {i+1}",
                event_type="threshold-crossing",  # Your event type
                onset=float(time[idx]),
                offset=None,  # Set for interval events
                confidence=1.0,
                metadata={"threshold": str(threshold)}
            ))
        
        return events


print("âœ“ Annotator class defined successfully!")

## 5. Configure Channel Bindings and Run

Before running, tell TRACE which actual data channel maps to your semantic role (`"my_signal"` â†’ `"group:channel"`).

In [None]:
from tracengine.data.descriptors import RunConfig

# Create an annotator instance
annotator = MyAnnotator()
instance_name = "MyAnnotator"

# Configure channel bindings
# TODO: Replace with your actual group:channel reference
example_group = list(run.signals.keys())[0]
example_channel = run.signals[example_group].list_channels()[0]
actual_channel = f"{example_group}:{example_channel}"

run.run_config = RunConfig(
    channel_bindings={
        instance_name: {
            "my_signal": actual_channel  # Maps semantic role to actual channel
        }
    }
)

print(f"âœ“ Bound 'my_signal' â†’ '{actual_channel}'")

# Run the annotator
events = annotator.run(run, instance_name=instance_name, threshold=0.5)
print(f"âœ“ Detected {len(events)} events")

## 6. Visualize Detected Events

Overlay your detected events on the signal plot.

In [None]:
# Get the signal we analyzed
t, y = run.get_signal(example_group, example_channel)

fig, ax = plt.subplots(figsize=(14, 5))

# Plot signal
ax.plot(t, y, 'b-', linewidth=0.5, alpha=0.7, label="Signal")

# Plot threshold
threshold = 0.5
ax.axhline(y=threshold, color='gray', linestyle='--', label=f"Threshold: {threshold}")

# Plot detected events
for event in events:
    ax.axvline(x=event.onset, color='red', alpha=0.7, linewidth=1.5)

ax.set_xlabel("Time (seconds)")
ax.set_ylabel(example_channel)
ax.set_title(f"Detected {len(events)} events")
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Next Steps: Save for Production

Once your annotator works correctly:

1. **Copy the class** to `plugins/annotators/my_annotator.py`
2. **Register in trace-project.yaml**:
   ```yaml
   plugins:
     annotators:
       - plugins/annotators/my_annotator.py
   ```
3. **Use in GUI** or **pipelines** with automatic discovery!

### Tips
- Use `produces = "interval"` for events with duration (set `offset` in Event)
- Add multiple `required_channels` for multi-signal algorithms
- Customize `get_parameters()` to expose tuning knobs in the GUI