In [None]:
import gc
import numpy as np
import scipy.io as sio  # For saving

# For clearing cache
import shutil
import os


from one.api import ONE
from brainbox.io.one import SpikeSortingLoader, SessionLoader
from brainbox.behavior.training import compute_performance
from brainbox.population.decode import get_spike_counts_in_bins

# Connect to IBL data server
one = ONE(base_url='https://openalyx.internationalbrainlab.org')
print(one.cache_dir)

# =========================================
# 🔁 Loop through contrast × prior settings
# =========================================
# Spike count window settings
window_start = 0          # in seconds
window_end = 300 / 1000   # in seconds

nNeurons_lb = 10
nTrials_lb = 20
nTrials_lb = 2
RT_ub = 2000 # trials with RT more than this dur (in ms) will be removed

# # Define experiment(s)
# eid = '7af49c00-63dd-4fed-b2e0-1b3bd945b20b'
# eids = [eid]  # Can add more eids here

# Find spike-sorted sessions from Churchland lab
# lab='churchlandlab',
lab='churchlandlab'
eids = one.search(
    lab=lab,
    dataset_types=['spikes.times']
)
print(f">>> {lab} sessions with spike sorting: {len(eids)}")


In [None]:
import psutil

process = psutil.Process()
mem_info = process.memory_info()
print(f"Memory used: {mem_info.rss / 1024**2:.2f} MB")

In [None]:
# An example eid for debugging
# eid = 'aad23144-0e52-4eac-80c5-c4ee2decb198'
# pids, labels = one.eid2pid(eid)
# trials = one.load_object(eid, 'trials', collection='alf')
# contrast_levels = np.unique(trials['contrastLeft'])
# prior_levels = np.unique(trials['probabilityLeft'])
process = psutil.Process()
mem_info = process.memory_info()


In [None]:
# Loop over sessions (EIDs)
for indEid, eid in enumerate(eids, start=1):

    # ================================
    # 📦 Load trial and probe info
    # ================================
    pids, labels = one.eid2pid(eid)
    trials = one.load_object(eid, 'trials', collection='alf')
    contrast_levels = np.unique(trials['contrastLeft'])
    prior_levels = np.unique(trials['probabilityLeft'])

    # assert that contrast_levels has 5 levels: 0, .0625, .125, .25, .5, 1
    # assert np.array_equal(contrast_levels, [0., 0.0625, 0.125, 0.25, 0.5, 1]), f"Unexpected contrast_levels: {contrast_levels}"
    
    # assert that prior_levels has three levels: 0.2, 0.5 and 0.8
    assert np.array_equal(prior_levels, [0.2, 0.5, 0.8]), f"Unexpected prior_levels: {prior_levels}"

    for label in labels:
        print(f"\n📂 Session #{indEid}/{len(eids)}: {eid} | Probe: {label}")

        # ================================
        # 🔬 Load spikes and clusters
        # ================================
        ssl = SpikeSortingLoader(eid=eid, one=one, pname=label)
        spikes, clusters, channels = ssl.load_spike_sorting()
        clusters = ssl.merge_clusters(spikes, clusters, channels)
        regions_all = np.unique(clusters['acronym'])

        for target_region in regions_all:
            for flag_good_quality in [False, True]:  # True = label==1

                # ================================
                # 🧠 Neuron selection
                # ================================

                # Select neurons that have matched region and quality
                region_mask = clusters['acronym'] == target_region
                n_neurons = len(clusters['acronym'])
                quality_mask = clusters['label'] == 1 if flag_good_quality else np.ones(n_neurons, dtype=bool)
                valid_cluster_mask = region_mask & quality_mask
                
                # Apply valid-neuron mask
                valid_cluster_ids = clusters['cluster_id'][valid_cluster_mask]
                del region_mask, n_neurons, quality_mask, valid_cluster_mask
                gc.collect()

                if len(valid_cluster_ids) > nNeurons_lb:
                    process = psutil.Process()
                    mem_info = process.memory_info()
                    print(f"🎯 Region: {target_region}; Quality: {flag_good_quality} | Neurons: {len(valid_cluster_ids)}")

                    for contrast in contrast_levels:
                        for prior in prior_levels:

                            # ================================
                            # 🧪 Trial selection
                            # ================================
                            # Select trials with matched contrast and prior
                            trial_mask = (trials['contrastLeft'] == contrast) & \
                                         (trials['probabilityLeft'] == prior)

                            # Select trials with RT within RT_ub
                            rt = trials['response_times'] - trials['stimOn_times']
                            trial_mask &= (rt <= RT_ub) & ~np.isnan(rt)

                            # Select trials with long enough latency between stim onset and movement onset
                            movement_latency = trials['firstMovement_times'] - trials['stimOn_times']
                            no_early_movement_mask = (movement_latency >= window_end) | np.isnan(movement_latency)
                            trial_mask &= no_early_movement_mask

                            # Check priors
                            n_total = np.sum((trials['contrastLeft'] == contrast) & (trials['probabilityLeft'] == prior))
                            # print(f"📊 Prior={prior}, Contrast={contrast} | Total trials: {n_total}")

                            # Apply valid-trial mask
                            valid_trial_times = trials['stimOn_times'][trial_mask]

                            # check priors again
                            n_after = np.sum(trial_mask)
                            # print(f"   ↳ Trials after filtering: {n_after}")

                            del trial_mask, movement_latency, no_early_movement_mask
                            gc.collect()


                            if len(valid_trial_times) > nTrials_lb:
                                process = psutil.Process()
                                mem_info = process.memory_info()
                                print(f"  ✅Contrast: {contrast} | Prior: {prior} | Trials: {len(valid_trial_times)}")

                    del valid_cluster_ids
                    gc.collect()

        del spikes, clusters, channels, ssl
        gc.collect()

    del trials
    gc.collect()

    # Clear cache
    # session_cache_path = os.path.join(one.cache_dir, eid)
    session_cache_path = os.path.join(one.cache_dir, str(eid))
    
    if os.path.exists(session_cache_path):
        shutil.rmtree(session_cache_path)
        print(f"🧹 Cleared cache for session {eid}")

del pids, eids, labels
print('✅ DONE')

# Clear all cache at the end
shutil.rmtree(one.cache_dir)
print("🧹 All cache cleared.")
