# Cross validation
Demonstration of the benefit of using the sklearn framework

## 1. Import Packages

Below, we import both standard packages, and functions from the accompanying .py files

In [26]:
# Import standard packages

import os
import pickle

import numpy as np
from scipy import io, stats
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.model_selection import KFold, cross_validate
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from skorch import NeuralNetRegressor
from torch import optim
from tqdm import tqdm

from Neural_Decoding.nn import FNN
from Neural_Decoding.preprocessing_funcs import LagMat

## 2. Load Data
The data for this example can be downloaded at this [link](https://www.dropbox.com/sh/n4924ipcfjqc0t6/AACPWjxDKPEzQiXKUUFriFkJa?dl=0&preview=example_data_s1.pickle). It was recorded by Raeed Chowdhury from Lee Miller's lab at Northwestern.


The data that we load is in the format described below. We have another example notebook, "Example_format_data", that may be helpful towards putting the data in this format.

Neural data should be a matrix of size "number of time bins" x "number of neurons", where each entry is the firing rate of a given neuron in a given time bin

The output you are decoding should be a matrix of size "number of time bins" x "number of features you are decoding"

 

In [2]:
# folder='' #ENTER THE FOLDER THAT YOUR DATA IS IN
folder = "Decoding_data"

with open(os.path.join(folder, "example_data_s1.pickle"), "rb") as f:
    neural_data, vels_binned = pickle.load(f, encoding="latin1")  # If using python 3

## 3. Preprocess Data

### 3A. User Inputs
The user can define what time period to use spikes from (with respect to the output).

In [3]:
BIN_BEFORE = 6  # How many bins of neural data prior to the output are used for decoding
BIN_CURRENT = 1  # Whether to use concurrent time bin of neural data
BIN_AFTER = 6  # How many bins of neural data after the output are used for decoding

### 3B. Format Covariates

In [4]:
# Set decoding input
X = neural_data.astype("float32")  # torch requires float32

# Set decoding output
y = vels_binned.astype("float32")

## 4. Run cross validation

#### User Options

KFold cross-validation

In [22]:
N_SPLITS = 3

In [27]:
models = {
    "WF": {
        "estimator": Pipeline(
            [
                ("scaler", StandardScaler()),
                ("lagmat", LagMat(BIN_BEFORE, BIN_CURRENT, BIN_AFTER, flat=True)),
                ("linear", LinearRegression()),
            ]
        ),
    },
    "FNN": {
        "estimator": Pipeline(
            [
                ("scaler", StandardScaler()),
                ("lagmat", LagMat(BIN_BEFORE, BIN_CURRENT, BIN_AFTER, flat=True)),
                (
                    "fnn",
                    NeuralNetRegressor(
                        module=FNN,
                        lr=0.001,
                        iterator_train__shuffle=True,
                        optimizer=optim.Adam,
                        batch_size=32,
                        module__n_targets=y.shape[1],
                        module__num_units=400,
                        module__frac_dropout=0.25,
                        module__n_layers=2,
                        max_epochs=10,
                        verbose=0,
                    ),
                ),
            ]
        ),
    },
}

for name in tqdm(models):
    model = models[name]["estimator"]

    cv_results = cross_validate(
        model,
        X,
        y,
        cv=KFold(n_splits=N_SPLITS),
        scoring="r2",
        n_jobs=1,  # parallelisation
        verbose=1,
    )

    print(f"R2s {name}: {cv_results['test_score']}")

  0%|                                                                                             | 0/2 [00:00<?, ?it/s][Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    1.6s finished
 50%|██████████████████████████████████████████▌                                          | 1/2 [00:01<00:01,  1.57s/it]

R2s WF: [0.73822659 0.7493962  0.71846437]


[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   42.8s finished
100%|█████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:44<00:00, 22.19s/it]

R2s FNN: [0.83662546 0.84977973 0.83463073]



