# Developing Compute Plugins in TRACE

This notebook guides you through creating a custom **compute plugin** that calculates metrics from signals and detected events.

> ðŸ’¡ **Tip:** Consider making a copy of this notebook before modifying it (e.g., `my_compute.ipynb`). If you need to restore the original, run: `trace reset-notebooks .`

**What you'll learn:**
1. Load your project and explore available data + events
2. Create a compute module that processes event intervals
3. Visualize metric results
4. Export results with provenance tracking

## 1. Setup and Load Project

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# TRACE imports
from tracetool.project import load_project
from tracetool.data.loader import load_session_from_project

# Load the project (assumes you're running from notebooks/ folder)
project_path = Path("..").resolve()
project = load_project(project_path)

print(f"âœ“ Project: {project.name}")
print(f"âœ“ Data source: {project.get_data_path()}")

## 2. Load Data and Explore Events

Compute modules typically work on events detected by annotators.

In [None]:
# Load session
runs = load_session_from_project(project)
run = runs[0]

print(f"Run: {run.run}")
print(f"Subject: {run.subject}, Session: {run.session}")

# Show available signal groups
print("\nSignal groups:")
for name, group in run.signals.items():
    print(f"  â€¢ {name}: {group.list_channels()[:5]}...")

# Show any existing annotations
print("\nAnnotations loaded:")
if run.annotations:
    for ann_group, events in run.annotations.items():
        print(f"  â€¢ {ann_group}: {len(events)} events")
else:
    print("  (none - run an annotator first in develop_annotator.ipynb)")

## 3. Visualize Events on Signal

Before computing metrics, visualize the events you'll process.

In [None]:
# Pick a signal to visualize
# TODO: Replace with your actual group and channel names
group_name = list(run.signals.keys())[0]
channel_name = run.signals[group_name].list_channels()[0]

t, y = run.get_signal(group_name, channel_name)

fig, ax = plt.subplots(figsize=(14, 5))
ax.plot(t, y, 'b-', linewidth=0.5, alpha=0.7, label="Signal")

# If there are annotations, plot them
colors = plt.cm.Set2.colors
for i, (ann_group, events) in enumerate(run.annotations.items()):
    color = colors[i % len(colors)]
    for event in events[:20]:  # Limit for clarity
        if event.offset:  # Interval event
            ax.axvspan(event.onset, event.offset, alpha=0.2, color=color, label=ann_group if event == events[0] else None)
        else:  # Timepoint event
            ax.axvline(event.onset, color=color, alpha=0.5, linewidth=1)

ax.set_xlabel("Time (seconds)")
ax.set_ylabel(channel_name)
ax.set_title("Signal with Detected Events")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Create Your Compute Plugin

Subclass `ComputeBase` to create your metrics calculator.

**Key concepts:**
- `required_channels`: What signal data you need
- `required_events`: What event types you need (from annotators)
- `compute()`: Your metric calculation logic
- Returns a `pd.DataFrame` with one row per event/trial

In [None]:
from tracetool.compute.base import ComputeBase
from tracetool.data.descriptors import ChannelSpec, EventSpec, RunData


class MyCompute(ComputeBase):
    """
    Example compute module that calculates metrics within event intervals.
    
    Customize this for your specific analysis needs!
    """
    
    name = "My Custom Metrics"
    version = "1.0.0"
    
    # Declare required channels
    required_channels = {
        "signal": ChannelSpec(semantic_role="my_signal", allow_derived=True),
    }
    
    # Declare required events (from an annotator)
    required_events = {
        "events": EventSpec(event_type="my-event-type", kind="interval"),
    }
    
    @classmethod
    def get_parameters(cls):
        """Define user-adjustable parameters for the GUI."""
        return [
            {
                "name": "window_ms",
                "label": "Analysis Window (ms)",
                "type": "int",
                "default": 100,
                "min": 10,
                "max": 1000,
                "step": 10,
            }
        ]
    
    def compute(self, run: RunData, signal=None, events=None, window_ms=100, **kwargs) -> pd.DataFrame:
        """
        Calculate metrics for each event.
        
        Args:
            run: The RunData object
            signal: Tuple of (time_array, value_array)
            events: List of Event objects matching required_events
            window_ms: Analysis parameter from get_parameters()
        
        Returns:
            DataFrame with computed metrics
        """
        time, values = signal
        dt = np.median(np.diff(time))
        
        rows = []
        for i, event in enumerate(events):
            # Extract data within event interval
            mask = (time >= event.onset) & (time <= event.offset)
            event_values = values[mask]
            
            if len(event_values) < 2:
                continue
            
            # Calculate metrics (customize these!)
            duration = event.offset - event.onset
            mean_val = np.mean(event_values)
            std_val = np.std(event_values)
            max_val = np.max(event_values)
            min_val = np.min(event_values)
            range_val = max_val - min_val
            
            # Velocity metrics (if applicable)
            velocity = np.gradient(event_values, dt)
            peak_velocity = np.max(np.abs(velocity))
            
            rows.append({
                # Identifiers (automatically populated)
                "subject": run.subject,
                "session": run.session,
                "run": run.run,
                **run.metadata,
                
                # Event info
                "event_name": event.name,
                "event_type": event.event_type,
                "onset_s": event.onset,
                "offset_s": event.offset,
                "duration_s": duration,
                
                # Computed metrics
                "mean": mean_val,
                "std": std_val,
                "max": max_val,
                "min": min_val,
                "range": range_val,
                "peak_velocity": peak_velocity,
            })
        
        return pd.DataFrame(rows)


print("âœ“ Compute class defined successfully!")

## 5. Configure Bindings and Run

Map semantic roles to actual channels/events.

In [None]:
from tracetool.data.descriptors import RunConfig, Event

# Create compute instance
compute = MyCompute()
instance_name = "MyCompute"

# Configure channel bindings
# TODO: Replace with your actual group:channel reference
example_group = list(run.signals.keys())[0]
example_channel = run.signals[example_group].list_channels()[0]

# If you haven't run an annotator yet, create sample events for testing
if not run.annotations:
    # Create mock events for demonstration
    t, y = run.get_signal(example_group, example_channel)
    duration = t[-1] - t[0]
    interval = duration / 5
    
    mock_events = [
        Event(
            annotator="MockAnnotator",
            name=f"Test Event {i}",
            event_type="my-event-type",
            onset=i * interval,
            offset=i * interval + interval * 0.8,
            confidence=1.0,
            metadata={}
        )
        for i in range(5)
    ]
    run.annotations["mock"] = mock_events
    print(f"âœ“ Created {len(mock_events)} mock events for demonstration")

# Set up run configuration
run.run_config = RunConfig(
    channel_bindings={
        instance_name: {
            "my_signal": f"{example_group}:{example_channel}"
        }
    },
    event_bindings={
        instance_name: {
            "events": list(run.annotations.keys())[0]  # Use first annotation group
        }
    }
)

# Run the compute
results = compute.run(run, instance_name=instance_name, window_ms=100)
print(f"\nâœ“ Computed {len(results)} rows of metrics")
results.head()

## 6. Visualize Metrics

In [None]:
if len(results) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Duration distribution
    ax = axes[0, 0]
    ax.hist(results["duration_s"], bins=20, edgecolor='black', alpha=0.7)
    ax.set_xlabel("Duration (s)")
    ax.set_ylabel("Count")
    ax.set_title("Event Duration Distribution")
    
    # Mean value per event
    ax = axes[0, 1]
    ax.bar(range(len(results)), results["mean"])
    ax.set_xlabel("Event Index")
    ax.set_ylabel("Mean Value")
    ax.set_title("Mean Value per Event")
    
    # Peak velocity
    ax = axes[1, 0]
    ax.bar(range(len(results)), results["peak_velocity"])
    ax.set_xlabel("Event Index")
    ax.set_ylabel("Peak Velocity")
    ax.set_title("Peak Velocity per Event")
    
    # Range
    ax = axes[1, 1]
    ax.bar(range(len(results)), results["range"])
    ax.set_xlabel("Event Index")
    ax.set_ylabel("Signal Range")
    ax.set_title("Signal Range per Event")
    
    plt.tight_layout()
    plt.show()
else:
    print("No results to visualize")

## 7. Export with Provenance Tracking

Save results with automatic provenance metadata.

In [None]:
# Export results (uncomment to run)
# results_with_export = compute.run(
#     run, 
#     instance_name=instance_name,
#     export=True,
#     project_dir=project_path,
#     window_ms=100
# )
# 
# # Results are saved to:
# #   exports/{subject}_{session}_{run}_{instance_name}.csv
# #   exports/{subject}_{session}_{run}_{instance_name}_provenance.json

print("To export, uncomment the cell above and run.")

## 8. Next Steps: Save for Production

Once your compute module works correctly:

1. **Copy the class** to `plugins/compute/my_compute.py`
2. **Register in tracetool.yaml**:
   ```yaml
   plugins:
     compute:
       - plugins/compute/my_compute.py
   ```
3. **Use in GUI** or **pipelines** with automatic discovery!

### Tips
- Add multiple `required_channels` for multi-signal analysis
- Use `required_events` with specific `event_type` to filter events
- Customize `get_parameters()` to expose tuning knobs in the GUI
- Results export includes provenance JSON for reproducibility