In [None]:
import mne
import pandas as pd
import numpy as np
# from scipy.signal import butter, filtfilt # No longer needed with simplified approach
import os
from datetime import datetime, timezone

# --- Configuration ---
# Make paths relative to WORKSPACE_ROOT

# Determine paths programmatically instead of hardcoding
WORKSPACE_ROOT = os.getcwd()

# Define experiment details
SAMPLE_DATA_FOLDER = "sample_data"
EXPERIMENT_ID = "E23B6B24FX14_1743611361000"

# Construct file paths
edf_file_path = os.path.join(WORKSPACE_ROOT, SAMPLE_DATA_FOLDER, EXPERIMENT_ID, f"{EXPERIMENT_ID}.edf")
csv_file_path = os.path.join(WORKSPACE_ROOT, SAMPLE_DATA_FOLDER, EXPERIMENT_ID, f"{EXPERIMENT_ID}.csv")

# Print workspace root for verification
print(f"Using workspace root: {WORKSPACE_ROOT}")

FS_ORIGINAL = 0 

# --- Load Sleep Stages (only the 'Sleep stage' column is needed now) ---
print(f"Loading sleep stages from: {csv_file_path}")
stages_df = pd.read_csv(csv_file_path, usecols=['Sleep stage'])

# --- Load EDF Data ---
print(f"Loading EDF data from: {edf_file_path}")
raw = mne.io.read_raw_edf(edf_file_path, preload=True, verbose='INFO')
FS_ORIGINAL = int(raw.info['sfreq'])

# session_start_time_utc is still useful for metadata, even if not for stage alignment
if raw.info['meas_date']:
    session_start_time_utc = raw.info['meas_date'] 
else:
    try:
        filename = os.path.basename(edf_file_path)
        timestamp_ms = int(filename.split('_')[1])
        session_start_time_utc = datetime.fromtimestamp(timestamp_ms / 1000, timezone.utc) 
        print(f"Inferred session start time from filename: {session_start_time_utc}")
    except Exception as e:
        print(f"Could not infer session start time from filename ({e}), setting to now (UTC).")
        session_start_time_utc = datetime.now(timezone.utc)

if session_start_time_utc.tzinfo is None:
    print("Warning: session_start_time_utc was naive, localizing to UTC.")
    session_start_time_utc = session_start_time_utc.replace(tzinfo=timezone.utc)
elif session_start_time_utc.tzinfo != timezone.utc:
    print(f"Warning: session_start_time_utc was {session_start_time_utc.tzinfo}, converting to UTC.")
    session_start_time_utc = session_start_time_utc.astimezone(timezone.utc)

print(f"Session start time (UTC): {session_start_time_utc}")
print(f"Original sampling frequency: {FS_ORIGINAL} Hz")

# --- Prepare EEG Data ---
EEG_CHANNELS_TO_USE = raw.ch_names
print(f"Using available EEG channels from EDF: {EEG_CHANNELS_TO_USE}")
raw_eeg = raw.copy().pick_channels(EEG_CHANNELS_TO_USE, ordered=False)
eeg_data = raw_eeg.get_data() 

eog_data = None 
print("EOG data processing and derivation has been removed for this simulation.")

# --- Calculate total duration and sleep stage interval ---
total_samples_eeg = eeg_data.shape[1]
simulation_duration_seconds = total_samples_eeg / FS_ORIGINAL
print(f"Total samples in EEG data: {total_samples_eeg}")
print(f"Calculated simulation duration: {simulation_duration_seconds:.2f} seconds")

n_sleep_stages_entries = len(stages_df)
if n_sleep_stages_entries > 0:
    sleep_stage_interval_seconds = simulation_duration_seconds / n_sleep_stages_entries
    print(f"Number of sleep stage entries in CSV: {n_sleep_stages_entries}")
    print(f"Calculated sleep stage interval: {sleep_stage_interval_seconds:.2f} seconds per stage entry.")
else:
    sleep_stage_interval_seconds = simulation_duration_seconds # Or some other default if no stages
    print("Warning: No sleep stage entries found in CSV. Simulation will use a single stage or default.")
    # Optionally, create a default stage_df if it's empty to prevent errors later
    if stages_df.empty:
        stages_df = pd.DataFrame({'Sleep stage': [0]}) # Default to Wake
        n_sleep_stages_entries = 1
        print("Defaulting to a single 'Wake' stage for the entire duration.")


print("Data loading and initial preparation complete.")
print(f"EEG data shape: {eeg_data.shape}")
print("EOG data: Not used in this simulation.")
print(f"Sampling Frequency (FS_ORIGINAL): {FS_ORIGINAL} Hz")

In [None]:
import numpy as np
import pandas as pd
from datetime import datetime, timezone
import scipy.signal
import sys # Added sys import
import os # Added os import

# Add workspace root to sys.path

if WORKSPACE_ROOT not in sys.path:
    sys.path.insert(0, WORKSPACE_ROOT)

from dl_alertness_detection import predict_alertness_ema # Import actual function
from app.main import needed_len # Import needed_len

# --- Simulation Constants ---
FS_TARGET = 125  # Hz, target sampling frequency for processing
REM_SLEEP_STAGE_VALUE = 3
MAX_SUCCESSIVE_REM_CUES = 2
REM_AUDIO_CUE_INTERVAL_SECONDS = 10  # Minimum interval between REM cue sequences
ALERTNESS_THRESHOLD_FOR_ACTION = 0.6 # Example threshold
ALERTNESS_EMA_ALPHA = 0.1 # For smoothing alertness, if we implement EMA here. The imported function has its own.
SECONDS_PER_WINDOW = 1 # Process data second by second

# --- Helper Functions ---
def get_sleep_stage_at_time(current_sim_time_seconds, session_start_iso, stages_df, sim_stage_interval_seconds, sim_total_stages_count):
    """
    Determines the sleep stage for a given simulation timestamp using fixed interval indexing.
    """
    if stages_df is None or stages_df.empty:
        print("Warning: Sleep stages DataFrame is empty or None.")
        return 0 # Default to awake or unknown if no data

    # Calculate the index based on the current simulation time and the fixed interval
    # current_sim_time_seconds is relative to the start of the simulation (0 to simulation_duration_seconds)
    stage_index = int(current_sim_time_seconds // sim_stage_interval_seconds)

    if 0 <= stage_index < sim_total_stages_count:
        return stages_df['Sleep stage'].iloc[stage_index]
    elif stage_index >= sim_total_stages_count:
        # If current time exceeds the duration covered by stages, return the last known stage
        print(f"Warning: current_sim_time_seconds {current_sim_time_seconds} exceeds stage data coverage. Using last stage.")
        return stages_df['Sleep stage'].iloc[-1]
    else:
        # Should not happen if current_sim_time_seconds starts at 0
        print(f"Warning: Calculated stage_index {stage_index} is out of bounds. Defaulting to 0.")
        return 0

# Placeholder for predict_alertness_ema_sim is now removed, we will use the imported one.

print(f"Constants and helper functions defined. needed_len: {needed_len}")

In [None]:
sim_metadata_audio_cue_timestamps = []
last_audio_cue_time = -float('inf')

def fire_rem_audio_cues_sequence_sim(current_time_seconds):
    """Simulates firing audio cues and records their timestamps."""
    global last_audio_cue_time, sim_metadata_audio_cue_timestamps
    
    print(f"SIM AUDIO: Attempting to fire REM audio cue sequence at {current_time_seconds:.2f}s")
    for i in range(MAX_SUCCESSIVE_REM_CUES):
        cue_initiation_time = current_time_seconds + (i * REM_AUDIO_CUE_INTERVAL_SECONDS)
        # Check if enough time has passed since the *very last* cue of any sequence
        # This is a simplified check for simulation; real app might have more complex state
        if cue_initiation_time < last_audio_cue_time + REM_AUDIO_CUE_INTERVAL_SECONDS: 
            # This check is mostly to prevent overlapping print statements in rapid succession if called improperly
            # The main loop's `is_in_rem_cycle` and `rem_audio_cues_fired_this_cycle` should prevent re-triggering too soon.
            print(f"SIM AUDIO: Cue {i+1} at {cue_initiation_time:.2f}s would be too soon. Skipping.")
            continue

        sim_metadata_audio_cue_timestamps.append(cue_initiation_time)
        print(f"SIM AUDIO: Cue {i+1}/{MAX_SUCCESSIVE_REM_CUES} scheduled/fired at {cue_initiation_time:.2f}s (Simulated)")
        last_audio_cue_time = cue_initiation_time
        
        # In a real simulation, you might want to simulate the time passing for the cue duration
        # For this, we just record the timestamp of when it *would* play.
    return True # Indicates sequence was initiated

print("Simulated audio cue function defined.")

In [None]:
from scipy.signal import resample_poly
import math 
import numpy as np
import pandas as pd
from datetime import datetime, timezone, timedelta
import scipy.signal
import time # Added time import for verbosity

# This cell should be run after the previous cells defining constants, helpers, and importing predict_alertness_ema

def real_time_processing_simulation(
    eeg_data_original, 
    fs_original, 
    fs_target, 
    stages_df_sim, 
    simulation_duration_seconds_sim, 
    session_start_iso_sim, # This is the correct parameter name
    sim_stage_interval_seconds, # Interval for sleep stage updates
    sim_total_stages_count, # Total number of stage entries
    eeg_channels_to_use_sim # List of EEG channel names used
):
    """
    Simulates the real-time processing loop using pre-recorded data.
    Uses actual alertness detection model.
    """
    print(f"Simulation starting with duration: {simulation_duration_seconds_sim}s, Target FS: {fs_target}Hz")
    print(f"EEG channels used in simulation: {eeg_channels_to_use_sim}") # Expecting one channel for current model

    # Ensure eeg_data_original is 2D (channels, samples) and select the first channel if multiple
    if eeg_data_original.ndim == 1:
        eeg_data_original = eeg_data_original.reshape(1, -1) # Make it (1, samples)
    
    # For the current DeepSleepNet model, it expects single channel EEG data.
    # If eeg_channels_to_use_sim has more than one, we should select one or average.
    # For now, assuming the loaded eeg_data corresponds to the single channel expected by the model.
    # If eeg_data_original has multiple channels, let's pick the first one specified in eeg_channels_to_use_sim
    # This logic assumes eeg_data_original was loaded considering eeg_channels_to_use_sim from cell 1
    # For simplicity, if eeg_data_original still has multiple channels here, we take the first one.
    if eeg_data_original.shape[0] > 1:
        print(f"Warning: Original EEG data has {eeg_data_original.shape[0]} channels. Using the first channel for alertness prediction.")
        eeg_single_channel_original = eeg_data_original[0, :]
    else:
        eeg_single_channel_original = eeg_data_original.squeeze() # Ensure it's 1D

    # Resample the selected EEG channel data
    num_original_samples = len(eeg_single_channel_original)
    num_target_samples = int(num_original_samples * fs_target / fs_original)
    eeg_resampled = scipy.signal.resample(eeg_single_channel_original, num_target_samples)
    print(f"EEG data resampled from {num_original_samples} to {num_target_samples} samples.")

    # Simulation loop variables
    sim_timestamps = []
    sim_sleep_stages = []
    sim_alertness_scores = []
    
    # Declare global and clear before use in this function
    global sim_metadata_audio_cue_timestamps
    # sim_metadata_audio_cue_timestamps is already global from cell ffbe0495, clear it for new run
    # Ensure it exists in global scope if this cell is run independently of cell ffbe0495 for some reason
    if 'sim_metadata_audio_cue_timestamps' not in globals():
        sim_metadata_audio_cue_timestamps = [] 
    else:
        sim_metadata_audio_cue_timestamps.clear()

    last_rem_audio_cue_time = -float('inf')
    successive_rem_cues_fired = 0
    current_alertness_ema_sim = 0.5 # Initial value

    # Buffer for EEG data for alertness model (needs `needed_len` samples)
    eeg_buffer_alertness = np.zeros(needed_len) # Initialize with zeros
    
    total_windows = int(simulation_duration_seconds_sim // SECONDS_PER_WINDOW)
    print(f"Total windows to process: {total_windows}")

    for i in range(total_windows):
        current_sim_time_seconds = i * SECONDS_PER_WINDOW
        # print(f"Simulating time: {current_sim_time_seconds:.2f}s") # Verbose

        # 1. Get current sleep stage
        current_sleep_stage = get_sleep_stage_at_time(
            current_sim_time_seconds, 
            session_start_iso_sim,  # Corrected variable name
            stages_df_sim, 
            sim_stage_interval_seconds,
            sim_total_stages_count
        )

        # 2. Get EEG data for the current window (1 second)
        start_sample_idx = int(current_sim_time_seconds * fs_target)
        end_sample_idx = start_sample_idx + int(SECONDS_PER_WINDOW * fs_target)
        
        if end_sample_idx > len(eeg_resampled):
            print(f"Reached end of resampled EEG data at {current_sim_time_seconds}s. Stopping simulation.")
            break
        
        current_eeg_window = eeg_resampled[start_sample_idx:end_sample_idx]

        # 3. Update EEG buffer for alertness model
        # Shift buffer and add new data
        eeg_buffer_alertness = np.roll(eeg_buffer_alertness, -len(current_eeg_window))
        eeg_buffer_alertness[-len(current_eeg_window):] = current_eeg_window

        # 4. Predict alertness (if buffer has enough data)
        # The model expects `needed_len` samples. We feed it every second.
        # The `predict_alertness_ema` function itself handles chunking internally based on its own `input_len` and `seq_len`.
        if i * SECONDS_PER_WINDOW * fs_target >= needed_len: # Ensure we have at least needed_len of data processed into buffer
            # The function expects the raw segment of `needed_len`
            # The current eeg_buffer_alertness *is* that segment.
            current_alertness_ema_sim = predict_alertness_ema(eeg_buffer_alertness, ema_span=20) # ema_span is internal to func
        else:
            # Not enough data yet for a full prediction, use initial/previous or a placeholder
            current_alertness_ema_sim = 0.5 # Or some other strategy
            # print(f"Not enough data for alertness prediction yet. Current buffer filled: {i * SECONDS_PER_WINDOW * fs_target / needed_len * 100:.1f}%")

        # 5. REM Detection and Audio Cue Logic (Simplified)
        is_rem_sleep = (current_sleep_stage == REM_SLEEP_STAGE_VALUE)
        
        if is_rem_sleep:
            # print(f"REM detected at {current_sim_time_seconds:.2f}s. Alertness: {current_alertness_ema_sim:.2f}")
            can_fire_due_to_interval = (current_sim_time_seconds - last_rem_audio_cue_time) >= REM_AUDIO_CUE_INTERVAL_SECONDS
            can_fire_due_to_successive_limit = successive_rem_cues_fired < MAX_SUCCESSIVE_REM_CUES
            
            if can_fire_due_to_interval and can_fire_due_to_successive_limit:
                print(f"  Firing REM audio cue sequence at {current_sim_time_seconds:.2f}s. Successive count: {successive_rem_cues_fired + 1}")
                fire_rem_audio_cues_sequence_sim(current_sim_time_seconds) # Records to global list
                last_rem_audio_cue_time = current_sim_time_seconds 
                successive_rem_cues_fired += 1
            elif not can_fire_due_to_interval:
                # print(f"  REM cue suppressed: Interval not met. Last cue at {last_rem_audio_cue_time:.2f}s")
                pass # Interval not met
            elif not can_fire_due_to_successive_limit:
                # print(f"  REM cue suppressed: Max successive cues reached.")
                pass # Max successive cues reached
        else:
            # Not in REM, reset successive cue counter
            if successive_rem_cues_fired > 0:
                # print(f"Exited REM or not in REM. Resetting successive REM cue counter from {successive_rem_cues_fired}.")
                pass
            successive_rem_cues_fired = 0
            
        # Store metadata for this second
        sim_timestamps.append(current_sim_time_seconds)
        sim_sleep_stages.append(current_sleep_stage)
        sim_alertness_scores.append(current_alertness_ema_sim)

        if (i + 1) % (10 * SECONDS_PER_WINDOW) == 0: # Print progress every 10 seconds of simulation
            print(f"Progress: {current_sim_time_seconds:.0f}s / {simulation_duration_seconds_sim:.0f}s. Stage: {current_sleep_stage}, Alertness: {current_alertness_ema_sim:.3f}")

    print("Simulation loop finished.")
    
    # Prepare results
    results = {
        "timestamps": np.array(sim_timestamps),
        "sleep_stages": np.array(sim_sleep_stages),
        "alertness_scores": np.array(sim_alertness_scores),
        "audio_cue_timestamps": np.array(sim_metadata_audio_cue_timestamps),
        "fs_target": fs_target,
        "simulation_duration_seconds": simulation_duration_seconds_sim,
        "session_start_iso": session_start_iso_sim,
        "sleep_stage_interval_simulated": sim_stage_interval_seconds,
        "eeg_channels_used": eeg_channels_to_use_sim
    }
    return results

print("Simulation loop function `real_time_processing_simulation` defined and ready.")
# Example of how to call (actual call will be in a new cell):
# sim_results = real_time_processing_simulation(
#     eeg_data, # from cell 1
#     FS_ORIGINAL, # from cell 1
#     FS_TARGET, # from cell 2
#     stages_df, # from cell 1
#     simulation_duration_seconds, # from cell 1
#     session_start_time_utc.isoformat(), # from cell 1
#     sleep_stage_interval_seconds, # from cell 1
#     n_sleep_stages_entries, # from cell 1
#     EEG_CHANNELS_TO_USE # from cell 1
# )

In [None]:
# --- Execute the Simulation ---
# This cell assumes all previous cells (data loading, helpers, simulation function) have been run.

print("Starting simulation execution...")

# Ensure global list for audio cues is initialized if not already by cell ffbe0495
if 'sim_metadata_audio_cue_timestamps' not in globals():
    sim_metadata_audio_cue_timestamps = []
else:
    sim_metadata_audio_cue_timestamps.clear() # Clear for a fresh run

# Call the simulation function with variables loaded/defined in previous cells
# These variables should be in the global notebook scope from executing Cell 1 (b00d3c33) and Cell 2 (921c860b)
sim_results = real_time_processing_simulation(
    eeg_data,                   # Loaded in Cell 1 (b00d3c33)
    FS_ORIGINAL,                # Defined in Cell 1 (b00d3c33)
    FS_TARGET,                  # Defined in Cell 2 (921c860b)
    stages_df,                  # Loaded in Cell 1 (b00d3c33)
    simulation_duration_seconds,# Calculated in Cell 1 (b00d3c33)
    session_start_time_utc.isoformat(), # Corrected: Convert datetime object to ISO string
    sleep_stage_interval_seconds, # Calculated in Cell 1 (b00d3c33)
    n_sleep_stages_entries,     # Calculated in Cell 1 (b00d3c33)
    EEG_CHANNELS_TO_USE         # Defined in Cell 1 (b00d3c33)
)

print("Simulation finished. Results dictionary created.")
print(f"Number of audio cues triggered: {len(sim_results['audio_cue_timestamps'])}")
print(f"First few alertness scores: {sim_results['alertness_scores'][:10]}")
print(f"Simulated sleep stage interval: {sim_results['sleep_stage_interval_simulated']}")

In [None]:
# --- Save Simulation Results ---
import numpy as np
import os

output_filename = "simulation_results.npz"
output_path = os.path.join(WORKSPACE_ROOT, output_filename) # Save in workspace root

if 'sim_results' in globals():
    try:
        np.savez_compressed(output_path, **sim_results)
        print(f"Simulation results successfully saved to: {output_path}")
    except Exception as e:
        print(f"Error saving simulation results: {e}")
else:
    print("'sim_results' not found. Please run the simulation cell first.")

In [None]:
# --- Plot Simulation Results ---
import numpy as np
import matplotlib.pyplot as plt
import os

# Assuming WORKSPACE_ROOT is defined (e.g., from cell 921c860b or defined here again)
if 'WORKSPACE_ROOT' not in globals():
    WORKSPACE_ROOT = '/Users/suryaven/Documents/code/uni/neurotech/lucid-dreaming-core'

results_filename = "simulation_results.npz"
results_path = os.path.join(WORKSPACE_ROOT, results_filename)

if not os.path.exists(results_path):
    print(f"Results file not found at {results_path}. Please run the saving cell first.")
else:
    try:
        data = np.load(results_path, allow_pickle=True)
        print("Loaded data keys:", list(data.keys()))

        timestamps = data['timestamps']
        sleep_stages = data['sleep_stages']
        alertness_scores = data['alertness_scores']
        audio_cue_timestamps = data['audio_cue_timestamps']
        print(audio_cue_timestamps)
        
        fig, ax1 = plt.subplots(figsize=(15, 7))

        # Plot sleep stages on primary y-axis
        color = 'tab:blue'
        ax1.set_xlabel('Time (s)')
        ax1.set_ylabel('Sleep Stage', color=color)
        ax1.plot(timestamps, sleep_stages, color=color, linestyle='-', marker='.', label='Sleep Stage')
        ax1.tick_params(axis='y', labelcolor=color)
        ax1.set_yticks(np.unique(sleep_stages)) # Show actual stage values on y-axis
        # Remap sleep stage values to names for clarity if desired, e.g. using a dictionary
        # sleep_stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'REM', 4: 'N3'} # Example mapping
        # ax1.set_yticklabels([sleep_stage_names.get(stage, str(stage)) for stage in np.unique(sleep_stages)])

        # Create a secondary y-axis for alertness scores
        ax2 = ax1.twinx()
        color = 'tab:red'
        ax2.set_ylabel('Alertness Score (EMA)', color=color)  # we already handled the x-label with ax1
        ax2.plot(timestamps, alertness_scores, color=color, linestyle='--', label='Alertness Score')
        ax2.tick_params(axis='y', labelcolor=color)
        ax2.set_ylim(0, 1) # Alertness score is between 0 and 1

        # Plot audio cue triggers as vertical lines
        if len(audio_cue_timestamps) > 0:
            for cue_time in audio_cue_timestamps:
                ax1.axvline(x=cue_time, color='tab:green', linestyle=':', linewidth=2, label='Audio Cue (REM)' if 'Audio Cue (REM)' not in [l.get_label() for l in ax1.lines] else "")
        
        # Add a title and legend
        plt.title('Simulation Results: Sleep Stages, Alertness, and Audio Cues')
        fig.tight_layout()  # otherwise the right y-label is slightly clipped
        
        # Combine legends from both axes
        lines, labels = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        # Add dummy entry for axvline if cues were plotted
        if len(audio_cue_timestamps) > 0 and not any('Audio Cue (REM)' in l for l in labels + labels2):
             # Create a proxy artist for the legend if not already created by plot
            from matplotlib.lines import Line2D
            proxy_line = Line2D([0], [0], linestyle=':', color='tab:green', linewidth=2, label='Audio Cue (REM)')
            lines.append(proxy_line)
            labels.append('Audio Cue (REM)')

        ax2.legend(lines + lines2, labels + labels2, loc='upper right')

        plt.grid(True)
        plt.show()

    except Exception as e:
        print(f"Error plotting results: {e}")
        import traceback
        traceback.print_exc()