# Clusterless Decoding

## Overview

_Developer Note:_ if you may make a PR in the future, be sure to copy this
notebook, and use the `gitignore` prefix `temp` to avoid future conflicts.

This is one notebook in a multi-part series on Spyglass.

- To set up your Spyglass environment and database, see
  [the Setup notebook](./00_Setup.ipynb)
- This tutorial assumes you've already 
  [extracted waveforms](./41_Extracting_Clusterless_Waveform_Features.ipynb), as well as loaded 
  [position data](./20_Position_Trodes.ipynb). If 1D decoding, this data should also be
  [linearized](./24_Linearization.ipynb).

Clusterless decoding can be performed on either 1D or 2D data. We will start with 2D data.

## Elements of Clusterless Decoding
- **Position Data**: This is the data that we want to decode. It can be 1D or 2D.
- **Spike Waveform Features**: These are the features that we will use to decode the position data.
- **Decoding Model Parameters**: This is how we define the model that we will use to decode the position data.

## Grouping Data
An important concept will be groups. Groups are tables that allow use to specify collections of data. We will use groups in two situations here:
1. Because we want to decode from more than one tetrode (or probe), so we will create a group that contains all of the tetrodes that we want to decode from. 
2. Similarly, we will create a group for the position data that we want to decode, so that we can decode from position data from multiple sessions.

### Grouping Waveform Features
Let's start with grouping the Waveform Features. We will first inspect the waveform features that we have extracted to figure out the primary keys of the data that we want to decode from. We need to use the tables `SpikeSortingSelection` and `SpikeSortingOutput` to figure out the `merge_id` associated with `nwb_file_name` to get the waveform features associated with the NWB file of interest.


In [1]:
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs
from spyglass.decoding.v1.waveform_features import (
    UnitWaveformFeaturesSelection,
    UnitWaveformFeatures,
)


nwb_copy_file_name = "IM-1875_darling_20250720_.nwb"

sorter_keys = {
    "nwb_file_name": nwb_copy_file_name,
    "sorter": "clusterless_thresholder",
    "sorter_param_name": "default_clusterless",
}

feature_key = {"features_param_name": "amplitude"}

(
    UnitWaveformFeaturesSelection.proj(merge_id="spikesorting_merge_id")
    * SpikeSortingOutput.CurationV1
    * sgs.SpikeSortingSelection
) & SpikeSortingOutput().get_restricted_merge_ids(
    sorter_keys, sources=["v1"], as_dict=True
)

[2025-08-19 15:07:07,888][INFO]: DataJoint 0.14.4 connected to scrater@lmf-db.cin.ucsf.edu:3306
[15:07:15][INFO] Spyglass: Initializing UserEnvironment for spikesorting: scrater_spyglass_01


merge_id,features_param_name  a name for this set of parameters,sorting_id,curation_id,recording_id,sorter,sorter_param_name,nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list
0bf44ff7-1e0f-e60f-cdef-d1d74ba29a9f,amplitude,643002d5-daec-45f9-bf7f-da7c46f4cc50,0,7ae104cd-b080-4d81-a2b7-1a778e7cb30b,clusterless_thresholder,default_clusterless,IM-1875_darling_20250720_.nwb,4fa06ba9-b763-4e2b-b153-051a1afd51e2


In [2]:
from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection

# find the merge ids that correspond to the sorter key restrictions
merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
    sorter_keys, sources=["v1"], as_dict=True
)

# find the previously populated waveform selection keys that correspond to these sorts
waveform_selection_keys = (
    UnitWaveformFeaturesSelection().proj(merge_id="spikesorting_merge_id")
    & merge_ids
    & feature_key
).fetch(as_dict=True)
for key in waveform_selection_keys:
    key["spikesorting_merge_id"] = key.pop("merge_id")

UnitWaveformFeaturesSelection & waveform_selection_keys

spikesorting_merge_id,features_param_name  a name for this set of parameters
0bf44ff7-1e0f-e60f-cdef-d1d74ba29a9f,amplitude


We will create a group called `test_group` that contains all of the tetrodes that we want to decode from. We will use the `create_group` function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group.

In [3]:
from spyglass.decoding.v1.clusterless import UnitWaveformFeaturesGroup

UnitWaveformFeaturesGroup().create_group(
    nwb_file_name=nwb_copy_file_name,
    group_name="sac_test_group",
    keys=waveform_selection_keys,
)
UnitWaveformFeaturesGroup & {"waveform_features_group_name": "sac_test_group"}

nwb_file_name  name of the NWB file,waveform_features_group_name
IM-1875_darling_20250720_.nwb,sac_test_group


We can see that we successfully associated "test_group" with the tetrodes that we want to decode from by using the `get_group` function.

In [4]:
UnitWaveformFeaturesGroup.UnitFeatures & {
    "nwb_file_name": nwb_copy_file_name,
    "waveform_features_group_name": "sac_test_group",
}

nwb_file_name  name of the NWB file,waveform_features_group_name,spikesorting_merge_id,features_param_name  a name for this set of parameters
IM-1875_darling_20250720_.nwb,sac_test_group,0bf44ff7-1e0f-e60f-cdef-d1d74ba29a9f,amplitude


### Grouping Position Data

We will now create a group called `02_r1` that contains all of the position data that we want to decode from. As before, we will use the `create_group` function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group.

We use the the `PositionOutput` table to figure out the `merge_id` associated with `nwb_file_name` to get the position data associated with the NWB file of interest. In this case, we only have one position to insert, but we could insert multiple positions if we wanted to decode from multiple sessions.

Note that we can use the `upsample_rate` parameter to define the rate to which position data will be upsampled to to for decoding in Hz. This is useful if we want to decode at a finer time scale than the position data sampling frequency. In practice, a value of 500Hz is used in many analyses. Skipping or providing a null value for this parameter will default to using the position sampling rate.

You will also want to specify the name of the position variables if they are different from the default names. The default names are `position_x` and `position_y`.

In [7]:
from spyglass.position import PositionOutput
import spyglass.position as sgp


sgp.v1.TrodesPosParams.insert1(
    {
        "trodes_pos_params_name": "default_decoding",
        "params": {
            "max_LED_separation": 9.0,
            "max_plausible_speed": 300.0,
            "position_smoothing_duration": 0.125,
            "speed_smoothing_std_dev": 0.100,
            "orient_smoothing_std_dev": 0.001,
            "led1_is_front": 1,
            "is_upsampled": 1,
            "upsampling_sampling_rate": 250,
            "upsampling_interpolation_method": "linear",
        },
    },
    skip_duplicates=True,
)

trodes_s_key = {
    "nwb_file_name": nwb_copy_file_name,
    "interval_list_name": "00_r1",
    "trodes_pos_params_name": "default_decoding",
}
sgp.v1.TrodesPosSelection.insert1(trodes_s_key)
sgp.v1.TrodesPosV1.populate(trodes_s_key)

PositionOutput.TrodesPosV1 & trodes_s_key

IntegrityError: Cannot add or update a child row: a foreign key constraint fails (`position_v1_trodes_position`.`trodes_pos_selection`, CONSTRAINT `trodes_pos_selection_ibfk_1` FOREIGN KEY (`nwb_file_name`, `interval_list_name`) REFERENCES `common_behav`.`_raw_position` (`n)

In [8]:
from spyglass.decoding.v1.core import PositionGroup

position_merge_ids = (
    PositionOutput.TrodesPosV1
    & {
        "nwb_file_name": nwb_copy_file_name,
        "interval_list_name": "pos 0 valid times",
        "trodes_pos_params_name": "default_decoding",
    }
).fetch("merge_id")

PositionGroup().create_group(
    nwb_file_name=nwb_copy_file_name,
    group_name="sac_test_group",
    keys=[{"pos_merge_id": merge_id} for merge_id in position_merge_ids],
    upsample_rate=500,
)

PositionGroup & {
    "nwb_file_name": nwb_copy_file_name,
    "position_group_name": "sac_test_group",
}

nwb_file_name  name of the NWB file,position_group_name,position_variables  list of position variables to decode,upsample_rate  upsampling rate for position data (Hz)
IM-1875_darling_20250720_.nwb,sac_test_group,=BLOB=,500.0


In [9]:
(
    PositionGroup
    & {"nwb_file_name": nwb_copy_file_name, "position_group_name": "sac_test_group"}
).fetch1("position_variables")

['position_x', 'position_y']

In [10]:
PositionGroup.Position & {
    "nwb_file_name": nwb_copy_file_name,
    "position_group_name": "sac_test_group",
}

nwb_file_name  name of the NWB file,position_group_name,pos_merge_id
IM-1875_darling_20250720_.nwb,sac_test_group,e1b91be0-9154-bfaa-7372-8f2dfb026088


## Decoding Model Parameters

We will use the `non_local_detector` package to decode the data. This package is highly flexible and allows several different types of models to be used. In this case, we will use the `ContFragClusterlessClassifier` to decode the data. This has two discrete states: Continuous and Fragmented, which correspond to different types of movement models. To read more about this model, see:
> Denovellis, E.L., Gillespie, A.K., Coulter, M.E., Sosa, M., Chung, J.E., Eden, U.T., and Frank, L.M. (2021). Hippocampal replay of experience at real-world speeds. eLife 10, e64505. [10.7554/eLife.64505](https://doi.org/10.7554/eLife.64505).

Let's first look at the model and the default parameters:


In [11]:
from non_local_detector.models import ContFragClusterlessClassifier

ContFragClusterlessClassifier()

You can change these parameters like so: 

In [11]:
from non_local_detector.models import ContFragClusterlessClassifier

ContFragClusterlessClassifier(
    clusterless_algorithm_params={
        "block_size": 10000,
        "position_std": 12.0,
        "waveform_std": 24.0,
    },
)

This is how to insert the model parameters into the database:

In [12]:
from spyglass.decoding.v1.core import DecodingParameters


DecodingParameters.insert1(
    {
        "decoding_param_name": "contfrag_clusterless",
        "decoding_params": ContFragClusterlessClassifier(),
        "decoding_kwargs": dict(),
    },
    skip_duplicates=True,
)

DecodingParameters & {"decoding_param_name": "contfrag_clusterless"}

decoding_param_name  a name for this set of parameters,decoding_params  initialization parameters for model,decoding_kwargs  additional keyword arguments
contfrag_clusterless,=BLOB=,=BLOB=


We can retrieve these parameters and rebuild the model like so:

In [13]:
model_params = (
    DecodingParameters & {"decoding_param_name": "contfrag_clusterless"}
).fetch1()

ContFragClusterlessClassifier(**model_params["decoding_params"])

## Decoding

Now that we have grouped the data and defined the model parameters, we have finally set up the elements in tables that we need to decode the data. We now need to use the `ClusterlessDecodingSelection` to fully specify all the parameters and data that we want.

This has:
- `waveform_features_group_name`: the name of the group that contains the waveform features that we want to decode from
- `position_group_name`: the name of the group that contains the position data that we want to decode from
- `decoding_param_name`: the name of the decoding parameters that we want to use
- `nwb_file_name`: the name of the NWB file that we want to decode from
- `encoding_interval`: the interval of time that we want to train the initial model on
- `decoding_interval`: the interval of time that we want to decode from
- `estimate_decoding_params`: whether or not we want to estimate the decoding parameters


The first three parameters should be familiar to you. 


### Decoding and Encoding Intervals
The `encoding_interval` is the interval of time that we want to train the initial model on. The `decoding_interval` is the interval of time that we want to decode from. These two intervals can be the same, but they do not have to be. For example, we may want to train the model on a long interval of time, but only decode from a short interval of time. This is useful if we want to decode from a short interval of time that is not representative of the entire session. In this case, we will train the model on a longer interval of time that is representative of the entire session.

These keys come from the `IntervalList` table. We can see that the `IntervalList` table contains the `nwb_file_name` and `interval_name` that we need to specify the `encoding_interval` and `decoding_interval`. We will specify a short decoding interval called `test decoding interval` and use that to decode from.


### Estimating Decoding Parameters
The last parameter is `estimate_decoding_params`. This is a boolean that specifies whether or not we want to estimate the decoding parameters. If this is `True`, then we will estimate the initial conditions and discrete transition matrix from the data.

NOTE: If estimating parameters, then we need to treat times outside decoding interval as missing. this means that times outside the decoding interval will not use the spiking data and only the state transition matrix and previous time step will be used. This may or may not be desired depending on the length of this missing interval.


In [15]:
from spyglass.decoding.v1.clusterless import ClusterlessDecodingSelection

ClusterlessDecodingSelection() & {"nwb_file_name": nwb_copy_file_name}

nwb_file_name  name of the NWB file,waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
,,,,,,


In [16]:
from spyglass.common import IntervalList

IntervalList & {"nwb_file_name": nwb_copy_file_name}

nwb_file_name  name of the NWB file,interval_list_name  descriptive name of this interval list,valid_times  numpy array with start/end times for each interval,"pipeline  type of interval list (e.g. 'position', 'spikesorting_recording_v1')"
IM-1875_darling_20250720_.nwb,00_r1,=BLOB=,
IM-1875_darling_20250720_.nwb,4fa06ba9-b763-4e2b-b153-051a1afd51e2,=BLOB=,spikesorting_artifact_v1
IM-1875_darling_20250720_.nwb,7ae104cd-b080-4d81-a2b7-1a778e7cb30b,=BLOB=,spikesorting_recording_v1
IM-1875_darling_20250720_.nwb,epoch0_block1,=BLOB=,hex_maze
IM-1875_darling_20250720_.nwb,epoch0_block1_trial1,=BLOB=,hex_maze
IM-1875_darling_20250720_.nwb,epoch0_block1_trial10,=BLOB=,hex_maze
IM-1875_darling_20250720_.nwb,epoch0_block1_trial11,=BLOB=,hex_maze
IM-1875_darling_20250720_.nwb,epoch0_block1_trial12,=BLOB=,hex_maze
IM-1875_darling_20250720_.nwb,epoch0_block1_trial13,=BLOB=,hex_maze
IM-1875_darling_20250720_.nwb,epoch0_block1_trial14,=BLOB=,hex_maze


In [None]:
# decoding_interval_valid_times = [
#     [1625935714.6359036, 1625935714.6359036 + 15.0]
# ]

# IntervalList.insert1(
#     {
#         "nwb_file_name": "mediumnwb20230802_.nwb",
#         "interval_list_name": "test decoding interval",
#         "valid_times": decoding_interval_valid_times,
#     },
#     skip_duplicates=True,
# )

Once we have figured out the keys that we need, we can insert the `ClusterlessDecodingSelection` into the database.

In [17]:
selection_key = {
    "waveform_features_group_name": "sac_test_group",
    "position_group_name": "sac_test_group",
    "decoding_param_name": "contfrag_clusterless",
    "nwb_file_name": nwb_copy_file_name,
    "encoding_interval": "00_r1",
    "decoding_interval": "epoch0_block1_trial16",
    "estimate_decoding_params": False,
}

ClusterlessDecodingSelection.insert1(
    selection_key,
    skip_duplicates=True,
)

ClusterlessDecodingSelection & selection_key

nwb_file_name  name of the NWB file,waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
IM-1875_darling_20250720_.nwb,sac_test_group,sac_test_group,contfrag_clusterless,00_r1,epoch0_block1_trial16,0


In [19]:
ClusterlessDecodingSelection() & {"nwb_file_name": nwb_copy_file_name}

nwb_file_name  name of the NWB file,waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
IM-1875_darling_20250720_.nwb,sac_test_group,sac_test_group,contfrag_clusterless,00_r1,epoch0_block1_trial16,0


To run decoding, we simply populate the `ClusterlessDecodingOutput` table. This will run the decoding and insert the results into the database. We can then retrieve the results from the database.

In [20]:
from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1

ClusterlessDecodingV1.populate(selection_key)

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  position_df[column][st:en] = np.nan


Encoding models:   0%|          | 0/1 [00:00<?, ?electrode/s]

Non-Local Likelihood:   0%|          | 0/1 [00:00<?, ?electrode/s]

  results = xr.Dataset(


{'success_count': 1, 'error_list': []}

We can now see it as an entry in the `DecodingOutput` table.

In [21]:
from spyglass.decoding.decoding_merge import DecodingOutput

DecodingOutput.ClusterlessDecodingV1 & selection_key

merge_id,nwb_file_name  name of the NWB file,waveform_features_group_name,position_group_name,decoding_param_name  a name for this set of parameters,encoding_interval  descriptive name of this interval list,decoding_interval  descriptive name of this interval list,estimate_decoding_params  whether to estimate the decoding parameters
ee6a9b07-03f7-1f55-2a4d-13f37885f585,IM-1875_darling_20250720_.nwb,sac_test_group,sac_test_group,contfrag_clusterless,00_r1,epoch0_block1_trial16,0


We can load the results of the decoding:

In [22]:
decoding_results = (ClusterlessDecodingV1 & selection_key).fetch_results()
decoding_results



Finally, if we deleted the results, we can use the `cleanup` function to delete the results from the file system:

In [23]:
DecodingOutput().cleanup()

[10:26:49][INFO] Spyglass: Cleaning up decoding outputs
[10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_259e8498-1eb6-4e84-a0f6-7575c4ab9b87.nc
[10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_8404dc49-081c-48c7-b448-34767512e8ed.nc
[10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_259e8498-1eb6-4e84-a0f6-7575c4ab9b87.pkl
[10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_8404dc49-081c-48c7-b448-34767512e8ed.pkl


## Visualization of decoding output.

The output of decoding can be challenging to visualize with static graphs, especially if the decoding is performed on 2D data.

We can interactively visualize the output of decoding using the [figurl](https://github.com/flatironinstitute/figurl) package. This package allows to create a visualization of the decoding output that can be viewed in a web browser. This is useful for exploring the decoding output over time and sharing the results with others.

**NOTE**: You will need a kachery cloud instance to use this feature. If you are a member of the Frank lab, you should have access to the Frank lab kachery cloud instance. If you are not a member of the Frank lab, you can create your own kachery cloud instance by following the instructions [here](https://github.com/flatironinstitute/kachery-cloud/blob/main/doc/create_kachery_zone.md).

For each user, you will need to run `kachery-cloud-init` in the terminal and follow the instructions to associate your computer with your GitHub user on the kachery-cloud network.


In [23]:
from non_local_detector.visualization import (
    create_interactive_2D_decoding_figurl,
)

(
    position_info,
    position_variable_names,
) = ClusterlessDecodingV1.fetch_position_info(selection_key)
results_time = decoding_results.acausal_posterior.isel(intervals=0).time.values
position_info = position_info.loc[results_time[0] : results_time[-1]]

env = ClusterlessDecodingV1.fetch_environments(selection_key)[0]
spike_times, _ = ClusterlessDecodingV1.fetch_spike_data(selection_key)


create_interactive_2D_decoding_figurl(
    position_time=position_info.index.to_numpy(),
    position=position_info[position_variable_names],
    env=env,
    results=decoding_results,
    posterior=decoding_results.acausal_posterior.isel(intervals=0)
    .unstack("state_bins")
    .sum("state"),
    spike_times=spike_times,
    head_dir=position_info["orientation"],
    speed=position_info["speed"],
)

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  position_df[column][st:en] = np.nan
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Ser


Kachery let's scientist store data files in the cloud for the purpose of using
cloud-based visualization tools and collaborating with others. This is a free
service when used for scientific research purposes. In order to use it, you must
register using your GitHub account, provide your email, and briefly describe the
purpose of the research. To register, visit https://kachery.vercel.app. Then set
the KACHERY_API_KEY environment variable to your API key.

Alternatively, you can use the "scratch" zone which is subject to regular deletion
of files by setting the KACHERY_ZONE environment variable to "scratch".

For more information, visit https://github.com/magland/kachery.



Exception: KACHERY_API_KEY environment variable is not set

## GPUs
We can use GPUs for decoding which will result in a significant speedup. This is achieved using the [jax](https://jax.readthedocs.io/en/latest/) package.

### Ensuring jax can find a GPU
 Assuming you've set up a GPU, we can use `jax.devices()` to make sure the decoding code can see the GPU. If a GPU is available, it will be listed.

In the following instance, we do not have a GPU:

In [25]:
import jax

jax.devices()

[CpuDevice(id=0)]

### Selecting a GPU
If you do have multiple GPUs, you can use the `jax` package to set the device (GPU) that you want to use. For example, if you want to use the second GPU, you can use the following code (uncomment first):

In [26]:
# device_id = 2
# device = jax.devices()[device_id]
# jax.config.update("jax_default_device", device)
# device

### Monitoring GPU Usage

You can see which GPUs are occupied (if you have multiple GPUs) by running the command `nvidia-smi` in
a terminal (or `!nvidia-smi` in a notebook). Pick a GPU with low memory usage. 

We can monitor GPU use with the terminal command `watch -n 0.1 nvidia-smi`, will
update `nvidia-smi` every 100 ms. This won't work in a notebook, as it won't
display the updates.

Other ways to monitor GPU usage are:

- A 
  [jupyter widget by nvidia](https://github.com/rapidsai/jupyterlab-nvdashboard)
  to monitor GPU usage in the notebook
- A [terminal program](https://github.com/peci1/nvidia-htop) like nvidia-smi
  with more information about  which GPUs are being utilized and by whom.