# SpikeInterface v0.101.2 - Adapted by Rodrigo Noseda - October 2024

SpikeInterface to analyze a multichannel dataset from Cambridge Neurotech Probes. 
The dataset is extracted using open-ephys DAQ and Bonsai-rx (in .bin).
Event_timestamps need some work.

# 0. Preparation <a class="anchor" id="preparation"></a>

In [None]:
import spikeinterface.full as si
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import os
import csv
import glob
from datetime import datetime
import warnings
warnings.simplefilter("ignore")
%matplotlib widget
print(f"SpikeInterface Version: {si.__version__}")

# 1. Load Recording and Paths <a class="anchor" id="loading"></a>

In [None]:
# Setting file paths and basic parameters
base_folder = Path('D:/Ephys_C2DRG/')
data_folder = Path("D:/Ephys_C2DRG/2023_9_19/")
#Pasted directly from explorer "C:\Users\rodri\Documents\Bonsai-RN\Bonsai_DataRN\2023_3_21\"

recording_paths_list = []
for filename in os.listdir(data_folder):
    if filename.startswith('RawEphysData') and filename.endswith('.bin'):
        recording_paths_list.append(data_folder / filename)
print('Recording Files List:')
print(recording_paths_list)

# parameters associated to the recording in bin format
num_channels = 64 #must know apriori; modify in probe below accordingly.
fs = 30000
gain_to_uV = 0.195
offset_to_uV = 0
rec_dtype = "float32"
time_axis = 0     
time_format = "%H:%M:%S.%f"
n_cpus = os.cpu_count()
n_jobs = n_cpus - 2 #n_jobs = -1 :equal to the number of cores.
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [None]:
#Extract and append recording segments to Baserecording object
recordings_list = []
rec = si.read_binary(recording_paths_list, num_chan=num_channels,sampling_frequency=fs,
                           dtype=rec_dtype, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV, 
                           time_axis=time_axis, is_filtered=False)
recordings_list.append(rec)#Appends all extracted rec to a list. Kilosort does not support segments. Use concatenation.
recording = si.concatenate_recordings(recordings_list)#Creates Object ConcatenateSegmentRecording
#Filtering recording
recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000)
recording_cmr = si.common_reference(recording_f, reference='global', operator='median')
recording_layers = dict(raw=recording,
                        filt=recording_f, 
                        common=recording_cmr)

In [None]:
w = si.plot_traces(recording_layers, time_range=[569, 600], channel_ids=[45],
                return_scaled=True, show_channel_ids=True, backend="ipywidgets")

# 2. Preprocessing: Artifact Removal <a class="anchor" id="preprocessing"></a>

In [None]:
# Function to calculate time difference in seconds
def time_difference_in_seconds(start_time, end_time):
    start = datetime.combine(datetime.min, start_time)
    end = datetime.combine(datetime.min, end_time)
    return (end - start).total_seconds()

# Load the Timestamps CSV files, extract start time
timestamp_start_df = pd.read_csv(data_folder / 'TimestampsEphys_0.csv', nrows=1, header=None, names=['Timestamps'])
timestamp_start = pd.to_datetime(timestamp_start_df['Timestamps'][0]).time() # Convert timestamps to datetime objects (only time part)

#Load ttl times from csv files, calculate time in seconds
ttl_files = sorted(glob.glob(os.path.join(data_folder, "TTL*.csv")))
concatenated_ttl_times = pd.DataFrame()
for ttl_file in ttl_files:
    TTL_df = pd.read_csv(ttl_file, header=None, names=['TTL_Times'])#(usecols=[0], nrows=1)
    concatenated_ttl_times = pd.concat([concatenated_ttl_times, TTL_df], ignore_index=True)
concatenated_ttl_times['TTL_Times'] = pd.to_datetime(concatenated_ttl_times['TTL_Times']).dt.time # Convert the TTL timestamps to timedelta (ignoring date)
# Apply the time difference function to each row in the TTL DataFrame
concatenated_ttl_times['time_diff_seconds'] = concatenated_ttl_times['TTL_Times'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
#print(concatenated_ttl_times)

#Load events times from csv files, calculate time in seconds
events_files = sorted(glob.glob(os.path.join(data_folder, "Events*.csv")))
concatenated_event_times = pd.DataFrame()
for event_file in events_files:
    Events_df = pd.read_csv(event_file, header=None, usecols=[0, 1], names=['Stim_start', 'Stim_end'])#(usecols=[0], nrows=1)
    concatenated_event_times = pd.concat([concatenated_event_times, Events_df], ignore_index=True)
concatenated_event_times['Stim_start'] = pd.to_datetime(concatenated_event_times['Stim_start']).dt.time # Convert the TTL timestamps to timedelta (ignoring date)
concatenated_event_times['Stim_end'] = pd.to_datetime(concatenated_event_times['Stim_end']).dt.time # Convert the TTL timestamps to timedelta (ignoring date)
# Apply the time difference function to each row in the TTL DataFrame
concatenated_event_times['time_diff_start_seconds'] = concatenated_event_times['Stim_start'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
concatenated_event_times['time_diff_end_seconds'] = concatenated_event_times['Stim_end'].apply(lambda x: time_difference_in_seconds(timestamp_start, x))
concatenated_event_times['stim_duration'] = concatenated_event_times['time_diff_end_seconds'] - concatenated_event_times['time_diff_start_seconds'] 
#print(concatenated_event_times)

In [None]:
# Plot a single channel trace to store the x coordinates on key_press (for remove artifact function)
coordinates_x = []
start_time = 6311 * fs
end_time = 6312 * fs
fig, ax = plt.subplots()
trace = recording_f.get_traces(start_frame=start_time, end_frame=end_time, channel_ids=[8, 43, 44], return_scaled=True)
time_axis = np.arange(start_time, end_time) / fs
ax.plot(time_axis, trace)

# Function to capture click events
def onclick(event):
    if event.xdata is not None:
        coordinates_x.append((event.xdata))# Store the key_press coordinates in the list
        print(f"Key pressed at: x={event.xdata}")# Display the coordinates
# Connect the key_press event to the figure
cid = fig.canvas.mpl_connect('key_press_event', onclick)
# Show the plot
plt.show()

In [None]:
# Write initial coordinates_x csv file containing a column of artifact timestamps (from onclick function) in seconds
def save_to_csv(filename, float_list):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for value in float_list:
            writer.writerow([value])
filename = 'artifacts_coordinates_xb.csv'
save_to_csv(data_folder / filename, coordinates_x)

In [None]:
# Append more floats to same coordinates_x csv file
def append_to_csv(filename, float_list):
    with open(filename, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for value in float_list:
            writer.writerow([value])
filename = 'artifacts_coordinates_x.csv'
append_to_csv(data_folder / filename, coordinates_x)

In [None]:
# Create sorted list of triggers in frames from coordinates_x csv file
filename = 'artifacts_coordinates_x.csv'
triggers_in_sec = []
with open(data_folder / filename, 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        for item in row:
            triggers_in_sec.append(float(item))
triggers_in_sec = [round(triggers_in_sec, 6) for triggers_in_sec in triggers_in_sec]
triggers_in_sec.sort()
triggers_in_frames = [round(i * fs) for i in triggers_in_sec]

#Creates list of tuples with start_frame and end_frame for artifact removal.
def create_tuples(triggers_in_frames):
    result = [(frame - 300, frame) for frame in triggers_in_frames] #300 frames is 10ms. Each frame represent the end of the artifact.
    return result
list_periods = create_tuples(triggers_in_frames)
print(list_periods)

In [None]:
# Create sorted list of triggers in frames from coordinates_x csv file
from operator import itemgetter 
filename = 'long_artifacts_coordinates_x.csv'
noise_in_sec = []
with open(data_folder / filename, 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        for item in row:
            noise_in_sec.append(float(item))
noise_in_sec = [round(noise_in_sec, 6) for noise_in_sec in noise_in_sec]
noise_in_sec.sort()

a = [1, 3, 5, 7]
noise_in_sec = (itemgetter(*a)(noise_in_sec))
noise_in_frames = [round(i * fs) for i in noise_in_sec]
noise1 = noise_in_frames[0:1]
noise2 = noise_in_frames[1:2]
noise3 = noise_in_frames[2:3]
noise4 = noise_in_frames[3:4]
print(noise_in_sec)

In [None]:
#Clean recording with silence_periods. List periods is one list per segment of tuples (start_frame, end_frame).
recording_clean = si.silence_periods(recording_cmr, list_periods=list_periods, seed=0, mode='zeros')

In [None]:
recording_layers2 = dict(raw=recording, common=recording_cmr, clean=recording_clean) 
w = si.plot_traces(recording_layers2, time_range=[0, 60], channel_ids=[15, 16, 17, 42, 61],
                    return_scaled=True, show_channel_ids=True, backend="ipywidgets")

In [None]:
recording_clean.save(format='binary', folder=data_folder / "recording_clean", **job_kwargs)

In [None]:
# Create a structured array
#data = np.array([(1.5, 2.5, 'Hello'), (3.0, 4.0, 'World')], 
#                 dtype=[('col1', 'f8'), ('col2', 'f8'), ('col3', 'U100')])

# Save the array to a .npy file
#np.save('my_array.npy', data)