# SpikeInterface v0.101.2 - Adapted by Rodrigo Noseda - October 2024

SpikeInterface to analyze a multichannel dataset from Cambridge Neurotech Probes. 
The dataset is extracted using open-ephys DAQ and Bonsai-rx (in .bin).
Event_timestamps need some work.

# 0. Preparation <a class="anchor" id="preparation"></a>

In [None]:
import spikeinterface.full as si
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import os
import csv
import glob
from datetime import datetime
import warnings
warnings.simplefilter("ignore")
%matplotlib widget
print(f"SpikeInterface Version: {si.__version__}")

# 1. Load Recording and Paths <a class="anchor" id="loading"></a>

In [None]:
# Setting file paths and basic parameters
base_folder = Path('D:/Ephys_C2DRG/')
data_folder = Path("D:/Ephys_C2DRG/2023_9_28/")
#Pasted directly from explorer "C:\Users\rodri\Documents\Bonsai-RN\Bonsai_DataRN\2023_3_21\"

recording_paths_list = []
for filename in os.listdir(data_folder):
    if filename.startswith('RawEphysData') and filename.endswith('.bin'):
        recording_paths_list.append(data_folder / filename)
print('Recording Files List:')
print(recording_paths_list)

# parameters associated to the recording in bin format
num_channels = 64 #must know apriori; modify in probe below accordingly.
fs = 30000
gain_to_uV = 0.195
offset_to_uV = 0
rec_dtype = "float32"
time_axis = 0     
time_format = "%H:%M:%S.%f"
n_cpus = os.cpu_count()
n_jobs = n_cpus - 2 #n_jobs = -1 :equal to the number of cores.
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [None]:
#Extract and append recording segments to Baserecording object
recordings_list = []
rec = si.read_binary(recording_paths_list, num_chan=num_channels,sampling_frequency=fs,
                           dtype=rec_dtype, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV, 
                           time_axis=time_axis, is_filtered=False)
recordings_list.append(rec)#Appends all extracted rec to a list. Kilosort does not support segments. Use concatenation.
recording = si.concatenate_recordings(recordings_list)#Creates Object ConcatenateSegmentRecording
#Filtering recording
recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_layers = dict(raw=recording,
                        filt=recording_f, 
                        common=recording_cmr)
recording_cmr

In [None]:
w = si.plot_traces(recording_cmr, time_range=[0, 5], channel_ids=[5],
                return_scaled=True, show_channel_ids=True, backend="ipywidgets")

# 2. Preprocessing: Artifact Removal <a class="anchor" id="preprocessing"></a>

In [None]:
# Function to calculate time difference in seconds
def time_difference_in_seconds(start_time, end_time):
    start = datetime.combine(datetime.min, start_time)
    end = datetime.combine(datetime.min, end_time)
    return (end - start).total_seconds()

# Load the Timestamps CSV files, extract start time
timestamp_start_df = pd.read_csv(data_folder / 'TimestampsEphys_0.csv', nrows=1, header=None, names=['Timestamps'])
timestamp_start = pd.to_datetime(timestamp_start_df['Timestamps'][0]).time() # Convert timestamps to datetime objects (only time part)

#Load times from timestamps csv files and calculate start time in seconds.
tms_files = sorted(glob.glob(os.path.join(data_folder, "Timestamps*.csv")))
concatenated_segment_times = pd.DataFrame()
for tms_file in tms_files:
    tms_df = pd.read_csv(tms_file, header=None)#, names=['Start_Times', 'End_Times'])#(usecols=[0], nrows=1)
    segment_times = pd.DataFrame({'Start_Times': tms_df.iloc[0], 'End_Times': tms_df.iloc[-1]})
    concatenated_segment_times = pd.concat([concatenated_segment_times, segment_times], ignore_index=True)
concatenated_segment_times['Start_Times'] = pd.to_datetime(concatenated_segment_times['Start_Times']).dt.time # Convert the tms timestamps to timedelta (ignoring date)
concatenated_segment_times['End_Times'] = pd.to_datetime(concatenated_segment_times['End_Times']).dt.time # Convert the tms timestamps to timedelta (ignoring date)
#concatenated_segment_times['Segment_Times'] = pd.to_datetime(concatenated_segment_times['Segment_Times']).dt.time # Convert the tms timestamps to timedelta (ignoring date)
# Apply the time difference function to each row in the tms DataFrame
concatenated_segment_times['Start_Times_seconds'] = concatenated_segment_times['Start_Times'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
concatenated_segment_times['End_Times_seconds'] = concatenated_segment_times['End_Times'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
concatenated_segment_times['Segment_duration_seconds'] = concatenated_segment_times['End_Times_seconds'] - concatenated_segment_times['Start_Times_seconds']
#print(concatenated_segment_times)

In [None]:
#Load ttl times from csv files, calculate time in seconds
ttl_files = sorted(glob.glob(os.path.join(data_folder, "TTL*.csv")))
concatenated_ttl_times = pd.DataFrame()
for ttl_file in ttl_files:
    TTL_df = pd.read_csv(ttl_file, header=None, names=['TTL_Times'])#(usecols=[0], nrows=1)
    concatenated_ttl_times = pd.concat([concatenated_ttl_times, TTL_df], ignore_index=True)
concatenated_ttl_times['TTL_Times'] = pd.to_datetime(concatenated_ttl_times['TTL_Times']).dt.time # Convert the TTL timestamps to timedelta (ignoring date)
# Apply the time difference function to each row in the TTL DataFrame
concatenated_ttl_times['time_diff_seconds'] = concatenated_ttl_times['TTL_Times'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
#print(concatenated_ttl_times)

#Load events times from csv files, calculate time in seconds
events_files = sorted(glob.glob(os.path.join(data_folder, "Events*.csv")))
concatenated_event_times = pd.DataFrame()
for event_file in events_files:
    Events_df = pd.read_csv(event_file, header=None, usecols=[0, 1], names=['Stim_start', 'Stim_end'])#(usecols=[0], nrows=1)
    concatenated_event_times = pd.concat([concatenated_event_times, Events_df], ignore_index=True)
concatenated_event_times['Stim_start'] = pd.to_datetime(concatenated_event_times['Stim_start']).dt.time # Convert the TTL timestamps to timedelta (ignoring date)
concatenated_event_times['Stim_end'] = pd.to_datetime(concatenated_event_times['Stim_end']).dt.time # Convert the TTL timestamps to timedelta (ignoring date)
# Apply the time difference function to each row in the TTL DataFrame
concatenated_event_times['time_diff_start_seconds'] = concatenated_event_times['Stim_start'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
concatenated_event_times['time_diff_end_seconds'] = concatenated_event_times['Stim_end'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
concatenated_event_times['stim_duration'] = concatenated_event_times['time_diff_end_seconds'] - concatenated_event_times['time_diff_start_seconds'] 
#print(concatenated_event_times)

In [None]:
# Plot a single channel trace to store the x coordinates on key_press (for remove artifact function)
coordinates_x_end = [] #holds the end timestamp of the short artifact (< 10ms)
coordinates_x_long = [] #holds the start and end timestamp of a long artifact (> 10ms). Usually ~1 sec
start_time = 9000 * fs
end_time = 9660 * fs
fig, ax = plt.subplots()
trace = recording_f.get_traces(start_frame=start_time, end_frame=end_time, channel_ids=[45], return_scaled=True)
time_axis = np.arange(start_time, end_time) / fs
ax.plot(time_axis, trace)

# Initialize a variable to keep track of the column position (0 or 1)
click_count = 0
# Function to capture click events
def onkey(event):
    global click_count
    if event.key == 'z': # Only respond to the "z" key
        if event.xdata is not None:
            coordinates_x_end.append((event.xdata))# Store the key_press coordinates in the list
            print(f"Key 'z' pressed at: x={event.xdata}")# Display the coordinates
                
    elif event.key == 'w': # Only respond to the "a" key
        if event.xdata is not None:
            if click_count % 2 == 0:
                # Start a new row for each pair of clicks
                coordinates_x_long.append([event.xdata, None])  # Initialize the second column as None
            else:
                # Update the second column for the latest row
                coordinates_x_long[-1][1] = event.xdata
            print(f"Key 'w' pressed at: x={event.xdata}")
            click_count += 1

# Connect the button press event to the figure
cid = fig.canvas.mpl_connect('key_press_event', onkey)
plt.show()

In [None]:
# Create arrays with times in seconds for artifact removal. Remove 10 ms for TTL artifacts. 
filename = 'artifacts_coordinates_x.csv'
sec = 0.01 # 10ms
arr = np.array([coordinates_x_end][0], dtype=np.float32)
new_col = arr - sec
coordinates_end = np.column_stack((new_col, arr))
coordinates_long = np.array([coordinates_x_long][0], dtype=np.float32)
artifacts_coordinates_x = np.row_stack((coordinates_end, coordinates_long))
# Create new csv file with values (6 decimal points), and append new values to existing csv file
with open(data_folder / filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for row in artifacts_coordinates_x:
                writer.writerow([f"{value:.6f}" for value in row])

In [None]:
# Create list of tuples of triggers in frames from artifacts_coordinates_x csv file
filename = 'artifacts_coordinates_x.csv'
triggers_in_frames = []
with open(data_folder / filename, 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
         triggers_in_frames.append([((float(row[0])*fs)), ((float(row[1])*fs))])
triggers_in_frames.sort()
print(triggers_in_frames)

In [None]:
#Clean recording with silence_periods. List periods is one list per segment of tuples (start_frame, end_frame).
recording_clean = si.silence_periods(recording_cmr, list_periods=triggers_in_frames, seed=0, mode='zeros')

In [None]:
recording_layers2 = dict(raw=recording, common=recording_cmr, clean=recording_clean) 
w = si.plot_traces(recording_layers2, time_range=[3380, 3480], channel_ids=[8, 45],
                    return_scaled=True, show_channel_ids=True, backend="ipywidgets")

In [None]:
recording_clean.save(format='binary', folder=data_folder / "recording_clean", overwrite=True, **job_kwargs)