This tutorial will demonstrate how to:
- fit tuning curves given behavioral labels (e.g. position)
- perform state-space decoding in the fashion of [Denovellis et. al. (2021)](https://elifesciences.org/articles/64505). It gives the posterior probability of the *label* and the *dynamics type* (they call it continuous and discrete variable, repsectively). 
    - *dynamics type* specifies the temporal prior of the label. When the dynamics type is *continuous*, the temporal prior of the label is a gaussian random walk, with a movement variance specified by the user. When the dynamics type is *fragmented*, the temporal prior of the label is uniform across all possible bins.


# import  

In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import poor_man_gplvm as pmg
import poor_man_gplvm.plot_helper as ph
import pynapple as nap

# Load the data (ignore this section and replace with your own data)
Some words on data preprocessing. We highly recommend [pynapple](https://pynapple.org/) as an entry point for neural data analysis in Python. They wrap around numpy objects but provide additional useful functionalities like restricting to time intervals, aligning to common time stamps, and turn spike times into counts. Essentially, for this tutorial, we need: 
- *spk_times*: pynapple TsGroup, obtained from a list of spike times (from the entire recording) for each unit.
- *position_tsdf*: pynapple TsdFrame, obtained from an array of (n_time, n_columns), timestamps, and column names. Each column is one behavior label we will decode (doesn't have to be position).
- *behavior_ep*: pynapple IntervalSet, obtained from arrays of start and end times of the behavior epoch when tuning curve is computed. 
- *speed_tsd* (optional): pynapple Tsd, obtained from an array of (n_time,) and timestamps. Here it is used for subselecting the locomotion epochs to include in the tuning curve computation. 


In [18]:
import sys,os
sys.path.append('../../poor_gplvm/code')
import preprocess_roman_tmaze as preprt



data_dir_full = preprt.db_roman.iloc[0]['data_dir_full']

prep_res = nap.load_folder(os.path.join(data_dir_full, "derivatives"))  


spk_times = prep_res["spk_times"]
ripple_intervals = prep_res["ripple_intervals"]
position_tsdf = prep_res["position_tsdf"]
behavior_ep = prep_res['behavior_ep']


behavior_ep = prep_res['behavior_ep']
speed_tsd = prep_res['speed_tsd']

# Prepare the data

## turn spike train into a matrix (TsdFrame, n_time x n_neuron) of spike counts 
Optional: use a mask to subselect only the pyramidal cells. This is easy if the relevant mask, e.g. *is_pyr* (whether it is a pyramidal cell) is stored as a metadata in the TsGroup.


In [15]:
spk_times_pyr=spk_times[spk_times['is_pyr']]
spk_mat = spk_times_pyr.count(0.1,ep=behavior_ep)

## prepare the labels and hyperparameters
### labels
In the paradigm of spatial navigation, the *label_l* can be time series of:
- linearized positions 
- 2D positions 
- [choice port ID](https://www.nature.com/articles/s41586-024-08397-7)
- linearized positions + direction
- 2D positions + direction

Indeed, in contrast to existing libraries of spatial decoding, we allow for arbitrary numbers of label dimension (up to memory constraint, so practically if one is already using 2D positions, the extra dimensions should not have too many discretized bins). 

Even for a linearizeable maze like the alternating T-maze, I personally still prefer using the 2D positions as labels. Whereas for a linear track, I would use the 1D position + direction, although 2D could give a subtler picture as hinted by [Zutshi et. al. (2025)](https://www.nature.com/articles/s41586-024-08397-7).


### Hyperparameters
- *label_bin_size*: binsize for discretizing the labels. 
- *smooth_std*: the standard deviation of the Gaussian kernel for smoothing the tuning curves. If None then no smoothing.
- *occupancy_threshold*: the occupancy threshold (in seconds) for the label bin to be considered valid. 

All of the above can be either: 1) a single number that apply to all the dimensions; 2) a list of value per dimension; and 3) for multiple mazes, a dictionary of {maze_key: val}, where val can be 1) or 2).

Here we will demonstrate the more general syntax assuming multiple mazes, but know that it can be simplified.

In [17]:
label_d= {} 
label_d['familiar']=position_tsdf[['x','y']].restrict(behavior_ep[0]) # The restrict limit the x y coordinates to the first behavior epoch

ep_d={}
ep_d['familiar'] = speed_tsd.restrict(behavior_ep[0]).threshold(5).time_support

label_bin_size_d = {}
label_bin_size_d['familiar'] =  3.

smooth_std_d = {}
smooth_std_d['familiar'] = 3.


## below is if there's a second novel linear maze, with time_window given by `behavior_ep[1]`

In [None]:
# novel_lin=position_tsdf[['lin']].restrict(behavior_ep[1])

# novel_lin_dir = novel_lin.derivative() > 0

# beh_tsdf_novel=nap.TsdFrame(d=np.stack([novel_lin.d,novel_lin_dir.d.astype(int)],axis=1),t=novel_lin.t,columns=['lin','dir'])

# label_d['novel'] = beh_tsdf_novel

# speed_tsd_novel = np.abs(novel_lin.derivative())

# ep_d['novel'] = speed_tsd_novel.threshold(5).time_support



# label_bin_size_d['novel'] = [3.,1.]

# smooth_std_d['novel'] = [3,None]

# compute tuning curves
Tuning curves are computed by: 1) discretize the multi-dimensional labels into bins; 2) get occupancy of each bin; 3) drop the low occupancy bins; 4) compute the spike counts emitted within each bin; 5) smooth the occupancy and spike counts using a Gaussian kernel; 6) FR = count_smoothed / occupancy_smoothed, in Hz

In [23]:
from importlib import reload
reload(pmg)

ImportError: cannot import name 'decode_naive_bayes_supervised' from 'poor_man_gplvm.supervised_analysis' (unknown location)

In [21]:
tuning_res = pmg.get_tuning_supervised(
    label_l=label_d,
    spk_mat=spk_mat,             # nap.TsdFrame, n_time x n_neuron
    ep=ep_d,                   # nap.IntervalSet (optional)
    label_bin_size=label_bin_size_d,          # cm
    smooth_std=smooth_std_d,              # cm, Gaussian kernel std
    occupancy_threshold=0.1,     # seconds
)



[autoreload of poor_man_gplvm failed: Traceback (most recent call last):
  File "/mnt/home/szheng/miniconda3/envs/jaxnew2/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/mnt/home/szheng/miniconda3/envs/jaxnew2/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/mnt/home/szheng/miniconda3/envs/jaxnew2/lib/python3.12/importlib/__init__.py", line 131, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 866, in _exec
  File "<frozen importlib._bootstrap_external>", line 995, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/mnt/home/szheng/projects/poor-man-GPLVM/poor_man_gplvm/__init__.py", line 11, in <module>
    from poor_man_gplvm.supervised_analysis import (
ImportError: cannot import name 'decode_naive_bay

TypeError: 'module' object is not callable