In [None]:
import os
import sys
import inspect
import numpy as np
import pandas as pd
from pprint import pprint
import matplotlib.pyplot as plt
import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
import seaborn as sns
from spikeinterface import WaveformExtractor, extract_waveforms
from spikeinterface import extract_waveforms
from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface import generate_multi_columns_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
from IPython.display import Markdown, display

In [None]:
si.__version__

# Sorting Algorithm

In [None]:
# only need to run the first time, then can be commented out
!git clone https://github.com/flatironinstitute/ironclust

In [None]:
#set path to ironclust folder
ss.IronClustSorter.set_ironclust_path(r'D:\MEA_Analysis\SI_files/ironclust')

# Reading Files

In [None]:
## reads from directory
directory_folders = r"D:\MEA_Analysis\SI_files"
sub_folders = [name for name in os.listdir(directory_folders) if os.path.isdir(os.path.join(directory_folders, name)) and name.endswith(".raw")]

print('directory_folders: ', directory_folders)
print('sub_folders: ', sub_folders)



In [None]:

# Define the directory path 
directory_path = r"D:\MEA_Analysis\SI_files" 

# List all files in the directory
mea_files = [file for file in os.listdir(directory_path) if file.endswith('.raw')]

# Define the batch size
batch_size = 5

# Initialize a counter for the number of files read
file_counter = 0

# Process files in batches
for i in range(0, len(mea_files), batch_size):
    batch_files = mea_files[i:i+batch_size]
    extractors = []

    # Loop through the list of files in the current batch and create recording extractors
    for mea_file in batch_files:
        file_path = os.path.join(directory_path, mea_file)
        recording = se.MCSRawRecordingExtractor(file_path)
        extractors.append(recording) 

        # Increment the file counter
        file_counter += 1

        # Print some information about the loaded recording and the current file count
        print(f"File: {mea_file}, Num channels: {recording.get_num_channels()}, Duration: {recording.get_num_frames() / recording.get_sampling_frequency()} s")
        print(f"Files Read: {file_counter}")

    #extractors.clear()

#                                       Ploting Signal Traces

In [None]:
ch_list_sorted = [
    '14', '9', '11', '14', '15', '18', '20', '14',
    '6',  '8', '10', '13', '16', '19', '21', '23',
    '4',  '5',  '7', '12', '17', '22', '24', '25', 
    '1',  '0',  '2',  '3', '26', '27', '29', '28', 
    '58', '59', '57', '56', '33', '32', '30', '31', 
    '55', '54', '52', '47', '42', '37', '35', '34',
    '53', '51', '49', '46', '43', '40', '38', '36', 
    '14', '50', '48', '45', '44', '41', '39', '14']

# Calculate the size of the matrix
n = int(np.sqrt(len(ch_list_sorted)))

# Create a 2D matrix with the specified layout
matrix = np.array(ch_list_sorted).reshape((n, n))

# Rotate the matrix counter-clockwise by 90 degrees
rotated_matrix = np.rot90(matrix, k=1)

# Flatten the rotated matrix to get the desired list
rotated_list = rotated_matrix.flatten()

# Convert elements to strings
rotated_list = [str(element) for element in rotated_list]

print(rotated_list)

In [None]:
### Plotting selected channels. 
time_range = (50, 51)

channel_list = ['53', '51', '49', '46', '43', '40', '38', '36']

# Create a figure with subplots
fig, axs = plt.subplots(len(channel_list), 1, figsize=(40, 1 * len(channel_list)))

# Iterate over channels and plot timeseries in each subplot
for i, channel_id in enumerate(channel_list):
    sw.plot_timeseries(recording,
                      channel_ids=[channel_id],
                      time_range=(50, 70),
                      mode='auto',
                      return_scaled=True,
                      cmap='RdBu',
                      show_channel_ids=True,
                      color_groups=True,
                      color=None,
                      clim=None,
                      tile_size=1500,
                      seconds_per_row=0.2,
                      with_colorbar=True,
                      add_legend=None,
                      backend=None,
                      ax=axs[i])

    # Set y-axis limits to fit the time series within the window
    axs[i].set_ylim([-50, 50])  # Adjust this range as needed
    
    # Increase font size of channel_ids
    axs[i].tick_params(axis='both', labelsize=40)  # Adjust the labelsize as needed

    #axs[i].set_title(f'Ch {channel_id} Raw', fontsize=8)
    # Remove x-axis label
    axs[i].set_xticks([])
    axs[i].set_xlabel('')
    # Remove spines (borders)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    axs[i].spines['left'].set_visible(False)

# Adjust the layout and display the plot
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0)  # Set spacing between subplots to 0
plt.show()

#################### Filterd traces ########################################################
rec_f = spre.filter(recording, band=200, btype="highpass", filter_mode="sos", ftype='butter')

# Create a figure with subplots
fig, axs = plt.subplots(len(channel_list), 1, figsize=(40, 1 * len(channel_list)))

# Iterate over channels and plot timeseries in each subplot
for i, channel_id in enumerate(channel_list):
    sw.plot_timeseries(rec_f,
                      channel_ids=[channel_id],
                      time_range=(50, 70),
                      mode='auto',
                      return_scaled=True,
                      cmap='RdBu',
                      show_channel_ids=True,
                      color_groups=True,
                      color='blue',
                      clim=None,
                      tile_size=1500,
                      seconds_per_row=0.2,
                      with_colorbar=True,
                      add_legend=None,
                      backend=None,
                      ax=axs[i])

    # Set y-axis limits to fit the time series within the window
    axs[i].set_ylim([-50, 50])  # Adjust this range as needed
    
    # Increase font size of channel_ids
    axs[i].tick_params(axis='both', labelsize=40)  # Adjust the labelsize as needed

    #axs[i].set_title(f'Ch {channel_id} Raw', fontsize=8)
    # Remove x-axis label
    axs[i].set_xticks([])
    axs[i].set_xlabel('')
    # Remove spines (borders)
    axs[i].spines['top'].set_visible(False)
    axs[i].spines['right'].set_visible(False)
    axs[i].spines['bottom'].set_visible(False)
    axs[i].spines['left'].set_visible(False)

# Adjust the layout and display the plot
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0)  # Set spacing between subplots to 0
plt.show()


## Setting Probe

In [None]:
probe = read_probeinterface('MCS_60channel_200_30_updated_contact_ids-11-88.json')
plot_probe_group(probe,with_contact_id=True, with_device_index=True)

# contact ids are the names of the electrodes as idicated on the MSC map and in MC_Rack. Device channel ids are the indices of the streams according to the wiring.

probe_df=probe.to_dataframe(complete=True).loc[:, ["contact_ids", "device_channel_indices"]]
display(probe_df)

In [None]:
# contact ids are the names of the electrodes as indicated on the MSC map and in MC_Rack. 
# Device channel ids are the indices of the streams according to the wiring.
probe_df=probe.to_dataframe(complete=True).loc[:, ["contact_ids", "device_channel_indices"]]
probe_df["electrode_names"]=electrode_names
display(probe_df)

In [None]:
# channels that will be excluded from sorting due to high noise levels
"""
exc_channels = [
    [42,84],
    [46]
]
"""


exc_channels_ind = []
for i in range(len(exc_channels)):
    indices = []
    for k in range(len(exc_channels[i])):
        indices.append(probe_df["device_channel_indices"][int(np.where(probe_df["contact_ids"] == str(exc_channels[i][k]))[0])])
    exc_channels_ind.append(indices)
    
print(exc_channels_ind)

In [None]:
# Include channels 
"""
inc_channels = [16, 17, 25, 26, 27, 28,
                35, 36, 37, 38, 45, 46, 47, 48 
                ]  
inc_channels_ind = []
"""
## This part of code changed by Rola
for i in inc_channels:
    condition = probe_df["contact_ids"] == str(i)
    indices = np.where(condition)[0]  # This will still give us an array of indices
    if len(indices) > 0:
        # If there are multiple matches, this takes the first. Adjust accordingly.
        index = indices[0]
        inc_channels_ind.append(probe_df["device_channel_indices"].iloc[index])
    else:
        # Handle the case where no match is found, if necessary
        print(f"No match found for {i}")

print(inc_channels_ind)
inc_channels_ind.sort()
print(inc_channels_ind)

# Optional

In [None]:
# options for parallel processing:
n_workers = 20
chunk_memory = "1000M"

# options for automatic curation/quality control
snr_thresh = 5
isi_viol_thresh = 0.2
#query = f"snr > {snr_thresh} & isi_violations_rate < {isi_viol_thresh}"

#digital data for stimulation timestamps
n_digital_bits = 16
input_stim_bit = 4 
output_stim_bit = 5

In [None]:
global_job_kwargs = dict(n_jobs=20, chunk_duration="1s")
si.set_global_job_kwargs(**global_job_kwargs)

In [None]:
se.recording_extractor_full_list

In [None]:
################### PROCESSING, SORTING and CURATION #################################################


# load and preprocess data from folder

for mea_file in mea_files:
    file_path = os.path.join(directory_path, mea_file)
    recording = se.MCSRawRecordingExtractor(mea_file)
    print(recording)
    recording.annotate(is_filtered=False)
    channel_ids = recording.get_channel_ids()
    fs = recording.get_sampling_frequency()
    num_chan = recording.get_num_channels()
    num_segments = recording.get_num_segments()
    print(f'Channel ids: {channel_ids}')
    print(f'Sampling frequency: {fs}')
    print(f'Number of channels: {num_chan}')
    print(f"Number of segments: {num_segments}")
    recording_prb = recording.set_probes(probe)
    recording_f = spre.filter(recording_prb, band=100, btype="highpass", filter_mode="sos", ftype='butter')
    recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
    rec_preprocessed = recording_cmr.save(folder=str("preprocessed_" + str(mea_file)), progress_bar=True, n_jobs=n_workers, chunk_memory=chunk_memory)

# sort preprocessed data using ironclust and save sorted data


for mea_file in (mea_files):
    print(mea_file)
    recording_loaded = si.load_extractor(str("preprocessed_" + str(mea_file) + "/"))
    print (recording_loaded)

    output_folder=str("results_IC_" + str(mea_file))
    print(output_folder)

    recording_IC = ss.run_sorter('ironclust', 
                                 recording_loaded, 
                                 output_folder=str("results_IC_" + str(mea_file)), 
                                 detect_threshold=5, 
                                 verbose=True)
    
    
    recording_IC_saved = recording_IC.save(folder=str("sorting_IC_" + str(mea_file)))


#automatic curation according to SNR threshold and ISI violation (SNR set manually and ISI violation was not included)
for mea_file in (mea_files):
    recording_loaded = si.load_extractor(str("preprocessed_" + str(mea_file) + "/"))
    sorting_loaded = si.load_extractor('sorting_IC_' + str(mea_file) + "/")
    recording_we = si.extract_waveforms(recording_loaded, sorting_loaded, folder= "wf_" + str(mea_file), progress_bar=True,n_jobs=n_workers, chunk_memory=chunk_memory, overwrite=True)
    recording_qc = si.qualitymetrics.compute_quality_metrics(recording_we)
    keep_units_recording = recording_qc.query("snr > 5 ")
    keep_unit_ids_recording = keep_units_recording.index.values
    recording_sorting_autocur = sorting_loaded.select_units(keep_unit_ids_recording)
    recording_sorting_autocur.save(folder=str("autocurated_sorting_IC_" + str(mea_file)))

# expoprt timestamp information for each recording into a .csv file

for mea_file in (mea_files):
    sorting = si.load_extractor('autocurated_sorting_IC_' + str(mea_file) + "/")
    n = len(sorting.get_unit_ids())
    data = {}
    for k in range(n):
        data.update( {str(sorting.get_unit_ids()[k]) : sorting.get_unit_spike_train(sorting.get_unit_ids()[k])})
    df = pd.DataFrame.from_dict(data,orient='index')
    df = df.transpose()
    df.to_csv('autocurated_timestamps_'+ str(mea_file) +'.csv')




# extract unit ids and positions into csv file
for mea_file in mea_files:
    # Load preprocessed recording and autocurated sorting
    sorting = si.load_extractor('autocurated_sorting_IC_' + str(mea_file) + "/")
    si.load_extractor(str("preprocessed_" + str(mea_file) + "/"))
    recording_we = si.extract_waveforms(recording_loaded, sorting_loaded, folder= "wf_" + str(mea_file), progress_bar=True,n_jobs=n_workers, chunk_memory=chunk_memory, overwrite=True)
    #recording_loaded = si.load_extractor(os.path.join(directory_path, f"preprocessed_{mea_file}/"))
    
    
    # Extract unit IDs and positions
    unit_ids = sorting.get_unit_ids()
    unit_loc=spost.compute_unit_locations(recording_we)
    
    # Prepare data for DataFrame
    data = {'unit_id': [], 'x': [], 'y': [], 'z': []}
    for unit_id, position in zip(unit_ids, unit_loc):
        data['unit_id'].append(unit_id)
        data['x'].append(position[0])  # Assuming x is the first element in the position tuple
        data['y'].append(position[1])  # Assuming y is the second element in the position tuple
        data['z'].append(position[2])  # Assuming z is the third element in the position tuple
    
    # Create DataFrame
    df = pd.DataFrame(data)
    
    # Export to CSV
    df.to_csv(os.path.join(directory_path, f"unit_positions_{mea_file}.csv"), index=False)

In [None]:
# extract unit ids and positions into csv file
for mea_file in mea_files:
    # Load preprocessed recording and autocurated sorting

    recording_loaded = si.load_extractor(str("preprocessed_" + str(mea_file) + "/"))
    sorting_autocur_loaded = si.load_extractor('autocurated_sorting_IC_' + str(mea_file) + "/")
    recording_we = si.extract_waveforms(recording_loaded, sorting_autocur_loaded, folder= "wf_" + str(mea_file), progress_bar=True,n_jobs=n_workers, chunk_memory=chunk_memory, overwrite=True)
    
    # Extract unit IDs and positions
    unit_ids = sorting_autocur_loaded.get_unit_ids()
    print(str(mea_file), unit_ids, 'number of units:', len(unit_ids))
    unit_loc=spost.compute_unit_locations(recording_we)
    print(str(mea_file), unit_loc, 'loc:', len(unit_loc))
    
    
    # Prepare data for DataFrame
    data = {'unit_id': [], 'x': [], 'y': [], 'z': []}
    print(r'data', data)
    for unit_id, position in zip(unit_ids, unit_loc):
        data['unit_id'].append(unit_id)
        data['x'].append(position[0])  # Assuming x is the first element in the position tuple
        data['y'].append(position[1])  # Assuming y is the second element in the position tuple
        data['z'].append(position[2])  # Assuming z is the third element in the position tuple
    
    print(r'unit_id', unit_id)
    # Create DataFrame
    df = pd.DataFrame(data)

    print(r'df', df, 'number of units:', len(df))
    
    # Export to CSV
    df.to_csv(os.path.join(directory_path, f"unit_positions_{mea_file}.csv"), index=False)