In [None]:
# Some features of the code may need update (allways check the spikeinterface version and updates in libraries)
# To extract the individual neuronal unit data from raw data recorded by standard multielectrode arrays (MEAs: here multichannel systems MCS, 8x8 electrode)
# Before using the spikeinterface the MC_Rack data should be converted to Bin format using multichanneöl systems data tool
import os
import sys
import inspect
import numpy as np
import pandas as pd
from pprint import pprint
import matplotlib.pyplot as plt
import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw
import seaborn as sns
from spikeinterface import WaveformExtractor, extract_waveforms
from spikeinterface import extract_waveforms
from probeinterface import Probe, ProbeGroup
from probeinterface.plotting import plot_probe, plot_probe_group
from probeinterface import generate_multi_columns_probe
from probeinterface import write_probeinterface, read_probeinterface
from probeinterface import write_prb, read_prb
from IPython.display import Markdown, display

In [None]:
si.__version__

# Sorting Algorithm

In [None]:
# only need to run the first time, then can be commented out
!git clone https://github.com/flatironinstitute/ironclust

In [None]:
#set path to ironclust folder
ss.IronClustSorter.set_ironclust_path(r'D:\......./ironclust')

# Reading Files

In [None]:
## In case there are sub-folders
directory_folders = r"D:\........"
#sub_folders = [name for name in os.listdir(directory_folders) if os.path.isdir(os.path.join(directory_folders, name)) and name.endswith(".raw")]
#print('directory_folders: ', directory_folders)
#print('sub_folders: ', sub_folders)

In [None]:
directory_path = r"D:\....." 
mea_files = [file for file in os.listdir(directory_path) if file.endswith('.raw')]
batch_size = 5
file_counter = 0
for i in range(0, len(mea_files), batch_size):
    batch_files = mea_files[i:i+batch_size]
    extractors = []
    for mea_file in batch_files:
        file_path = os.path.join(directory_path, mea_file)
        recording = se.MCSRawRecordingExtractor(file_path)
        extractors.append(recording) 
        file_counter += 1
        print(f"File: {mea_file}, Num channels: {recording.get_num_channels()}, Duration: {recording.get_num_frames() / recording.get_sampling_frequency()} s")
        print(f"Files Read: {file_counter}")
    #extractors.clear()

## Setting Probe

In [None]:
probe = read_probeinterface('MCS_60channel_200_30_updated_contact_ids-11-88.json')
plot_probe_group(probe,with_contact_id=True, with_device_index=True)
# contact ids are the names of the electrodes as idicated on the MSC map and in MC_Rack. 
# Device channel ids are the indices of the streams according to the wiring.
probe_df=probe.to_dataframe(complete=True).loc[:, ["contact_ids", "device_channel_indices"]]
display(probe_df)

In [None]:
# Exclude channels that showe high noise levels
"""
exc_channels = [
    [42,84],
    [46]
]
"""
exc_channels = [
    [15], [21],[31],[41],[51],[61],[71],[12],[22],[32],[42],[52],[62],[]
]

exc_channels_ind = []
for i in range(len(exc_channels)):
    indices = []
    for k in range(len(exc_channels[i])):
        indices.append(probe_df["device_channel_indices"][int(np.where(probe_df["contact_ids"] == str(exc_channels[i][k]))[0])])
    exc_channels_ind.append(indices)
    
print(exc_channels_ind)

In [None]:
# Include selected channels in sorting

inc_channels = [16, 17, 25, 26, 27, 28,
                35, 36, 37, 38, 45, 46, 47, 48, 
                55, 56, 57, 58, 65, 66, 67, 68, 
                75, 76, 77, 78, 85, 86, 87] 
for i in inc_channels:
    condition = probe_df["contact_ids"] == str(i)
    indices = np.where(condition)[0] 
    if len(indices) > 0:
        index = indices[0]
        inc_channels_ind.append(probe_df["device_channel_indices"].iloc[index])
    else:
        print(f"No match found for {i}")
print(inc_channels_ind)
inc_channels_ind.sort()
print(inc_channels_ind)

# Optional

In [None]:
# options for parallel processing:
n_workers = 20
chunk_memory = "1000M"

# options for automatic curation/quality control
snr_thresh = 4.5
isi_viol_thresh = 0.2
#query = f"snr > {snr_thresh} & isi_violations_rate < {isi_viol_thresh}"

In [None]:
global_job_kwargs = dict(n_jobs=20, chunk_duration="1s")
si.set_global_job_kwargs(**global_job_kwargs)

# Sorting in Steps

In [None]:
###### Sorting in Steps ##### 

### Step 01 Preprocessing

In [None]:
#Step 01 Preprocessing

# List to keep track of files that had errors
error_files_preprocessing = []
for mea_file in mea_files:
    try:
        print(f"Processing file: {mea_file}")
        
        file_path = os.path.join(directory_path, mea_file)
        recording = se.MCSRawRecordingExtractor(file_path)
        print(recording)
        recording.annotate(is_filtered=False)
        channel_ids = recording.get_channel_ids()
        fs = recording.get_sampling_frequency()
        num_chan = recording.get_num_channels()
        num_segments = recording.get_num_segments()
        #print(f'Channel ids: {channel_ids}')
        #print(f'Sampling frequency: {fs}')
        #print(f'Number of channels: {num_chan}')
        #print(f"Number of segments: {num_segments}")
        
        # Set probes and preprocess
        recording_prb = recording.set_probes(probe)
        recording_f = spre.filter(recording_prb, band=100, btype="highpass", filter_mode="sos", ftype='butter')
        recording_cmr = spre.common_reference(recording_f, reference='global', operator='median')
        
        # Save preprocessed recording
        rec_preprocessed = recording_cmr.save(folder=f"preprocessed_{mea_file}", progress_bar=False, n_jobs=n_workers, chunk_memory=chunk_memory)
        print(f"Preprocessing and saving completed for file: {mea_file}")
    
    except Exception as e:
        print(f"An error occurred while processing {mea_file}: {e}")
        error_files_preprocessing.append(mea_file)

# list the files that had errors
if error_files_preprocessing:
    print("\nThe following files had errors and were not processed successfully:")
    for error_file in error_files_preprocessing:
        print(error_file)
else:
    print("\nAll files were processed successfully.")

### Step 02 Results and Sorting

In [None]:
# Step 02 Results and Sorting
# Sort preprocessed data using IronClust and save sorted data

error_files = []
for mea_file in mea_files:
    try:
        print(f"Processing file: {mea_file}")
        # Load the preprocessed folder
        recording_loaded = si.load_extractor(f"preprocessed_{mea_file}/")
        #print(f"Loaded recording: {recording_loaded}")
        # Define output folder for IronClust results
        output_folder = f"results_IC_{mea_file}"
        #print(f"Output folder: {output_folder}")

        # Run IronClust sorter
        recording_IC = ss.run_sorter('ironclust', 
                                     recording_loaded, 
                                     output_folder=output_folder, 
                                     detect_threshold=5, 
                                     verbose=True)
        
        # Save the sorted data
        recording_IC_saved = recording_IC.save(folder=f"sorting_IC_{mea_file}")
        print(f"Saved sorted data to: sorting_IC_{mea_file}")
    
    except Exception as e:
        print(f"An error occurred while processing {mea_file}: {e}")
        error_files.append(mea_file)

# list the files that had errors
if error_files:
    print("\nThe following files had errors and were not processed successfully:")
    for error_file in error_files:
        print(error_file)
else:
    print("\nAll files were processed successfully.")

### Step 03 Autocuration and exporting units timestamps

In [None]:
# Step 03 Autocuration and export the .CSV of sorted timestamps
# Automatic curation according to SNR threshold and ISI violation 

# List the files that had errors
error_files_autocuration = []
error_files_csv = []

for mea_file in mea_files:
    try:
        print(f"Processing autocuration for file: {mea_file}")
        recording_loaded = si.load_extractor(f"preprocessed_{mea_file}/")
        sorting_loaded = si.load_extractor(f"sorting_IC_{mea_file}/")
        recording_we = si.extract_waveforms(recording_loaded, sorting_loaded, folder=f"wf_{mea_file}", progress_bar=False, n_jobs=n_workers, chunk_memory=chunk_memory, overwrite=True)
        recording_qc = si.qualitymetrics.compute_quality_metrics(recording_we)
        keep_units_recording = recording_qc.query("snr > 4.5")
        keep_unit_ids_recording = keep_units_recording.index.values
        recording_sorting_autocur = sorting_loaded.select_units(keep_unit_ids_recording)
        recording_sorting_autocur.save(folder=f"autocurated_sorting_IC_{mea_file}")
    
    except Exception as e:
        print(f"An error occurred during autocuration for {mea_file}: {e}")
        error_files_autocuration.append(mea_file)

# Export timestamp information for each recording into a .csv file
for mea_file in mea_files:
    try:
        print(f"Exporting timestamps to CSV for file: {mea_file}")
        sorting = si.load_extractor(f"autocurated_sorting_IC_{mea_file}/")
        n = len(sorting.get_unit_ids())
        data = {}
        for k in range(n):
            unit_id = sorting.get_unit_ids()[k]
            data[str(unit_id)] = sorting.get_unit_spike_train(unit_id)
        df = pd.DataFrame.from_dict(data, orient='index').transpose()
        df.to_csv(f'autocurated_timestamps_{mea_file}.csv')
    
    except Exception as e:
        print(f"An error occurred while exporting timestamps for {mea_file}: {e}")
        error_files_csv.append(mea_file)

# list the files that had errors
if error_files_autocuration:
    print("\nThe following files had errors during autocuration and were not processed successfully:")
    for error_file in error_files_autocuration:
        print(error_file)
else:
    print("\nAll files were autocurated successfully.")

if error_files_csv:
    print("\nThe following files had errors during CSV export and were not processed successfully:")
    for error_file in error_files_csv:
        print(error_file)
else:
    print("\nAll files were exported to CSV successfully.")


### Step 04 Extraction of Unit IDs and Positions

In [None]:
#Step 04 Unit ID and Position
# extract unit ids and positions into .csv file

# List the files that had errors
error_files_unit_position = []

for mea_file in mea_files:
    try:
        print(f"Processing unit IDs and positions for file: {mea_file}")
        # Load preprocessed recording and autocurated sorting
        recording_loaded = si.load_extractor(f"preprocessed_{mea_file}/")
        sorting_autocur_loaded = si.load_extractor(f"autocurated_sorting_IC_{mea_file}/")
        
        recording_we = si.extract_waveforms(recording_loaded, sorting_autocur_loaded, 
                                            folder=f"wf_{mea_file}", progress_bar=True, n_jobs=n_workers, 
                                            chunk_memory=chunk_memory, overwrite=True)
        # Extract unit IDs and positions
        unit_ids = sorting_autocur_loaded.get_unit_ids()
        print(f"{mea_file}: Unit IDs - {unit_ids}, number of units: {len(unit_ids)}")
        unit_loc = spost.compute_unit_locations(recording_we)
        print(f"{mea_file}: Unit locations - {unit_loc}, number of locations: {len(unit_loc)}")
        
        # Prepare data for DataFrame
        data = {'unit_id': [], 'x': [], 'y': [], 'z': []}
        for unit_id, position in zip(unit_ids, unit_loc):
            data['unit_id'].append(unit_id)
            data['x'].append(position[0])  
            data['y'].append(position[1])  
            data['z'].append(position[2])  
        df = pd.DataFrame(data)
        print(f"{mea_file}: DataFrame -\n{df}, number of units: {len(df)}")
        df.to_csv(os.path.join(directory_path, f"unit_positions_{mea_file}.csv"), index=False)
        print(f"{mea_file}: CSV export completed successfully.")
    except Exception as e:
        print(f"An error occurred while processing {mea_file}: {e}")
        error_files_unit_position.append(mea_file)

#list the files that had errors
if error_files_unit_position:
    print("\nThe following files had errors and were not processed successfully:")
    for error_file in error_files_unit_position:
        print(error_file)
else:
    print("\nAll files were processed successfully.")