# main.ipynb

- Compute arousal metrics from behavior
- Compute spectral and multivariate differentiation
- Decode stimuli

In [1]:
import multiprocessing
from itertools import cycle, product
from pathlib import Path

import numpy as np
import pandas as pd
import scipy
import scipy.spatial
import sklearn
import yaml
from joblib import Parallel, delayed
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import (StratifiedKFold, cross_val_score,
                                     cross_validate)
from sklearn.pipeline import Pipeline
from sklearn.utils import parallel_backend
from tqdm.auto import tqdm
from umap import UMAP

import metadata
from spectral_differentiation import compute_spectral_differentiation, TRUNCATION_TOLERANCE
from analysis import (BLOCK_STIMULI, CONTINUOUS_NATURAL, SESSIONS,
                      UNSCRAMBLED_SCRAMBLED, get_cells)
import analysis
from load import load_preprocessed_data, load_preprocessed_events, make_multiindex
from metadata import METADATA, STIMULUS_METADATA

In [2]:
# Register tqdm with pandas for `progress_apply`
tqdm.pandas()

In [3]:
N_CPUS = multiprocessing.cpu_count()

In [12]:
OUTPUT_DIR = Path('results')
OUTPUT_DIR.mkdir(exist_ok=True)

## Load data

In [4]:
def loader(session):
    events = load_preprocessed_events(session) 
    dff, data = load_preprocessed_data(session)
    return dff, events, data

loaded_data = Parallel(n_jobs=min(len(SESSIONS), N_CPUS), verbose=5)(
    delayed(loader)(session) for session in SESSIONS
)

[Parallel(n_jobs=44)]: Using backend LokyBackend with 44 concurrent workers.
[Parallel(n_jobs=44)]: Done   2 out of  44 | elapsed:    9.4s remaining:  3.3min
[Parallel(n_jobs=44)]: Done  11 out of  44 | elapsed:   14.4s remaining:   43.1s
[Parallel(n_jobs=44)]: Done  20 out of  44 | elapsed:   19.1s remaining:   23.0s
[Parallel(n_jobs=44)]: Done  29 out of  44 | elapsed:   26.0s remaining:   13.4s
[Parallel(n_jobs=44)]: Done  38 out of  44 | elapsed:   38.7s remaining:    6.1s
[Parallel(n_jobs=44)]: Done  44 out of  44 | elapsed:   53.5s finished


In [5]:
DFF = dict()
EVENTS = dict()
DATA = dict()
for session, (dff, events, data) in zip(SESSIONS, loaded_data):
    DFF[session] = dff
    EVENTS[session] = events
    DATA[session] = data

In [6]:
# Attach label columns to dF/F and event dataframes
for session, data in tqdm(DATA.items()):
    labels = data.loc[:, ["stimulus", "trial", "stimulus_is_scrambled_pair", "stimulus_is_block"]]
    DFF[session] = DFF[session].merge(
        labels,
        how="left",
        left_index=True,
        right_index=True,
    )
    EVENTS[session] = EVENTS[session].merge(
        labels,
        how="left",
        left_index=True,
        right_index=True,
    )

  0%|          | 0/44 [00:00<?, ?it/s]

# Behavior

Mean pupil diameter and locomotion fraction

In [7]:
behavior = (
    pd.concat(DATA.values())
    .reset_index()
)

In [8]:
cols = ["normalized_pupil_diameter", "locomotion", "filtered_velocity"]

In [13]:
# By session and trial
groupings = {
    "session": ["session"],
    "trial": ["session", "stimulus", "trial"],
}
behavior_by = {}
for grouping_name, grouping in tqdm(groupings.items()):
    behavior_by[grouping_name] = (
        behavior
        .loc[:, grouping + cols]
        .groupby(grouping)
        .mean()
    )
    behavior_by[grouping_name].to_parquet(OUTPUT_DIR/f"behavior_by_{grouping_name}.parquet")

  0%|          | 0/2 [00:00<?, ?it/s]

# Events

### Mean event magnitude per trial

In [14]:
grouping = ['session', 'stimulus', 'trial']

def worker(data):
    return (
        data
        .reset_index()
        .loc[:, get_cells(data) + grouping]
        .groupby(grouping)
        .mean()
    )

In [15]:
magnitudes = pd.concat(Parallel(n_jobs=min(len(SESSIONS), N_CPUS), verbose=5)(
    delayed(worker)(data) for data in tqdm(EVENTS.values())
))

[Parallel(n_jobs=44)]: Using backend LokyBackend with 44 concurrent workers.
[Parallel(n_jobs=44)]: Done   2 out of  44 | elapsed:    1.9s remaining:   38.9s
[Parallel(n_jobs=44)]: Done  11 out of  44 | elapsed:    2.8s remaining:    8.5s
[Parallel(n_jobs=44)]: Done  20 out of  44 | elapsed:    4.4s remaining:    5.2s
[Parallel(n_jobs=44)]: Done  29 out of  44 | elapsed:    5.6s remaining:    2.9s
[Parallel(n_jobs=44)]: Done  38 out of  44 | elapsed:    6.2s remaining:    1.0s
[Parallel(n_jobs=44)]: Done  44 out of  44 | elapsed:    6.4s finished


In [16]:
mean_magnitudes = pd.Series(data=magnitudes.mean(axis='columns'), name='mean_magnitude')

### Remove transients from dF/F

In [17]:
# Length of transient to remove, in seconds
TAU = 0.2  # s

tau_samples = TAU * metadata.TWOP_SAMPLE_RATE
assert tau_samples.is_integer()
tau_samples = int(tau_samples)

In [18]:
def worker(events, dff):
    frame, cell = np.where(events.loc[:, get_cells(events)])
    no_transients = (
        dff
        .copy()
        .loc[:, get_cells(dff)]
        .to_numpy()
    )
    noncell_cols = dff.loc[:, list(set(dff.columns) - set(get_cells(dff)))]
    
    for t in range(0, tau_samples):
        no_transients[np.minimum(frame + t, (dff.shape[0] - 1)), cell] = np.nan
        
    return (
        pd.DataFrame(
            data=no_transients,
            index=dff.index,
            columns=get_cells(dff),
        )
        .join(noncell_cols)
        .interpolate(method='linear', limit_area='inside')
    )

In [19]:
arg_list = [
    (EVENTS[session].copy(), DFF[session].copy())
    for session in tqdm(SESSIONS)
]

  0%|          | 0/44 [00:00<?, ?it/s]

In [20]:
results = Parallel(n_jobs=min(len(SESSIONS), N_CPUS), verbose=5)(
    delayed(worker)(*args) for args in tqdm(arg_list)
)

DFF_NO_TRANSIENTS = dict(zip(SESSIONS, results))

[Parallel(n_jobs=44)]: Using backend LokyBackend with 44 concurrent workers.
[Parallel(n_jobs=44)]: Done   2 out of  44 | elapsed:    1.6s remaining:   33.7s
[Parallel(n_jobs=44)]: Done  11 out of  44 | elapsed:    6.6s remaining:   19.7s
[Parallel(n_jobs=44)]: Done  20 out of  44 | elapsed:   13.0s remaining:   15.7s
[Parallel(n_jobs=44)]: Done  29 out of  44 | elapsed:   19.7s remaining:   10.2s
[Parallel(n_jobs=44)]: Done  38 out of  44 | elapsed:   23.6s remaining:    3.7s
[Parallel(n_jobs=44)]: Done  44 out of  44 | elapsed:   25.0s finished


# Spectral differentiation

### Build parameter space

Combinations of parameters to use for the sensitivity analysis.

In [21]:
args = (
    pd.DataFrame({
        # The length of the 'neurophysiological state' of the population, in seconds
        'state_length': [0.2, 0.5, 1, 2, 5, 10],
    })
    .merge(pd.DataFrame({
        # The metric with which to measure distances between neurophysiological states
        'metric': ["cityblock", "euclidean", "chebyshev"],
    }), how='cross')
    .merge(pd.DataFrame({
        # Whether to use log-spaced frequency bins in the spectral estimation
        'log_frequency': [False, True],
    }), how='cross')
    .merge(pd.DataFrame({
        # Which window to use
        'window': [
            None,
            'tukey',
            'kaiser',
        ],
        # The window parameter
        'window_param': [
            -1,
            1/4,
            14,
        ],
        # The fraction of window overlap
        'overlap': [
            0,
            1/8,
            1/2
        ],
    }), how='cross')
)
args

Unnamed: 0,state_length,metric,log_frequency,window,window_param,overlap
0,0.2,cityblock,False,,-1.00,0.000
1,0.2,cityblock,False,tukey,0.25,0.125
2,0.2,cityblock,False,kaiser,14.00,0.500
3,0.2,cityblock,True,,-1.00,0.000
4,0.2,cityblock,True,tukey,0.25,0.125
...,...,...,...,...,...,...
103,10.0,chebyshev,False,tukey,0.25,0.125
104,10.0,chebyshev,False,kaiser,14.00,0.500
105,10.0,chebyshev,True,,-1.00,0.000
106,10.0,chebyshev,True,tukey,0.25,0.125


In [22]:
# Parameters used in main analysis
main_params = {
    'state_length': 1.0,
    'metric': 'euclidean',
    'window': None,
    'window_param': -1.0,
    'overlap': 0.0,
    'log_frequency': False,
}
assert main_params in args.to_dict(orient='records')
main_params = pd.Series(main_params)
main_params

state_length           1.0
metric           euclidean
window                None
window_param          -1.0
overlap                0.0
log_frequency        False
dtype: object

In [23]:
args = list(product(
    zip(cycle(['dff']), DFF.items()),
    (row for _, row in args.iterrows())
))

### Compute all parameter combinations

In [24]:
def worker(datatype, session, data, params):
    params = params.to_dict()
    # Compute
    result = compute_spectral_differentiation(
        data,
        fs=metadata.TWOP_SAMPLE_RATE,
        **params,
    )
    # Convert Series result to a DataFrame
    result.name = "differentiation"
    result = (
        result
        .reset_index(drop=False)
        # Record parameters & session
        .assign(
            **params,
            datatype=datatype,
            session=session,
        )
    )
    return result


results = Parallel(n_jobs=min(len(args), N_CPUS), verbose=5)(
    delayed(worker)(datatype, session, data, params)
    for (datatype, (session, data)), params in tqdm(args)
)
results = (
    # Put results into a DataFrame
    pd.concat(results)
    .reset_index(drop=True)
    # Add stimulus metadata
    .merge(STIMULUS_METADATA, on="stimulus")
)
# Only consider block stimuli
results = results.loc[results["stimulus_is_block"]]
print("Done.")

print(results.shape)
results.head()

  peaks, properties = scipy.signal.find_peaks(
  peaks, properties = scipy.signal.find_peaks(
  peaks, properties = scipy.signal.find_peaks(
  peaks, properties = scipy.signal.find_peaks(
  peaks, properties = scipy.signal.find_peaks(
[Parallel(n_jobs=160)]: Using backend LokyBackend with 160 concurrent workers.
[Parallel(n_jobs=160)]: Done 130 tasks      | elapsed:  1.2min
[Parallel(n_jobs=160)]: Done 328 tasks      | elapsed:  1.4min
[Parallel(n_jobs=160)]: Done 562 tasks      | elapsed:  1.5min
[Parallel(n_jobs=160)]: Done 832 tasks      | elapsed:  1.7min
[Parallel(n_jobs=160)]: Done 1138 tasks      | elapsed:  1.9min
[Parallel(n_jobs=160)]: Done 1480 tasks      | elapsed:  2.3min
[Parallel(n_jobs=160)]: Done 1858 tasks      | elapsed:  2.7min
[Parallel(n_jobs=160)]: Done 2272 tasks      | elapsed:  3.0min
[Parallel(n_jobs=160)]: Done 2722 tasks      | elapsed:  3.3min
[Parallel(n_jobs=160)]: Done 3208 tasks      | elapsed:  3.8min
[Parallel(n_jobs=160)]: Done 3730 tasks      | ela

Done.
(570240, 17)


Unnamed: 0,stimulus,trial,differentiation,state_length,metric,log_frequency,window,window_param,overlap,datatype,session,stimulus_type,stimulus_is_scrambled_pair,stimulus_is_block,stimulus_is_continuous,stimulus_code,stimulus_filename
0,conspecifics,0.0,17.202698,0.2,cityblock,False,,-1.0,0.0,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
1,conspecifics,1.0,15.73695,0.2,cityblock,False,,-1.0,0.0,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
2,conspecifics,2.0,16.225434,0.2,cityblock,False,,-1.0,0.0,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
3,conspecifics,3.0,20.011245,0.2,cityblock,False,,-1.0,0.0,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
4,conspecifics,4.0,12.592473,0.2,cityblock,False,,-1.0,0.0,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy


### Attach metadata, behavior, mean event magnitude, and compute normalizations and transforms

In [25]:
metadata_subset = (
    METADATA.loc[METADATA.index.isin(SESSIONS)]
    .copy()
    .drop(
        columns=[
            "operator",
            "qc",
            "notes",
            "all_ophys_experiments",
            "all_ophys_sessions",
            "valid",
        ]
    )
)

In [26]:
def postprocess(results):
    data = (
        results
        .merge(metadata_subset, on="session")
        .merge(mean_magnitudes, on=['session', 'stimulus', 'trial'])
        .merge(behavior_by['trial'], on=['session', 'stimulus', 'trial'])
    )
    # Compute normalizations and transforms
    data['normalized differentiation'] = data['differentiation'] / np.sqrt(data['ncells'])
    data['log(normalized differentiation)'] = np.log10(data['normalized differentiation'])
    return data

In [27]:
df = postprocess(results)

print(df.shape)
df.head()

(570240, 33)


Unnamed: 0,stimulus,trial,differentiation,state_length,metric,log_frequency,window,window_param,overlap,datatype,...,start_time,order,sex,date_of_birth,mean_magnitude,normalized_pupil_diameter,locomotion,filtered_velocity,normalized differentiation,log(normalized differentiation)
0,conspecifics,0.0,17.202698,0.2,cityblock,False,,-1.0,0.0,dff,...,2018-07-05 08:29:08.784,0,M,2018-03-18,0.000961,0.523037,0.0,-0.020912,2.68661,0.429205
1,conspecifics,0.0,13.742288,0.2,cityblock,False,tukey,0.25,0.125,dff,...,2018-07-05 08:29:08.784,0,M,2018-03-18,0.000961,0.523037,0.0,-0.020912,2.146185,0.331667
2,conspecifics,0.0,3.52614,0.2,cityblock,False,kaiser,14.0,0.5,dff,...,2018-07-05 08:29:08.784,0,M,2018-03-18,0.000961,0.523037,0.0,-0.020912,0.55069,-0.259092
3,conspecifics,0.0,15.66024,0.2,cityblock,True,,-1.0,0.0,dff,...,2018-07-05 08:29:08.784,0,M,2018-03-18,0.000961,0.523037,0.0,-0.020912,2.445719,0.388406
4,conspecifics,0.0,12.49222,0.2,cityblock,True,tukey,0.25,0.125,dff,...,2018-07-05 08:29:08.784,0,M,2018-03-18,0.000961,0.523037,0.0,-0.020912,1.950957,0.290248


### Events

In [28]:
results = Parallel(n_jobs=min(len(EVENTS), N_CPUS), verbose=5)(
    delayed(worker)('events', session, data, main_params)
    for session, data in EVENTS.items()
        
)

results = (
    # Put results into a DataFrame
    pd.concat(results)
    .reset_index(drop=True)
    # Add stimulus metadata
    .merge(STIMULUS_METADATA, on="stimulus")
)
# Only consider block stimuli
results = results.loc[results["stimulus_is_block"]]
print("Done.")

[Parallel(n_jobs=44)]: Using backend LokyBackend with 44 concurrent workers.
[Parallel(n_jobs=44)]: Done   2 out of  44 | elapsed:   17.9s remaining:  6.2min
[Parallel(n_jobs=44)]: Done  11 out of  44 | elapsed:   18.8s remaining:   56.4s
[Parallel(n_jobs=44)]: Done  20 out of  44 | elapsed:   20.8s remaining:   25.0s
[Parallel(n_jobs=44)]: Done  29 out of  44 | elapsed:   22.6s remaining:   11.7s
[Parallel(n_jobs=44)]: Done  38 out of  44 | elapsed:   23.7s remaining:    3.7s
[Parallel(n_jobs=44)]: Done  44 out of  44 | elapsed:   24.3s finished


Done.


In [29]:
main_events = postprocess(results)
main_events.shape

  result = getattr(ufunc, method)(*inputs, **kwargs)


(5280, 33)

### No transients

In [30]:
results = Parallel(n_jobs=min(len(DFF_NO_TRANSIENTS), N_CPUS), verbose=5)(
    delayed(worker)('dff', session, data, main_params)
    for session, data in DFF_NO_TRANSIENTS.items()
        
)

results = (
    # Put results into a DataFrame
    pd.concat(results)
    .reset_index(drop=True)
    # Add stimulus metadata
    .merge(STIMULUS_METADATA, on="stimulus")
)
# Only consider block stimuli
results = results.loc[results["stimulus_is_block"]]
print("Done.")

print(results.shape)
results.head()

[Parallel(n_jobs=44)]: Using backend LokyBackend with 44 concurrent workers.
[Parallel(n_jobs=44)]: Done   2 out of  44 | elapsed:    1.7s remaining:   35.4s
[Parallel(n_jobs=44)]: Done  11 out of  44 | elapsed:    3.6s remaining:   10.8s
[Parallel(n_jobs=44)]: Done  20 out of  44 | elapsed:    6.6s remaining:    7.9s
[Parallel(n_jobs=44)]: Done  29 out of  44 | elapsed:    9.0s remaining:    4.7s
[Parallel(n_jobs=44)]: Done  38 out of  44 | elapsed:   10.2s remaining:    1.6s
[Parallel(n_jobs=44)]: Done  44 out of  44 | elapsed:   10.6s finished


Done.
(5280, 17)


Unnamed: 0,stimulus,trial,differentiation,state_length,metric,window,window_param,overlap,log_frequency,datatype,session,stimulus_type,stimulus_is_scrambled_pair,stimulus_is_block,stimulus_is_continuous,stimulus_code,stimulus_filename
0,conspecifics,0.0,45.763957,1.0,euclidean,,-1.0,0.0,False,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
1,conspecifics,1.0,37.440279,1.0,euclidean,,-1.0,0.0,False,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
2,conspecifics,2.0,53.230197,1.0,euclidean,,-1.0,0.0,False,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
3,conspecifics,3.0,50.683034,1.0,euclidean,,-1.0,0.0,False,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy
4,conspecifics,4.0,18.912759,1.0,euclidean,,-1.0,0.0,False,dff,717208879,natural,False,True,True,4,04-multiple-mice-deep.mp4.npy


In [31]:
main_no_transients = postprocess(results)
main_no_transients.shape

(5280, 33)

## Write results to disk

### All parameters (dF/F)

In [32]:
df.to_parquet(OUTPUT_DIR/'sensitivity_analysis.parquet')

### Main parameters (dF/F)

In [33]:
main = df.loc[
    np.logical_and.reduce(
        [
            (df[param].isna() if value is None else df[param] == value)
            for param, value in main_params.items()
        ]
    )
]
main.shape

(5280, 33)

In [34]:
main.to_parquet(OUTPUT_DIR/'main.parquet')

### Events

In [35]:
main_events.to_parquet(OUTPUT_DIR/'main__events.parquet')

### No transients

In [36]:
main_no_transients.to_parquet(OUTPUT_DIR/'main__no-transients.parquet')

# Multivariate differentiation

## Dimensionality reduction

In [37]:
args = (
    pd.DataFrame({
        'datatype': ["dff", "events"],
    }).merge(pd.DataFrame({
        'session': SESSIONS,
    }), how='cross')
)

args = [
    (row, (EVENTS[row['session']] if row['datatype'] == 'events' else DFF[row['session']]))
    for _, row in args.iterrows()
]

In [38]:
def preprocess_dff(dff):
    # Select stimuli of interest
    return dff.loc[dff["stimulus_is_scrambled_pair"] == True, :]

In [39]:
STATE_LENGTH = 1.0  # seconds

def preprocess_events(events):
    session = data.index.get_level_values('session')[0]
    # Select stimuli of interest
    events = events.loc[events["stimulus_is_scrambled_pair"] == True, :]
    # Bin events
    agg = {
        **{col: 'first' for col in events.columns},
        **{col: 'sum' for col in get_cells(events)}
    }
    events = analysis.bin_data(events, STATE_LENGTH, agg=agg)
    # Select frames with at least one event
    events.loc[analysis.active_frames(events)]
    # Include session data in index for later concatenation & grouping
    events.index = make_multiindex(session, events.index)
    return events

In [40]:
with open('umap-params.yml', mode='rt') as f:
    UMAP_PARAMS = yaml.load(f, Loader=yaml.SafeLoader)

In [41]:
def worker(params, data):
    if params['datatype'] == 'dff':
        data = preprocess_dff(data)
    if params['datatype'] == 'events':
        data = preprocess_events(data)
    
    labels = data.loc[:, ["stimulus", "trial"]]
    signal = data.loc[:, get_cells(data)]

    transform = UMAP(**UMAP_PARAMS).fit_transform(signal)

    return (
        pd.DataFrame(
            data=transform,
            index=data.index,
            columns=[f"umap_{i}" for i in range(transform.shape[1])],
        )
        .join(labels)
        .assign(**params)
    )

In [42]:
%%time

results = Parallel(n_jobs=min(len(args), N_CPUS), verbose=5)(
    delayed(worker)(params, data) for params, data in tqdm(args)
)

  0%|          | 0/88 [00:00<?, ?it/s]

[Parallel(n_jobs=88)]: Using backend LokyBackend with 88 concurrent workers.
[Parallel(n_jobs=88)]: Done   3 out of  88 | elapsed:  3.8min remaining: 107.4min
[Parallel(n_jobs=88)]: Done  21 out of  88 | elapsed:  4.3min remaining: 13.8min
[Parallel(n_jobs=88)]: Done  39 out of  88 | elapsed:  4.6min remaining:  5.8min
[Parallel(n_jobs=88)]: Done  57 out of  88 | elapsed:  4.7min remaining:  2.6min
[Parallel(n_jobs=88)]: Done  75 out of  88 | elapsed:  4.9min remaining:   51.2s
[Parallel(n_jobs=88)]: Done  88 out of  88 | elapsed:  5.2min finished


CPU times: user 10.9 s, sys: 41.3 s, total: 52.2 s
Wall time: 5min 11s


In [43]:
transforms = pd.concat(results)

----

## Mean centroid distance

In [44]:
def centroid(points):
    return points.mean(axis=0).reshape(1, -1)


def centroid_distance(points):
    points = points.to_numpy()
    if points.ndim != 2:
        raise ValueError("points must be 2D")
    return scipy.spatial.distance.cdist(points, centroid(points), metric="euclidean")


def mean_centroid_distance(points):
    return np.mean(centroid_distance(points))

In [45]:
results = pd.Series(
    data=(
        transforms
        .droplevel('session', axis='index')
        .groupby(
            [col for col in transforms.columns if not col.startswith("umap")]
        )
        .progress_apply(mean_centroid_distance)
    ),
    name='mean_centroid_distance',
).reset_index()

results["log(mean_centroid_distance)"] = np.log10(results["mean_centroid_distance"])

results = (
    results
    .merge(STIMULUS_METADATA, left_on="stimulus", right_index=True)
    .merge(METADATA, on="session")
)

  0%|          | 0/4400 [00:00<?, ?it/s]

### Write results to disk

In [46]:
mean_centroid_distance = results.loc[
    results['datatype'] == 'dff'
]
mean_centroid_distance.to_parquet(OUTPUT_DIR/"mean_centroid_distance.parquet")

In [47]:
mean_centroid_distance_events = results.loc[
    results['datatype'] == 'events'
]
mean_centroid_distance_events.to_parquet(OUTPUT_DIR/"mean_centroid_distance__events.parquet")

# Decoding

## Helper functions

In [48]:
def flatten_trial(data, target_length):
    """Truncate so each trial has the same number of frames."""
    if not abs(len(data) - target_length) <= TRUNCATION_TOLERANCE:
        raise ValueError(
            f"Expected length within tolerance {tolerance}; got {len(data)}"
        )
    return (
        data.iloc[:target_length]
        .drop(axis="columns", labels=["trial", "stimulus"])
        .to_numpy()
        .flatten()
    )


def get_labelled_data(session, stimuli, get_class):
    """Return data and classes.

    Reshapes and concatenates data for all stimuli.
    """
    dff = DFF[session]
    data = DATA[session]
    # Select stimulus subset and relevant columns
    dff = dff.loc[
        dff["stimulus"].isin(stimuli),
        get_cells(dff) + ['stimulus', 'trial']
    ]
    # X must be of shape (n_samples, n_features):
    # - samples are trials;
    # - features are the dF/F trace value for each cell for every sample within the trial,
    #   i.e. there are (30 s * 30 Hz * n_cells) features.
    # We need to truncate each trial to the minimum number of frames among all of them so
    # they each have the same number of features.
    target_length = int(dff.groupby(["trial", "stimulus"]).agg(len).min().min())
    X = (
        dff
        .groupby(["trial", "stimulus"])
        .apply(flatten_trial, target_length=target_length)
    )
    stimuli = X.index.get_level_values("stimulus")
    classes = stimuli.map(get_class)
    trials = X.index.get_level_values("trial")
    # Reshape to the (n_samples, n_features)
    samples = np.stack(X.to_numpy())
    return samples, classes.to_numpy(), trials.to_numpy(), stimuli.to_numpy()

## Compute

In [49]:
N_JOBS = -1
N_SPLITS = 5
CV = StratifiedKFold(n_splits=N_SPLITS)
PIPELINE = Pipeline(
    [
        ("reduce_dimensionality", PCA(n_components=0.99)),
        ("classify", LinearDiscriminantAnalysis(solver="lsqr", shrinkage="auto")),
    ]
)

In [50]:
identity = lambda x: x
category = STIMULUS_METADATA.stimulus_type.get

In [51]:
%%time

def worker(session, stimulus_set, get_class, scoring):
    X, y, trials, stimuli = get_labelled_data(
        session,
        stimuli=stimulus_set,
        get_class=get_class,
    )
    if scoring == "balanced_accuracy":
        scores = cross_val_score(PIPELINE, X, y, cv=CV, scoring=scoring, n_jobs=1)
        return pd.DataFrame([{scoring: scores.mean(), 'session': session}])
    if scoring == "f1_score":
        labels = sorted(np.unique(y))
        scoring = {
            label: sklearn.metrics.make_scorer(
                sklearn.metrics.f1_score,
                labels=[label],
                average="macro",
            )
            for label in labels
        }
        scores = cross_validate(PIPELINE, X, y, cv=CV, scoring=scoring, n_jobs=1)
        scores["session"] = session
        return pd.DataFrame(scores)
    raise ValueError("invalid scoring method")


for stimulus_set, scoring, get_class in tqdm(
    [
        # Figure 6
        (UNSCRAMBLED_SCRAMBLED, "balanced_accuracy", category),
        # Figure S8
        (BLOCK_STIMULI, "balanced_accuracy", identity),
        # Figure S9
        (CONTINUOUS_NATURAL, "f1_score", identity),
    ]
):
    if get_class is identity:
        decode_target = "identity"
    if get_class is category:
        decode_target = "category"

    if stimulus_set == UNSCRAMBLED_SCRAMBLED:
        stimulus_set_name = "unscrambled_scrambled"
    if stimulus_set == BLOCK_STIMULI:
        stimulus_set_name = "all"
    if stimulus_set == CONTINUOUS_NATURAL:
        stimulus_set_name = "continuous_natural"

    output_path = OUTPUT_DIR/f"decoding__stimuli-{stimulus_set_name}__decode-{decode_target}__scoring-{scoring}.parquet"
    
    chance = 1 / len(set(map(get_class, stimulus_set)))

    with parallel_backend("loky", n_jobs=N_JOBS):
        results = Parallel()(
            delayed(worker)(session, stimulus_set, get_class, scoring)
            for session in SESSIONS
        )

    result_df = pd.concat(results).merge(METADATA, on="session")
    result_df["chance"] = chance

    result_df.to_parquet(output_path)

  0%|          | 0/3 [00:00<?, ?it/s]

CPU times: user 6min 12s, sys: 2min 52s, total: 9min 4s
Wall time: 9min 56s
