In [1]:
import numpy as np
import os
import pickle
import warnings
import tkinter as tk
from tkinter import filedialog
from collections import defaultdict
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib
import spikeinterface.full as si
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from concurrent.futures import ProcessPoolExecutor

# Configuration graphique
matplotlib.use('qt5agg')
warnings.filterwarnings('ignore')

In [2]:
def save_obj(obj, name):
    name = name if name.endswith('.pkl') else name + '.pkl'
    path = os.path.normpath(name)
    if os.path.dirname(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    name = name if name.endswith('.pkl') else name + '.pkl'
    with open(os.path.normpath(name), 'rb') as f:
        return pickle.load(f)
    
def load_data(input_path, channel_id, nb_channels=256, dtype="uint16", voltage_resolution=0.1042, chunk_size=1000000):
    """Charge les données par blocs pour permettre l'affichage d'un tqdm."""
    m = np.memmap(os.path.normpath(input_path), dtype=dtype, mode='r')
    nb_samples = m.size // nb_channels
    data = np.empty(nb_samples, dtype=float)
    
    # On itère par blocs d'échantillons
    for i in tqdm(range(0, nb_samples, chunk_size), desc="Loading Trigger Channel"):
        end = min(i + chunk_size, nb_samples)
        # Slicing local au bloc
        block = m[i*nb_channels : end*nb_channels]
        data[i:end] = block[channel_id::nb_channels].astype(float)
    
    data = (data + np.iinfo('int16').min) / voltage_resolution
    return data, nb_samples

def detect_onsets(data, threshold):
    """Détection vectorisée des fronts montants."""
    # Cette opération est quasi-instantanée, le tqdm n'est utile que sur le recalage si nécessaire
    test = (data[:-1] < threshold) & (data[1:] >= threshold)
    indices = np.where(test)[0]
    
    # Recalage fin
    pbar = tqdm(total=5, desc="Getting Triggers", leave=False) # Souvent fini en < 5 itérations
    while True:
        to_shift = (indices > 0) & (data[indices - 1] < data[indices])
        if not np.any(to_shift):
            break
        indices[to_shift] -= 1
        pbar.update(1)
    pbar.close()
    return indices

def run_sanity_check(triggers, sampling_rate, maximal_jitter=0.25e-3):
    if len(triggers) < 2: return np.array([])
    inter_triggers = np.diff(triggers)
    val, counts = np.unique(inter_triggers, return_counts=True)
    mode_val = val[np.argmax(counts)]
    errors = np.where(np.abs(inter_triggers - mode_val) >= maximal_jitter * sampling_rate)[0]
    if errors.size > 0:
        print(f"⚠️ Erreurs triggers : {len(errors)}")
    else:
        print(f"✅ Triggers OK ({len(triggers)})")
    return triggers[errors].astype('int64')

def image_projection(image, setup_id):
    if setup_id == 2:
        return np.flipud(np.rot90(image))
    elif setup_id == 3:
        return np.fliplr(image)
    return image

def checkerboard_from_binary(nb_frames, nb_checks, path, setup_id):
    """
    Version optimisée utilisant la lecture en bloc et le déballage de bits NumPy.
    Gagne un facteur >100x en vitesse.
    """
    total_bits = nb_frames * nb_checks * nb_checks
    nb_bytes = (total_bits + 7) // 8
    
    with open(path, mode='rb') as f:
        raw_data = np.frombuffer(f.read(nb_bytes), dtype=np.uint8)

    all_bits = np.unpackbits(raw_data, bitorder='little')[:total_bits]
    
    all_frames = all_bits.reshape((nb_frames, nb_checks, nb_checks))
    
    checkerboard = np.zeros((nb_frames, nb_checks, nb_checks), dtype='uint8')
    
    for frame in tqdm(range(nb_frames), desc="Traitement des frames"):
        image = all_frames[frame].astype(float)
        checkerboard[frame] = image_projection(image, setup_id).astype('uint8')
        
    return checkerboard


In [3]:
def extract_from_sequence(cell_spikes, triggers, nb_repeats, stim_frequency, nb_frames_by_sequence, sequence_portion=(0.5, 1)):
    # 1. Pre-calculate indices and dimensions
    f0, f1 = sequence_portion
    nb_frames_portion = int((f1 - f0) * nb_frames_by_sequence)
    
    # Pre-calculate all start and end triggers for all repeats at once
    start_indices = np.arange(nb_repeats) * nb_frames_by_sequence + int(f0 * nb_frames_by_sequence)
    end_indices = np.arange(nb_repeats) * nb_frames_by_sequence + int(f1 * nb_frames_by_sequence)
    
    t_starts = triggers[start_indices]
    t_ends = triggers[end_indices]
    
    spike_trains = []
    spikes_counts = np.zeros((nb_repeats, nb_frames_portion))

    # 2. Use binary search (searchsorted) to find spikes within the range [min(t_starts), max(t_ends)]
    # This avoids scanning spikes that are outside the entire experiment portion
    relevant_indices = np.searchsorted(cell_spikes, [t_starts.min(), t_ends.max()])
    relevant_spikes = cell_spikes[relevant_indices[0]:relevant_indices[1]]

    # 3. Optimized Loop
    for i in range(nb_repeats):
        ts = t_starts[i]
        te = t_ends[i]
        
        # Binary search is much faster than boolean masking for sorted data
        idx_s, idx_e = np.searchsorted(relevant_spikes, [ts, te])
        spike_seq = relevant_spikes[idx_s:idx_e] - ts
        
        spike_trains.append(spike_seq)

        # Vectorized histogram (histogram is still good, but we limit the data it sees)
        counts, _ = np.histogram(spike_seq, bins=nb_frames_portion, range=(0, te - ts))
        spikes_counts[i, :] = counts

    return {
        "spike_trains": spike_trains,
        "counted_spikes": spikes_counts,
        "psth": spikes_counts.sum(axis=0) / nb_repeats * stim_frequency
    }


def compute_3D_sta(data, checkerboard, nb_frames_by_sequence, temporal_dimension):
    # 1. Flatten the spikes into a 1D array across all sequences/frames
    # We only care about the frames where spikes could actually trigger a window
    nb_frames_half = int(nb_frames_by_sequence / 2)
    spikes = data["counted_spikes"][:, temporal_dimension:nb_frames_half].flatten()
    total_spikes = np.sum(spikes)
    
    if total_spikes == 0:
        return np.zeros((temporal_dimension, checkerboard.shape[1], checkerboard.shape[2]))

    # 2. Identify indices where spikes occurred to avoid multiplying by zero
    spike_indices = np.where(spikes > 0)[0]
    weights = spikes[spike_indices]
    
    # 3. Vectorized window extraction
    # We map the flattened spike index back to the checkerboard index
    # idx_start = (seq * nb_frames_half) + frame - temporal_dimension
    # Since 'spikes' is already sliced from temporal_dimension, we adjust:
    sta = np.zeros((temporal_dimension, checkerboard.shape[1], checkerboard.shape[2]))
    
    for i, w in zip(spike_indices, weights):
        # Calculate the start/end in the checkerboard for this specific spike
        # The offset accounts for the 'temporal_dimension' skip in the spike data
        seq_idx = i // (nb_frames_half - temporal_dimension)
        frame_in_seq = i % (nb_frames_half - temporal_dimension) + temporal_dimension
        
        start = seq_idx * nb_frames_half + frame_in_seq - temporal_dimension
        end = start + temporal_dimension
        sta += w * checkerboard[start:end, :, :]

    # 4. Normalization
    sta /= total_spikes
    sta -= np.mean(sta)
    max_val = np.max(np.abs(sta))
    if max_val > 0:
        sta /= max_val
        
    return sta
def get_temporal_spatial_sta(sta_3D):
    # Use unravel_index on the absolute max to find the "peak" pixel and time
    idx_max = np.argmax(np.abs(sta_3D))
    best_t, best_x, best_y = np.unravel_index(idx_max, sta_3D.shape)
    
    # Extract slices
    sta_temporal = sta_3D[:, best_x, best_y]
    sta_spatial = sta_3D[best_t, :, :]
    
    # Vectorized normalization
    max_spatial = np.max(np.abs(sta_spatial))
    if max_spatial > 0:
        sta_spatial = sta_spatial / max_spatial
        
    return sta_temporal, sta_spatial, (best_t, best_x, best_y)



def process_single_electrode(args):
    """Function to process a single electrode - must be top-level for pickling."""
    electrode, mapping_info, spike_train, triggers, params = args
    row, col = mapping_info
    
    # Extract params for clarity
    nb_repeats = params['nb_repeats']
    stim_freq = params['stim_freq']
    nb_frames = params['nb_frames']
    temp_dim = params['temp_dim']
    checkerboard = params['checkerboard']

    # 1. Compute Raster Data
    res_r = extract_from_sequence(spike_train, triggers, nb_repeats, stim_freq, nb_frames, (0.5, 1))
    
    # 2. Compute STA Data
    res_s = extract_from_sequence(spike_train, triggers, nb_repeats, stim_freq, nb_frames, (0, 0.5))
    sta_3d = compute_3D_sta(res_s, checkerboard, nb_frames, temp_dim)
    _, sta_spat, _ = get_temporal_spatial_sta(sta_3d)

    return electrode, {
        'raster_spikes': res_r["spike_trains"],
        'sta_spatial': sta_spat
    }

def plot_stitched_sta(data_source, mapping, grid_size=16, padding=3):
    """
    Stitches individual 2D STAs into a single large 16x16 grid for fast rendering.
    """
    # 1. Handle data format (list of tuples from parallel vs. dictionary)
    if isinstance(data_source, list):
        data_dict = dict(data_source)
    else:
        data_dict = data_source

    if not data_dict:
        print("Error: No processed data found.")
        return

    # 2. Get dimensions from the first available electrode
    first_elec_id = next(iter(data_dict))
    h, w = data_dict[first_elec_id]['sta_spatial'].shape

    # 3. Create a giant empty canvas initialized with NaNs (for white padding)
    canvas_h = grid_size * h + (grid_size - 1) * padding
    canvas_w = grid_size * w + (grid_size - 1) * padding
    full_canvas = np.full((canvas_h, canvas_w), np.nan) 

    # 4. Fill the canvas
    for electrode, (row, col) in mapping.items():
        if electrode not in data_dict:
            continue
            
        sta = data_dict[electrode]['sta_spatial'].copy()
        
        # Local normalization for visibility
        vmax = np.max(np.abs(sta))
        if vmax > 0:
            sta /= vmax
        
        # Calculate pixel coordinates
        y_start = row * (h + padding)
        x_start = col * (w + padding)
        
        full_canvas[y_start : y_start + h, x_start : x_start + w] = sta

    # 5. Rendering
    plt.figure(figsize=(14, 14))
    
    # Configure colormap to show NaNs as white
    current_cmap = plt.cm.get_cmap('bwr').copy()
    current_cmap.set_bad(color='white') 

    plt.imshow(full_canvas, cmap=current_cmap, vmin=-1, vmax=1, interpolation='nearest')
    
    plt.axis('off')
    plt.title(f"Stitched STA Grid ({grid_size}x{grid_size})", fontsize=16, pad=20)
    plt.tight_layout()
    plt.show()

In [4]:
# --- Paramètres d'acquisition ---
SAMPLING_RATE = 20000
TOTAL_CHANNELS = 256
TRIGGER_CHANNEL = 126
DATA_TYPE = 'uint16'

# --- Choix du Setup ---
SETUP = 3  # 1 pour MEA1, 2 pour MEA2, 3 pour Opto

if SETUP == 1:
    DMD_POLARITY = 1
    PIXEL_SIZE = 2.3
    TRIGGER_THRESHOLD = 150e+3
elif SETUP == 2:
    DMD_POLARITY = 1
    PIXEL_SIZE = 3.5
    TRIGGER_THRESHOLD = 150e+3
elif SETUP == 3:
    DMD_POLARITY = -1
    PIXEL_SIZE = 2.8
    TRIGGER_THRESHOLD = 170e+3

# --- Paramètres Stimulus & Analyse ---
NB_CHECKS = 40
NB_FRAMES_SEQ = 1200
TEMPORAL_DIM = 30
PLOT_RASTER = True  # Équivalent au input "y/n"

In [5]:
# Sélection fichier
root = tk.Tk(); root.withdraw()
raw_path = filedialog.askopenfilename(title='Select a Checkerboard RAW file...')
if not raw_path: 
    print('File not Found')
    exit()
else:
    print(f"Selected File : {raw_path}")
mapping = load_obj('./electrodes_mapping_MEA_MCS_256.pkl')


Selected File : /mnt/Extra/PulsatingGrating/20250604_PulsingGrating/RAW_Files/01_ShiftingWhiteNoise_30Hz_30ND50%.raw


In [6]:
# 1. Traitement Spikes
print("Extraction des spikes...")
rec = si.read_binary(raw_path, sampling_frequency=SAMPLING_RATE, num_channels=TOTAL_CHANNELS, dtype=DATA_TYPE)

# Conversion spécifique SI pour passer en signé proprement
# rec = si.unsigned_to_signed(rec)

rec_filt = si.common_reference(si.bandpass_filter(rec))


peaks = detect_peaks(rec_filt, method="by_channel", peak_sign="neg", detect_threshold=6, n_jobs=10, progress_bar=True)
spike_trains_mua = defaultdict(list)
for p in peaks: spike_trains_mua[p[1]].append(p[0] / SAMPLING_RATE)



Extraction des spikes...


detect peaks using by_channel:   0%|          | 0/1846 [00:00<?, ?it/s]

In [7]:
# 2. Triggers
print("Lecture triggers...")
trig_raw, _ = load_data(raw_path, channel_id=TRIGGER_CHANNEL)
trig_idx = detect_onsets(trig_raw, TRIGGER_THRESHOLD)
run_sanity_check(trig_idx, SAMPLING_RATE)
triggers = trig_idx / SAMPLING_RATE
nb_repeats = len(triggers) // NB_FRAMES_SEQ

# --- Compute STIM_FREQ dynamically ---
# Calculate the mean time between consecutive frames (triggers)
# frequency = 1 / mean_inter_trigger_interval
avg_dt = np.mean(np.diff(triggers))
STIM_FREQ = int(round(1.0 / avg_dt))

print(f"Detected Stimulus Frequency: {STIM_FREQ} Hz")
print(f"Number of repeats: {nb_repeats}")


Lecture triggers...


Loading Trigger Channel:   0%|          | 0/37 [00:00<?, ?it/s]

Getting Triggers:   0%|          | 0/5 [00:00<?, ?it/s]

✅ Triggers OK (54000)
Detected Stimulus Frequency: 30 Hz
Number of repeats: 45


In [7]:
# 3. Stimulus
print("Chargement stimulus...")
stim_path = "./binarysource1000Mbits"
checkerboard = checkerboard_from_binary(nb_repeats * (NB_FRAMES_SEQ // 2), NB_CHECKS, stim_path, SETUP)

Chargement stimulus...


Traitement des frames:   0%|          | 0/55200 [00:00<?, ?it/s]

In [8]:
# 3. Stimulus for Shifting White Noise
MEA = 3 #int(input("\nEnter MEA number 2 or 3 :"))
if MEA==2: 
    threshold  = 150e+3
    pxl_size_dmd = 3.5
    size_dmd = [864, 864]      # dimensions of the DMD, in pixels
    polarity = 0
if MEA==3:
    threshold  = 170e+3   
    size_dmd = [760, 1020]      # dimensions of the DMD, in pixels
    pxl_size_dmd = 2.5          # The size of one pixel of the DMD in µm? on the camera or in reality?
    polarity = -1


In [16]:
from binfile import *

vec_file = "/mnt/Extra/PulsatingGrating/20250907_BarcodeXPulsatingGrating/VEC_Files/20250512_4_SWN_48pixCh_6pixShift_30Hz_MEA3.vec"
bin_file = "/mnt/Extra/PulsatingGrating/20250907_BarcodeXPulsatingGrating/VEC_Files/20250512_4_SWN_48pixCh_6pixShift_30Hz_MEA3.bin"

vec_trigs = np.loadtxt(vec_file)[1:]
binObj = BinFile(bin_file, size_dmd[0], size_dmd[1], MEA, mode='r')

shift_x_in_pix = 6
shift_y_in_pix = 6
checkerboard = []
num_unrepeated_frames = vec_trigs.shape[0]//2  #Keeping only the random frames knowing half are repeated and random one idex start at 0

for vec in tqdm(vec_trigs):
    vec_index = int(vec[1])
    if vec_index < num_unrepeated_frames:
        checker = binObj.read_frame(vec_index)
        checkerboard.append((checker[::shift_x_in_pix, ::shift_y_in_pix]/checker.max()))
checkerboard = np.array(checkerboard)

  0%|          | 0/54000 [00:00<?, ?it/s]

In [18]:
# 1. Prepare shared parameters
params = {
    'nb_repeats': nb_repeats,
    'stim_freq': STIM_FREQ,
    'nb_frames': NB_FRAMES_SEQ,
    'temp_dim': TEMPORAL_DIM,
    'checkerboard': checkerboard # Note: Large arrays can be slow to pass between processes
}

# 2. Filter tasks
tasks = [
    (elec, mapping[elec], np.array(spike_trains_mua[elec]), triggers, params)
    for elec in mapping.keys()
    if elec not in [127, 128, 255, 256] and elec in spike_trains_mua
]

processed_data = {}

# 3. Run in Parallel
# Adjust max_workers to the number of physical cores you want to use (e.g., 4, 8, or None for all)
with ProcessPoolExecutor(max_workers=10) as executor:
    results = list(tqdm(executor.map(process_single_electrode, tasks), 
                        total=len(tasks), 
                        desc="Parallel Analysis"))

# 4. Convert list of tuples back to dictionary
processed_data = dict(results)

Parallel Analysis:   0%|          | 0/252 [00:00<?, ?it/s]

IndexError: index 54000 is out of bounds for axis 0 with size 54000

In [32]:
if PLOT_RASTER :
    fig_r, axs_r = plt.subplots(16, 16, figsize=(15, 15))

    for electrode, (row, col) in tqdm(mapping.items()):
        ax_r = axs_r[row, col]

        # Check if we have data for this electrode
        if electrode not in processed_data:
            ax_r.axis('off')
            continue

        data = processed_data[electrode]

        # --- Plot Raster ---
        ax_r.eventplot(data['raster_spikes'], linewidths=0.1, color='black')
        ax_r.set_xticks([]); ax_r.set_yticks([])

    plt.show(block = False)

  0%|          | 0/256 [00:00<?, ?it/s]

In [33]:
if PLOT_RASTER:
    # On utilise une résolution de DPI raisonnable pour ne pas ramer à l'affichage
    fig_r, axs_r = plt.subplots(16, 16, figsize=(16, 16))
    
    data_dict = dict(processed_data) if isinstance(processed_data, list) else processed_data
    
    for electrode, (row, col) in tqdm(mapping.items(), desc="Plotting Rasters"):
        ax_r = axs_r[row, col]
    
        ax_r.set_xticks([])
        ax_r.set_yticks([])
        for spine in ax_r.spines.values():
            spine.set_visible(False)
    
        if electrode not in data_dict:
            # On laisse les électrodes vides en gris très clair pour voir la grille
            ax_r.set_facecolor('#f9f9f9') 
            continue
        data = data_dict[electrode]
        ax_r.eventplot(data['raster_spikes'], 
                       linewidths=0.1, 
                       color='black', 
                       rasterized=True) 
    
    
    # 4. Ajustement manuel (beaucoup plus rapide que tight_layout)
    plt.subplots_adjust(wspace=0.1, hspace=0.1, left=0.01, right=0.99, bottom=0.01, top=0.99)
    
    plt.show(block = False)

Plotting Rasters:   0%|          | 0/256 [00:00<?, ?it/s]

In [34]:
plot_stitched_sta(processed_data, mapping)

In [19]:
!python STA_MU_Exec.py

Selected File : /home/guiglaz/Documents/MultiunitFromRaw/01_Checkerboard_30Hz_16px_40sq_30ND50%.raw
Lecture triggers...
Loading Trigger Channel: 100%|██████████████████| 75/75 [00:01<00:00, 53.08it/s]
✅ Triggers OK (110462)                                                          
Detected Stimulus Frequency: 30 Hz
Number of repeats: 92
Extraction des spikes...
detect peaks using by_channel: 100%|████████| 3703/3703 [01:11<00:00, 51.96it/s]
Chargement stimulus...
Traitement des frames: 100%|██████████| 55200/55200 [00:00<00:00, 402962.94it/s]
Parallel Analysis: 100%|██████████████████████| 252/252 [00:24<00:00, 10.09it/s]
100%|█████████████████████████████████████████| 256/256 [00:12<00:00, 21.04it/s]
Press any key to close...^C
Traceback (most recent call last):
  File "/home/guiglaz/Documents/MultiunitFromRaw/STA_MU_Exec.py", line 415, in <module>
    input('Press any key to close...')
KeyboardInterrupt
[0m