In [1]:
# OPTIMIZED FNIRS VIEWER WITH PERSISTENT ZOOM AND FIXED TRIGGER HANDLING

# 1. SETUP CELL
%matplotlib widget
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import json
import warnings
from datetime import datetime, timedelta
warnings.filterwarnings('ignore')

# Check for MNE
try:
    import mne
    from mne.io import read_raw_nirx, read_raw_snirf
    FNIRS_FORMATS_AVAILABLE = True
except ImportError:
    FNIRS_FORMATS_AVAILABLE = False

# Create main container
ui_container = widgets.Output()
display(ui_container)

# 2. MAIN APPLICATION CLASS
class fNIRSViewer:
    def __init__(self):
        self.fnirs_data = None
        self.trigger_data = None
        self.time_axis = None
        self.current_view = [0, 10]
        self.visible_channels = []
        self.selected_trigger_idx = None
        self.trigger_colors = {}
        self.fig = None
        self.ax1 = None
        self.ax2 = None
        self.zoom_mode = False
        self.unique_trigger_values = []
        self.original_tsv_data = None
        self.original_tri_data = None
        self.reference_time = None
        self.current_xlim = None  # Track current zoom level
        self.clear_button = None  # Will be created in create_widgets
        
        self.create_widgets()
        self.setup_ui()
        
    def create_widgets(self):
        """Create all UI widgets"""
        # File inputs
        self.fnirs_input = widgets.Text(
            placeholder='fNIRS file path',
            layout=widgets.Layout(width='70%'))
        self.trigger_tsv_input = widgets.Text(
            placeholder='Trigger TSV file path',
            layout=widgets.Layout(width='70%'))
        self.trigger_tri_input = widgets.Text(
            placeholder='Trigger TRI file path',
            layout=widgets.Layout(width='70%'))
        
        # Buttons
        self.load_btn = widgets.Button(
            description='Load Data',
            button_style='primary')
        self.load_btn.on_click(self.load_data)
        
        # Clear triggers button
        self.clear_button = widgets.Button(
            description='Clear Triggers',
            button_style='warning',
            disabled=True)
        self.clear_button.on_click(self.clear_triggers)
        
        # Save buttons
        self.save_tsv_btn = widgets.Button(
            description='Save TSV',
            button_style='info',
            disabled=True)
        self.save_tri_btn = widgets.Button(
            description='Save TRI',
            button_style='info',
            disabled=True)
        self.save_tsv_btn.on_click(self.save_tsv)
        self.save_tri_btn.on_click(self.save_tri)
        
        # Trigger controls
        self.trigger_combobox = widgets.Combobox(
            placeholder='Select or enter trigger value',
            options=[],
            ensure_option=False,
            description='Trigger Value:',
            style={'description_width': 'initial'})
        
        self.add_trigger_btn = widgets.Button(
            description='Add Trigger',
            button_style='success')
        self.update_trigger_btn = widgets.Button(
            description='Update Trigger',
            button_style='info',
            disabled=True)
        self.delete_trigger_btn = widgets.Button(
            description='Delete Trigger',
            button_style='danger',
            disabled=True)
        
        self.add_trigger_btn.on_click(self.add_trigger)
        self.update_trigger_btn.on_click(self.update_trigger)
        self.delete_trigger_btn.on_click(self.delete_trigger)
        
        # Navigation
        self.nav_buttons = widgets.HBox([
            widgets.Button(description='Jump to Start', layout=widgets.Layout(width='auto')),
            widgets.Button(description='Zoom In (2x)', layout=widgets.Layout(width='auto')),
            widgets.Button(description='Zoom Out (0.5x)', layout=widgets.Layout(width='auto')),
            widgets.Button(description='Jump to End', layout=widgets.Layout(width='auto')),
            widgets.Button(description='Show All', layout=widgets.Layout(width='auto')),
            widgets.ToggleButton(
                description='Box Zoom',
                value=False,
                tooltip='Toggle box zoom mode',
                layout=widgets.Layout(width='auto'))
        ])
        
        self.jump_start_btn = self.nav_buttons.children[0]
        self.zoom_in_btn = self.nav_buttons.children[1]
        self.zoom_out_btn = self.nav_buttons.children[2]
        self.jump_end_btn = self.nav_buttons.children[3]
        self.show_all_btn = self.nav_buttons.children[4]
        self.zoom_toggle = self.nav_buttons.children[5]
        
        self.jump_start_btn.on_click(self.jump_to_start)
        self.zoom_in_btn.on_click(self.zoom_in)
        self.zoom_out_btn.on_click(self.zoom_out)
        self.jump_end_btn.on_click(self.jump_to_end)
        self.show_all_btn.on_click(self.show_all)
        self.zoom_toggle.observe(self.toggle_zoom_mode, 'value')
        
        # Status
        self.status = widgets.HTML("Ready to load data")
        self.trigger_status = widgets.HTML("No trigger selected")
        self.view_status = widgets.HTML("")
        
        # Plot areas
        self.plot_output = widgets.Output()
        self.channel_box = widgets.HBox([])
    
    def update_trigger_combobox(self):
        """Update the trigger value combobox options"""
        if len(self.unique_trigger_values) > 0:
            str_values = sorted([str(v) for v in self.unique_trigger_values])
            self.trigger_combobox.options = str_values
            if len(str_values) > 0:
                self.trigger_combobox.value = str_values[0]
    
    def calculate_initial_view(self):
        """Calculate initial view range based on data length"""
        if self.time_axis is not None:
            total_duration = self.time_axis[-1] - self.time_axis[0]
            view_width = min(total_duration * 0.1, 60)
            return [self.time_axis[0], self.time_axis[0] + view_width]
        return [0, 10]
        
    def setup_ui(self):
        """Arrange the UI layout"""
        # Add Clear Triggers button to UI
        trigger_control_box = widgets.HBox([
            self.add_trigger_btn,
            self.update_trigger_btn,
            self.delete_trigger_btn,
            self.clear_button  # Add clear button
        ])
        
        # Reorganize to put channels below the plot
        self.ui = widgets.VBox([
            widgets.HBox([
                widgets.VBox([
                    widgets.Label('fNIRS Data:'),
                    self.fnirs_input,
                    widgets.Label('Trigger TSV:'),
                    self.trigger_tsv_input,
                    widgets.Label('Trigger TRI:'),
                    self.trigger_tri_input,
                    self.load_btn
                ]),
                widgets.VBox([
                    self.trigger_combobox,
                    trigger_control_box,  # Updated with clear button
                    widgets.HBox([
                        self.save_tsv_btn,
                        self.save_tri_btn
                    ]),
                    self.trigger_status
                ])
            ]),
            self.nav_buttons,
            self.status,
            self.view_status,
            self.plot_output,
            widgets.Label('Channels:'),
            self.channel_box
        ])
    
    def toggle_zoom_mode(self, change):
        """Toggle between box zoom and normal mode"""
        self.zoom_mode = change['new']
        if self.fig:
            if self.zoom_mode:
                self.fig.canvas.toolbar.zoom()
            else:
                self.fig.canvas.toolbar.pan()
        self.status.value = "Box zoom enabled" if self.zoom_mode else "Normal mode"
    
    def sync_axes(self, ax):
        """Synchronize the x-axis limits between plots"""
        if ax == self.ax1:
            self.ax2.set_xlim(self.ax1.get_xlim())
            self.current_xlim = self.ax1.get_xlim()
        elif ax == self.ax2:
            self.ax1.set_xlim(self.ax2.get_xlim())
            self.current_xlim = self.ax2.get_xlim()
        self.update_view_status()
    
    def update_view_status(self):
        """Update the view status display"""
        if self.time_axis is not None:
            current_view = self.current_xlim if self.current_xlim is not None else self.current_view
            duration = self.time_axis[-1] - self.time_axis[0]
            view_width = current_view[1] - current_view[0]
            percent_view = (view_width / duration) * 100
            self.view_status.value = (
                f"View: {current_view[0]:.1f}s to {current_view[1]:.1f}s | "
                f"Width: {view_width:.1f}s ({percent_view:.1f}% of total)"
            )
    
    def parse_tri_timestamp(self, timestamp_str):
        """Parse TRI timestamp string into seconds since reference"""
        # Make sure we handle various timestamp formats
        timestamp_str = timestamp_str.rstrip('Z')
        
        # Handle timestamps with microseconds
        if '.' in timestamp_str:
            try:
                current_time = datetime.fromisoformat(timestamp_str)
            except ValueError:
                # Try with truncated microseconds if there's an issue
                timestamp_str = timestamp_str.split('.')[0]
                current_time = datetime.fromisoformat(timestamp_str)
        else:
            current_time = datetime.fromisoformat(timestamp_str)
        
        if not hasattr(self, 'reference_time') or self.reference_time is None:
            # Use first timestamp as reference
            self.reference_time = current_time
            return 0.0
        
        return (current_time - self.reference_time).total_seconds()
    
    def clear_triggers(self, btn=None):
        """Clear all triggers from the display"""
        self.trigger_data = np.empty((0, 2))
        self.selected_trigger_idx = None
        self.update_trigger_btn.disabled = True
        self.delete_trigger_btn.disabled = True
        self.trigger_status.value = "All triggers cleared"
        self.plot_data()
    
    def load_data(self, btn):
        """Load data from files with proper trigger support"""
        self.status.value = "Loading..."
        
        try:
            # Load fNIRS data
            if self.fnirs_input.value:
                if self.fnirs_input.value.endswith('.csv'):
                    df = pd.read_csv(self.fnirs_input.value)
                    self.time_axis = df.iloc[:, 0].values
                    self.fnirs_data = df.iloc[:, 1:].values
                elif self.fnirs_input.value.endswith('.npy'):
                    data = np.load(self.fnirs_input.value, allow_pickle=True)
                    if isinstance(data, dict):
                        self.time_axis = data['time']
                        self.fnirs_data = data['data']
                    else:
                        self.time_axis = data[:, 0]
                        self.fnirs_data = data[:, 1:]
                elif FNIRS_FORMATS_AVAILABLE and (self.fnirs_input.value.endswith('.snirf') or 
                                                self.fnirs_input.value.endswith('.nirs')):
                    raw = read_raw_snirf(self.fnirs_input.value) if self.fnirs_input.value.endswith('.snirf') else read_raw_nirx(self.fnirs_input.value)
                    self.fnirs_data = raw.get_data().T
                    self.time_axis = np.arange(len(self.fnirs_data)) / raw.info['sfreq']
            
            # Clear existing triggers if we're loading new ones
            if self.trigger_tsv_input.value or self.trigger_tri_input.value:
                self.clear_triggers()
                
            # Reset trigger data structures
            self.trigger_data = np.empty((0, 2))
            self.reference_time = None  # Reset reference time for TRI files
            
            # Load trigger data from TSV file
            if self.trigger_tsv_input.value:
                df = pd.read_csv(self.trigger_tsv_input.value, sep='\t')
                self.original_tsv_data = df.copy()
                
                # Handle different TSV formats
                if 'Onset' in df.columns and 'trial_type' in df.columns:
                    # BIDS format: Onset, Duration, trial_type
                    self.trigger_data = df[['Onset', 'trial_type']].values
                elif 'time' in df.columns and 'value' in df.columns:
                    # Simple time-value format
                    self.trigger_data = df[['time', 'value']].values
                else:
                    # Fallback: first two columns
                    self.trigger_data = df.iloc[:, :2].values
            
            # Load trigger data from TRI file
            if self.trigger_tri_input.value:
                try:
                    with open(self.trigger_tri_input.value, 'r') as f:
                        lines = [line.strip().split(';') for line in f.readlines() if line.strip()]
                        
                        # Store original TRI data for saving
                        self.original_tri_data = lines
                        
                        # Parse TRI timestamps into seconds
                        times = []
                        values = []
                        for line in lines:
                            # Skip malformed lines
                            if len(line) < 3:
                                continue
                                
                            try:
                                t = self.parse_tri_timestamp(line[0])
                                val = float(line[2])
                                times.append(t)
                                values.append(val)
                            except (ValueError, IndexError) as e:
                                print(f"Skipping malformed line: {line} - {str(e)}")
                                
                        if times and values:
                            tri_triggers = np.column_stack((times, values))
                            
                            # If we already have TSV triggers, append the TRI triggers
                            if len(self.trigger_data) > 0:
                                self.trigger_data = np.vstack([self.trigger_data, tri_triggers])
                            else:
                                self.trigger_data = tri_triggers
                except Exception as e:
                    self.status.value = f"Error loading TRI file: {str(e)}"
            
            if len(self.trigger_data) > 0:
                # Sort triggers by time
                self.trigger_data = self.trigger_data[self.trigger_data[:, 0].argsort()]
                
                # Initialize trigger colors and unique values
                self.unique_trigger_values = np.unique(self.trigger_data[:, 1]).astype(int).tolist()
                colors = plt.cm.tab10(np.linspace(0, 1, len(self.unique_trigger_values)))
                self.trigger_colors = {val: color for val, color in zip(self.unique_trigger_values, colors)}
                
                # Update combobox
                self.update_trigger_combobox()
                
                # Enable clear button
                self.clear_button.disabled = False
            else:
                # Set default trigger values if no triggers loaded
                self.unique_trigger_values = [0, 1]
                self.update_trigger_combobox()
                self.clear_button.disabled = True
            
            # Set initial view
            self.current_view = self.calculate_initial_view()
            self.current_xlim = self.current_view.copy()
            
            if self.fnirs_data is not None:
                self.setup_channels()
            
            # Always plot after loading to show what was loaded
            self.plot_data()
            
            # Enable save buttons if we loaded the files
            self.save_tsv_btn.disabled = not bool(self.trigger_tsv_input.value)
            self.save_tri_btn.disabled = not bool(self.trigger_tri_input.value)
            
            # Update status message
            if self.fnirs_data is not None:
                num_channels = self.fnirs_data.shape[1] if len(self.fnirs_data.shape) > 1 else 1
                self.status.value = f"Loaded {num_channels} channels and {len(self.trigger_data)} triggers"
            else:
                self.status.value = f"Loaded {len(self.trigger_data)} triggers"
            
        except Exception as e:
            self.status.value = f"Error: {str(e)}"
            import traceback
            traceback.print_exc()
    
    def setup_channels(self):
        """Setup channel checkboxes in a horizontal layout"""
        self.channel_checkboxes = []
        num_channels = self.fnirs_data.shape[1] if len(self.fnirs_data.shape) > 1 else 1
        
        # Create channel checkboxes in multiple rows if there are many channels
        rows = []
        current_row = []
        
        for i in range(num_channels):
            cb = widgets.Checkbox(
                value=i < 5,
                description=f'Ch {i+1}',  # Shorten label
                indent=False,
                layout=widgets.Layout(width='80px'))  # Fixed width for consistency
            
            cb.observe(self.plot_data_debounced, names='value')
            self.channel_checkboxes.append(cb)
            
            current_row.append(cb)
            
            # Create new row every 10 channels
            if (i + 1) % 10 == 0 or i == num_channels - 1:
                rows.append(widgets.HBox(current_row))
                current_row = []
        
        self.channel_box.children = rows
        
    def plot_data_debounced(self, change=None):
        """Debounced version of plot_data to prevent excessive redrawing"""
        import time
        
        # Simple debounce mechanism
        current_time = time.time()
        if hasattr(self, '_last_plot_time') and current_time - self._last_plot_time < 0.1:
            # If we plotted less than 100ms ago, schedule another plot
            if not hasattr(self, '_plot_scheduled') or not self._plot_scheduled:
                self._plot_scheduled = True
                import threading
                threading.Timer(0.1, self._execute_plot).start()
            return
        
        self._last_plot_time = current_time
        self._plot_scheduled = False
        self.plot_data()
    
    def _execute_plot(self):
        """Helper function to execute the plot after debounce"""
        self._plot_scheduled = False
        self.plot_data()
    
    def plot_data(self):
        """Enhanced plotting with synchronized views and optimized performance"""
        with self.plot_output:
            clear_output(wait=True)
            
            if self.fnirs_data is None and len(self.trigger_data) == 0:
                return
                
            self.fig, (self.ax1, self.ax2) = plt.subplots(
                2, 1, figsize=(12, 8), 
                gridspec_kw={'height_ratios': [3, 1]},
                sharex=True,
                constrained_layout=True)
            
            # Plot fNIRS data if available
            if self.fnirs_data is not None:
                visible_channels = [i for i, cb in enumerate(self.channel_checkboxes) if cb.value]
                
                # Limit displayed channels to improve performance
                if len(visible_channels) > 0:
                    data_to_plot = self.fnirs_data[:, visible_channels] if len(self.fnirs_data.shape) > 1 else self.fnirs_data
                    
                    # Use single plot call with offset instead of multiple plot calls
                    if len(visible_channels) > 1:
                        offsets = np.arange(len(visible_channels)) * 0.2
                        for i, channel_idx in enumerate(visible_channels):
                            # Plot each channel with its own color
                            self.ax1.plot(self.time_axis, data_to_plot[:, i] + offsets[i], 
                                        label=f'Ch {channel_idx+1}')
                    else:
                        # Only one channel selected
                        single_channel = data_to_plot if len(self.fnirs_data.shape) == 1 else data_to_plot[:, 0]
                        self.ax1.plot(self.time_axis, single_channel, label=f'Ch {visible_channels[0]+1}')
                    
                    self.ax1.set_ylabel('fNIRS Data')
                    self.ax1.legend(loc='upper right', ncol=min(5, len(visible_channels)))
            elif self.time_axis is None:
                # If no fNIRS data but we have triggers, create a dummy time axis
                self.time_axis = np.linspace(
                    min(self.trigger_data[:, 0]) - 1 if len(self.trigger_data) > 0 else 0,
                    max(self.trigger_data[:, 0]) + 1 if len(self.trigger_data) > 0 else 10,
                    100
                )
            
            # Set title based on what we're viewing
            if self.fnirs_data is not None:
                self.ax1.set_title('fNIRS Data and Triggers')
            else:
                self.ax1.set_title('Trigger Data')
            
            # Plot triggers - use fewer lines for better performance
            if len(self.trigger_data) > 0:
                # Group triggers by value for more efficient rendering
                unique_values = np.unique(self.trigger_data[:, 1])
                for val in unique_values:
                    val_idx = self.trigger_data[:, 1] == val
                    times = self.trigger_data[val_idx, 0]
                    int_val = int(val)
                    color = self.trigger_colors.get(int_val, 'red')
                    
                    # Add trigger lines in both plots
                    for t in times:
                        self.ax1.axvline(x=t, color=color, linestyle='--', alpha=0.3)
                
                # Plot in trigger view - use a single scatter call with improved selection logic
                scatter = self.ax2.scatter(
                    self.trigger_data[:, 0], 
                    self.trigger_data[:, 1],
                    c=[self.trigger_colors.get(int(v), 'red') for v in self.trigger_data[:, 1]],
                    s=30, alpha=0.7, picker=5)  # Increased picker radius
                
                # Highlight selected trigger
                if self.selected_trigger_idx is not None and self.selected_trigger_idx < len(self.trigger_data):
                    trigger = self.trigger_data[self.selected_trigger_idx]
                    self.ax1.axvline(
                        x=trigger[0], 
                        color='black', 
                        linewidth=2)
                    self.ax2.scatter(
                        [trigger[0]], 
                        [trigger[1]], 
                        facecolors='none', 
                        edgecolors='black', 
                        s=100, zorder=10)
            
            # Set view limits - maintain current zoom if it exists
            if self.current_xlim is not None:
                self.ax1.set_xlim(self.current_xlim)
            else:
                self.ax1.set_xlim(self.current_view)
                self.current_xlim = self.current_view.copy()
            
            self.ax2.set_xlabel('Time (s)')
            self.ax2.set_ylabel('Triggers')
            self.ax2.grid(True)
            
            # Connect events - improved trigger selection
            def on_click(event):
                if event.inaxes == self.ax2 and not self.zoom_mode and len(self.trigger_data) > 0:
                    # Calculate 2D distance to find closest trigger
                    x_scale = 1.0
                    y_scale = 0.1  # Give more weight to x-axis distance
                    
                    # Scale distances properly based on axes scales
                    x_data = self.trigger_data[:, 0]
                    y_data = self.trigger_data[:, 1]
                    
                    # Calculate scaled distances
                    x_dist = (x_data - event.xdata) * x_scale
                    y_dist = (y_data - event.ydata) * y_scale
                    distances = np.sqrt(x_dist**2 + y_dist**2)
                    
                    closest = np.argmin(distances)
                    min_distance = distances[closest]
                    
                    # Use a reasonable threshold for selection
                    if min_distance < 0.5:
                        self.selected_trigger_idx = closest
                        trigger = self.trigger_data[closest]
                        self.trigger_combobox.value = str(int(trigger[1]))
                        self.update_trigger_btn.disabled = False
                        self.delete_trigger_btn.disabled = False
                        self.trigger_status.value = f"Selected: t={trigger[0]:.2f}s, v={int(trigger[1])}"
                        
                        # Update plot to show selection
                        self.plot_data()
            
            def on_zoom(event):
                if event.name == 'xlim_changed':
                    # This is called when zooming
                    self.sync_axes(event.canvas)
            
            # Set up connections
            self.fig.canvas.mpl_connect('button_press_event', on_click)
            self.ax1.callbacks.connect('xlim_changed', on_zoom)
            self.ax2.callbacks.connect('xlim_changed', on_zoom)
            
            # Update view status
            self.update_view_status()
            plt.show()
    
    def add_trigger(self, btn):
        """Add new trigger at center of current view"""
        try:
            center = np.mean(self.current_xlim) if self.current_xlim is not None else np.mean(self.current_view)
            new_value = int(self.trigger_combobox.value)
            new_trigger = np.array([[center, new_value]])
            
            if self.trigger_data is None or self.trigger_data.size == 0:
                # First trigger
                self.trigger_data = new_trigger
            else:
                self.trigger_data = np.vstack([self.trigger_data, new_trigger])
            
            # Sort triggers by time
            self.trigger_data = self.trigger_data[self.trigger_data[:, 0].argsort()]
            
            # Update unique values if needed
            if new_value not in self.unique_trigger_values:
                self.unique_trigger_values.append(new_value)
                self.update_trigger_combobox()
            
            # Update colors if new value
            if new_value not in self.trigger_colors:
                new_color = plt.cm.tab10(len(self.trigger_colors) % 10)
                self.trigger_colors[new_value] = new_color
            
            # Update selected trigger index
            self.selected_trigger_idx = np.where(
                (np.abs(self.trigger_data[:, 0] - center) < 1e-6) & 
                (self.trigger_data[:, 1] == new_value)
            )[0][0]
            
            # Enable buttons
            self.update_trigger_btn.disabled = False
            self.delete_trigger_btn.disabled = False
            self.clear_button.disabled = False
            
            self.plot_data()
            self.status.value = f"Added trigger {new_value} at {center:.2f}s"
            
        except ValueError:
            self.status.value = "Error: Please enter a valid integer trigger value"
    
    def update_trigger(self, btn):
        """Update selected trigger value"""
        if self.selected_trigger_idx is not None and self.selected_trigger_idx < len(self.trigger_data):
            try:
                new_value = int(self.trigger_combobox.value)
                old_value = self.trigger_data[self.selected_trigger_idx, 1]
                self.trigger_data[self.selected_trigger_idx, 1] = new_value
                
                # Update unique values if needed
                if new_value not in self.unique_trigger_values:
                    self.unique_trigger_values.append(new_value)
                    self.update_trigger_combobox()
                
                # Update colors if new value
                if new_value not in self.trigger_colors:
                    new_color = plt.cm.tab10(len(self.trigger_colors) % 10)
                    self.trigger_colors[new_value] = new_color
                
                # Remove old value from colors if no longer used
                if np.sum(self.trigger_data[:, 1] == old_value) == 0:
                    if old_value in self.trigger_colors:
                        self.trigger_colors.pop(old_value, None)
                    if old_value in self.unique_trigger_values:
                        self.unique_trigger_values.remove(old_value)
                        self.update_trigger_combobox()
                
                self.plot_data()
                self.status.value = f"Updated trigger to value {new_value}"
            except ValueError:
                self.status.value = "Error: Please enter a valid integer trigger value"
    
    def delete_trigger(self, btn):
        """Delete selected trigger"""
        if self.selected_trigger_idx is not None:
            old_value = self.trigger_data[self.selected_trigger_idx, 1]
            self.trigger_data = np.delete(self.trigger_data, self.selected_trigger_idx, axis=0)
            self.selected_trigger_idx = None
            self.update_trigger_btn.disabled = True
            self.delete_trigger_btn.disabled = True
            self.trigger_status.value = "Trigger deleted"
            
            # Remove old value from colors if no longer used
            if np.sum(self.trigger_data[:, 1] == old_value) == 0:
                self.trigger_colors.pop(old_value, None)
                self.unique_trigger_values.remove(old_value)
                self.update_trigger_combobox()
            
            self.plot_data()
    
    def save_tsv(self, btn):
        """Save triggers to TSV file"""
        if not self.trigger_tsv_input.value:
            self.status.value = "No TSV file path specified"
            return
            
        try:
            # Check if original data was in BIDS format
            if self.original_tsv_data is not None and 'Onset' in self.original_tsv_data.columns:
                # Create new DataFrame in BIDS format
                df = pd.DataFrame({
                    'Onset': self.trigger_data[:, 0],
                    'Duration': np.full(len(self.trigger_data), 10),  # Default duration
                    'trial_type': self.trigger_data[:, 1].astype(int)
                })
            else:
                # Simple time-value format
                df = pd.DataFrame({
                    'time': self.trigger_data[:, 0],
                    'value': self.trigger_data[:, 1].astype(int)
                })
            
            # Save to file
            df.to_csv(self.trigger_tsv_input.value, sep='\t', index=False)
            self.status.value = f"Saved {len(self.trigger_data)} triggers to {self.trigger_tsv_input.value}"
        except Exception as e:
            self.status.value = f"Error saving TSV: {str(e)}"
    
    def save_tri(self, btn):
        """Save triggers to TRI file (previously LSL)"""
        if not self.trigger_tri_input.value:
            self.status.value = "No TRI file path specified"
            return
            
        if not hasattr(self, 'reference_time') or self.reference_time is None:
            # Create a default reference time if none exists
            self.reference_time = datetime.now()
            self.status.value = "Warning: Using current time as reference for TRI file"
            
        try:
            with open(self.trigger_tri_input.value, 'w') as f:
                for time, value in self.trigger_data:
                    # Convert time back to timestamp
                    trigger_time = self.reference_time + timedelta(seconds=time)
                    timestamp = trigger_time.isoformat() + 'Z'
                    
                    # Use original sample numbers if available, otherwise sequential
                    sample_num = int(time * 1000)  # Approximate sample number
                    
                    # Write line in TRI format: timestamp;sample_num;value
                    f.write(f"{timestamp};{sample_num};{int(value)}\n")
            
            self.status.value = f"Saved {len(self.trigger_data)} triggers to {self.trigger_tri_input.value}"
        except Exception as e:
            self.status.value = f"Error saving TRI file: {str(e)}"
    
    def zoom_in(self, btn):
        """Zoom in by 2x"""
        if self.ax1:
            current_lim = self.ax1.get_xlim() if self.ax1 else self.current_xlim
            center = np.mean(current_lim)
            width = current_lim[1] - current_lim[0]
            new_width = width / 2
            self.current_xlim = [center - new_width/2, center + new_width/2]
            self.plot_data()
    
    def zoom_out(self, btn):
        """Zoom out by 0.5x"""
        if self.ax1:
            current_lim = self.ax1.get_xlim() if self.ax1 else self.current_xlim
            center = np.mean(current_lim)
            width = current_lim[1] - current_lim[0]
            new_width = width * 2
            # Don't zoom out beyond data limits
            data_min, data_max = self.time_axis[0], self.time_axis[-1]
            new_min = max(data_min, center - new_width/2)
            new_max = min(data_max, center + new_width/2)
            self.current_xlim = [new_min, new_max]
            self.plot_data()
    
    def jump_to_start(self, btn):
        """Jump to start while maintaining current zoom level"""
        if self.time_axis is not None:
            width = self.current_xlim[1] - self.current_xlim[0] if self.current_xlim is not None else 10
            self.current_xlim = [self.time_axis[0], self.time_axis[0] + width]
            self.plot_data()
    
    def jump_to_end(self, btn):
        """Jump to end while maintaining current zoom level"""
        if self.time_axis is not None:
            width = self.current_xlim[1] - self.current_xlim[0] if self.current_xlim is not None else 10
            self.current_xlim = [self.time_axis[-1] - width, self.time_axis[-1]]
            self.plot_data()
    
    def show_all(self, btn):
        """Show entire time range"""
        if self.time_axis is not None:
            self.current_xlim = [self.time_axis[0], self.time_axis[-1]]
            self.plot_data()

# 3. CREATE AND DISPLAY THE VIEWER
with ui_container:
    viewer = fNIRSViewer()
    display(viewer.ui)

Output()