In [None]:
# %%
import spike.spike_analysis.spike_collection as sc
import spike.spike_analysis.spike_recording as sr
import spike.spike_analysis.firing_rate_calculations as fr
import spike.spike_analysis.normalization as norm
import spike.spike_analysis.single_cell as single_cell
import spike.spike_analysis.spike_collection as collection
import spike.spike_analysis.zscoring as zscoring
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import behavior.boris_extraction as boris
import matplotlib.pyplot as plt
import pickle
import re

# %% [markdown]
# ## Formatting for Pandas

# %%
pd.set_option('display.max_colwidth', 0)  # 0 means unlimited in newer pandas versions

# Show all rows
pd.set_option("display.max_rows", None)

# Show all columns
pd.set_option("display.max_columns", None)

# Donâ€™t truncate column contents
pd.set_option("display.max_colwidth", None)

# Expand the display to the full width of the screen
pd.set_option("display.width", 0)


# %%
spike_collection_json_path = r'C:\Users\thoma\Code\ResearchCode\diff_fam_social_memory_ephys\spike_collection.json\spike_collection.json'

# %% [markdown]
# ## Reloading Main 2 Packages

# %%
import importlib
importlib.reload(sc)
importlib.reload(zscoring)

# %% [markdown]
# ### Loading in SpikeCollection Object with Recordings and their corresponding event dicts

# %%
sp = sc.SpikeCollection.load_collection(spike_collection_json_path)

# %%
rec_events = sp.recordings[0].event_dict

# get unique event names from rec_events dictionary
event_names = list(rec_events.keys())
print("Unique event names:", event_names)

# %% [markdown]
# #### Unique event names: ['alone_rewarded', 'alone_rewarded_baseline', 'high_comp', 'high_comp_lose', 'high_comp_lose_baseline', 'high_comp_win', 'high_comp_win_baseline', 'lose', 'low_comp', 'low_comp_lose', 'low_comp_lose_baseline', 'low_comp_win', 'low_comp_win_baseline', 'overall_pretone', 'win']
# 

# %% [markdown]
# ### Verifying it's in timestamps not ms

# %%
# Pick any one recording and unit
recording = sp.recordings[0]  # or choose a specific one
unit_id = list(recording.unit_timestamps.keys())[0]  # get the first available good unit

# Extract the raw spike timestamps
raw_spikes = recording.unit_timestamps[unit_id]

# Show the first few spikes
print(f"Raw spike timestamps for unit {unit_id}:")
print(raw_spikes[:10])

# Convert to milliseconds
converted_spikes_ms = raw_spikes * (1000 / recording.sampling_rate)
print("\nConverted to milliseconds:")
print(converted_spikes_ms[:10])

# Also print min/max to check range
print(f"\nMin raw spike: {raw_spikes.min()} | Max raw spike: {raw_spikes.max()}")
print(f"Min spike time in ms: {converted_spikes_ms.min():.2f} ms | Max: {converted_spikes_ms.max():.2f} ms")


# %% [markdown]
# ### Z-Score for an event using baselines of all events in a recording

# %%
import numpy as np
import pandas as pd

def run_zscore_global_baseline(recording, event_name, pre_window=10, SD=1.65, verbose=False):
    """
    Z-score event firing rates using a *pooled* baseline (all event types) per unit.
    This function calculates the z-score of firing rates for a specific event type
    based on a global baseline computed from all event types in the recording.
    Parameters:
    - recording: SpikeRecording object containing spike data and events.
    - event_name: Name of the event type to analyze.
    - pre_window: Duration in seconds before the event to use for baseline calculation.
    - SD: Number of standard deviations to use for significance thresholding.
    - verbose: If True, prints additional information during processing.
    Returns:
    - A pandas DataFrame containing the z-scores and significance of firing rates for each unit
    for the specified event type.
    """
    # Step 1: Pool all baseline windows across all events for each unit
    global_baseline_counts = {}
    units = getattr(recording, "good_units", None) # get good units if available
    if units is None: # if not, use labels_dict
        units = [unit_id for unit_id, label in recording.labels_dict.items() if label == "good"]

        if verbose:
            print("Using labels_dict to determine good units.")
            print(f"Good units found: {units}\n\nFrom labels_dict: {recording.labels_dict}\n\n")

    # Initialize global baseline list per unit
    for unit_id in units:
        global_baseline_counts[unit_id] = []
    if verbose:
        print(f"Gloabal baseline counts initialized for units: {global_baseline_counts}\n")    

    # Loop through all event types and pool all baselines
    # creates a list of baseline counts for each unit
    for ev_type, event_windows in recording.event_dict.items():
        for unit_id in units:
            spikes = recording.unit_timestamps[unit_id] # number of spikes for this unit
            spikes_ms = spikes * (1000 / recording.sampling_rate) # convert to milliseconds since event_windows are in ms

            for window in event_windows:
                start_event = window[0]
                start_baseline = start_event - int(pre_window * 1000)
                if verbose:
                    print(f"For window {window}, start_event: {start_event}, start_baseline: {start_baseline}, pre_window: {pre_window}")

                end_baseline = start_event
                baseline_count = np.sum((spikes_ms >= start_baseline) & (spikes_ms < end_baseline))
                global_baseline_counts[unit_id].append(baseline_count) # list of counts for this unit appended to global_baseline_counts

                if verbose:
                    print(f"Unit {unit_id}, Event {ev_type}, Baseline count: {baseline_count} in {global_baseline_counts}\n")

    # Step 2: Compute global baseline mean and SD per unit using numpy
    baseline_mean = {u: np.mean(c) for u, c in global_baseline_counts.items()}
    baseline_sd = {u: np.std(c) for u, c in global_baseline_counts.items()}

    # Step 3: For the target event, calculate z-scores
    event_windows = recording.event_dict[event_name]
    event_firing = {}
    rows = []
    for unit_id in units:
        spikes = recording.unit_timestamps[unit_id]
        spikes_ms = spikes * (1000 / recording.sampling_rate)
        event_counts = []
        for window in event_windows:
            start_event = window[0]
            end_event = window[1]
            event_count = np.sum((spikes_ms >= start_event) & (spikes_ms < end_event)) # count spikes in the event window using masking
            event_counts.append(event_count)

        # getting all the important values for z-score calculation per unit
        ev_mean = np.mean(event_counts)
        b_mean = baseline_mean[unit_id]
        b_sd = baseline_sd[unit_id]

        # Calculate z-score
        zscore = np.nan if b_sd == 0 else (ev_mean - b_mean) / b_sd 


        # significance determination based on SD threshold given
        sig = "not sig"
        if not np.isnan(zscore):
            if zscore > SD:
                sig = "increase"
            elif zscore < -SD:
                sig = "decrease"

        rows.append({
            "Recording": recording.name,
            "Event name": event_name,
            "Unit number": unit_id,
            "Global Pre-event M": b_mean,
            "Global Pre-event SD": b_sd,
            "Event M": ev_mean,
            "Event Z-Score": zscore,
            "sig": sig
        })

    df = pd.DataFrame(rows)
    return df


# %%
rec = sp.recordings[0]
event_name = "alone_rewarded"  # or any event you know exists

df = run_zscore_global_baseline(rec, event_name, pre_window=10, SD=1.65)
df.head(20)