<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/SachsLab/MonkeyPFCSaccadeStudies/blob/master/StudyLocationRule/Analysis/04_analyze_shallow.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/SachsLab/IMonkeyPFCSaccadeStudies/blob/master/StudyLocationRule/Analysis/04_analyze_shallow.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Decode Intended Saccade Direction from Macaque PFC Microelectrode Recordings with Shallow ML

## Normalize Environments
Run the first two cells to normalize Local / Colab environments.

In [1]:
from pathlib import Path
import os
import sys
import datetime
import numpy as np
import matplotlib.pyplot as plt

try:
    from google.colab import files
    %tensorflow_version 2.x  # Only on colab
    os.chdir('..')
    
    if not (Path.home() / '.kaggle').is_dir():
        # Configure kaggle
        uploaded = files.upload()  # Find the kaggle.json file in your ~/.kaggle directory.
        if 'kaggle.json' in uploaded.keys():
            !mkdir -p ~/.kaggle
            !mv kaggle.json ~/.kaggle/
            !chmod 600 ~/.kaggle/kaggle.json
    
    !pip install git+https://github.com/SachsLab/indl.git
    
    if Path.cwd().stem == 'MonkeyPFCSaccadeStudies':
        os.chdir(Path.cwd().parent)

    if not (Path.cwd() / 'MonkeyPFCSaccadeStudies').is_dir():
        !git clone --single-branch --recursive https://github.com/SachsLab/MonkeyPFCSaccadeStudies.git
        sys.path.append(str(Path.cwd() / 'MonkeyPFCSaccadeStudies'))
    os.chdir('MonkeyPFCSaccadeStudies')
    
    !pip install -q kaggle
    plt.style.use('dark_background')
    IN_COLAB = True
except ModuleNotFoundError:
    import sys
    
    # chdir to MonkeyPFCSaccadeStudies
    if Path.cwd().stem == 'Analysis':
        os.chdir(Path.cwd().parent.parent)
    
    # Add indl repository to path.
    # Eventually this should already be pip installed, but it's still under heavy development so this is easier for now.
    check_dir = Path.cwd()
    while not (check_dir / 'Tools').is_dir():
        check_dir = check_dir / '..'
    indl_path = check_dir / 'Tools' / 'Neurophys' / 'indl'
    sys.path.append(str(indl_path))
    
    # Make sure the kaggle executable is on the PATH
    os.environ['PATH'] = os.environ['PATH'] + ';' + str(Path(sys.executable).parent / 'Scripts')
    
    IN_COLAB = False

# Try to clear any logs from previous runs
if (Path.cwd() / 'logs').is_dir():
    import shutil
    try:
        shutil.rmtree(str(Path.cwd() / 'logs'))
    except PermissionError:
        print("Unable to remove logs directory.")

# Additional imports
import tensorflow as tf
from indl.display import turbo_cmap
plt.rcParams.update({
    'axes.titlesize': 24,
    'axes.labelsize': 20,
    'lines.linewidth': 2,
    'lines.markersize': 5,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 18
})

%load_ext autoreload
%autoreload 2

In [2]:
# Download and unzip data
if IN_COLAB:
    data_path = Path.cwd() / 'data' / 'monkey_pfc' / 'converted'
else:
    data_path = Path.cwd() / 'StudyLocationRule' / 'Data' / 'Preprocessed'

if not (data_path).is_dir():
    !kaggle datasets download --unzip --path {str(data_path)} cboulay/macaque-8a-spikes-rates-and-saccades
    print("Finished downloading and extracting data.")
else:
    print("Data directory found. Skipping download.")

Data directory found. Skipping download.


## Parameterize

### Data Import

We will set the parameters for our `load_macaque_pfc` function.
Specifically, we are getting the spikerates, which have been smoothed and downsampled, only from trials where the outcome was correct and the d' was at least 1.0.

In [3]:
from misc.misc import load_macaque_pfc, sess_infos

load_kwargs = {
    'valid_outcomes': (0,),  # Use (0, 9) to include trials with incorrect behaviour
    'zscore': True,
    'dprime_range': (1.0, np.inf),  # Use (-np.inf, np.inf) to include all trials.
    'verbose': True,
    'y_type': 'sacClass',
    'samples_last': True  # Our EEGNet blocks expect time-samples in the last dimension.
}

### Shallow ML Model

We will use a 1-layer logistic regression to classify saccade target from presaccadic activity.

In [4]:
from sklearn.model_selection import StratifiedKFold
from sklearn.linear_model import LogisticRegression

log_reg_kwargs = {
    'solver': 'lbfgs',
    'C': 10.0,  # inverse regularization strength
    'penalty': 'l2',
    'multi_class': 'ovr',
    'max_iter': 1000
}
N_SPLITS = 10

## ML Analysis

### Calculate classification accuracy for each session

In [5]:
def get_accuracies(sess_infos, time_range=(-np.inf, np.inf),
                  log_reg_kwargs={'solver': 'lbfgs','C': 10.0,'penalty': 'l2',
                                  'multi_class': 'ovr', 'max_iter': 500}):
    
    splitter = StratifiedKFold(n_splits=N_SPLITS, shuffle=True)
    model = LogisticRegression(**log_reg_kwargs)

    accs_out = []
    for sess_ix, sess_info in enumerate(sess_infos):
        sess_id = sess_info['exp_code']
        print(f"\nProcessing session {sess_id}")
        X, Y, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates',
                                         time_range=time_range,
                                        **load_kwargs)
        y_preds = []
        y_true = []
        X_flat = X.reshape(-1, np.prod(X.shape[1:]))

        model = LogisticRegression(**log_reg_kwargs)

        print(f"Performing {N_SPLITS}-fold cross-validated logistic regression...")
        for kfold, (trn, tst) in enumerate(splitter.split(X, Y)):
            print(f"Fold {kfold + 1}")

            model.fit(X_flat[trn], Y[trn].ravel())
            y_preds.append(model.predict(X_flat[tst]))
            y_true.append(Y[tst].ravel())

        y_preds = np.hstack(y_preds)
        y_true = np.hstack(y_true)

        pcnt_corr = 100 * np.sum(y_preds == y_true) / len(y_preds)
        print(f"{sess_id} 8-class accuracy: % {pcnt_corr:.2f}")
        accs_out.append(pcnt_corr)
        
    return accs_out

In [6]:
accs_out_bline = get_accuracies(sess_infos, time_range=(-np.inf, 0))
accs_out_targ = get_accuracies(sess_infos, time_range=(-np.inf, 0.250))
accs_out_prego = get_accuracies(sess_infos, time_range=(-np.inf, 1.45))
accs_out_all = get_accuracies(sess_infos, time_range=(-np.inf, np.inf))


Processing session sra3_2_j_037_00+03
Found 160 trials, 26 timestamps(-0.25 to 0.0 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (160, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_2_j_037_00+03 8-class accuracy: % 57.50

Processing session sra3_1_j_050_00+
Found 285 trials, 26 timestamps(-0.25 to 0.0 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (285, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_j_050_00+ 8-class accuracy: % 29.82

Processing session sra3_1_j_051_00+
Found 271 trials, 26 timestamps(-0.25 to 0.0 at 100.0 Hz), 32 channels
Returnin

Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_j_051_00+ 8-class accuracy: % 73.43

Processing session sra3_1_j_052_00+
Found 277 trials, 171 timestamps(-0.25 to 1.45 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (277, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_j_052_00+ 8-class accuracy: % 65.70

Processing session sra3_1_m_077_00+01
Found 128 trials, 171 timestamps(-0.25 to 1.45 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (128, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_m_077_00+01 8-class accuracy: % 53.91

Proc

In [7]:
print(np.vstack((accs_out_bline, accs_out_targ, accs_out_prego, accs_out_all)).T)
print([_['exp_code'] for _ in sess_infos])

[[57.5        50.         54.375      54.375     ]
 [29.8245614  52.28070175 84.9122807  98.94736842]
 [43.17343173 60.88560886 73.43173432 96.3099631 ]
 [38.98916968 50.18050542 65.70397112 83.03249097]
 [39.0625     42.1875     53.90625    60.15625   ]
 [51.04166667 40.625      56.25       68.75      ]
 [39.0625     32.8125     71.875      70.3125    ]
 [38.82352941 56.47058824 74.11764706 76.47058824]]
['sra3_2_j_037_00+03', 'sra3_1_j_050_00+', 'sra3_1_j_051_00+', 'sra3_1_j_052_00+', 'sra3_1_m_077_00+01', 'sra3_1_m_081_00+01', 'sra3_1_m_082_00+01', 'sra3_1_m_083_00+01']


## ML Analysis 2 - SVM

In [8]:
from sklearn import svm


svm_kwargs = {
    'C': 0.1,  # inverse regularization strength
    'kernel': 'linear',  # ‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’
#     'degree': 3,  # Only used if kernel is 'poly'
#     'gamma': 'scale',  # kernel coeff, used by 'rbf', 'poly', and 'sigmoid'
    'class_weight': 'balanced',
}

def get_accuracies_svm(sess_infos, time_range=(-np.inf, np.inf),
                       svm_kwargs={'C': 0.1, 'kernel': 'linear',
                                   'class_weight': 'balanced'}):
    
    splitter = StratifiedKFold(n_splits=N_SPLITS, shuffle=True)

    accs_out = []
    for sess_ix, sess_info in enumerate(sess_infos):
        sess_id = sess_info['exp_code']
        print(f"\nProcessing session {sess_id}")
        X, Y, ax_info = load_macaque_pfc(data_path, sess_id, x_chunk='spikerates',
                                         time_range=time_range,
                                        **load_kwargs)
        y_preds = []
        y_true = []
        X_flat = X.reshape(-1, np.prod(X.shape[1:]))

        model = svm.SVC(**svm_kwargs)

        print(f"Performing {N_SPLITS}-fold cross-validated SVM...")
        for kfold, (trn, tst) in enumerate(splitter.split(X, Y)):
            print(f"Fold {kfold + 1}")

            model.fit(X_flat[trn], Y[trn].ravel())
            y_preds.append(model.predict(X_flat[tst]))
            y_true.append(Y[tst].ravel())

        y_preds = np.hstack(y_preds)
        y_true = np.hstack(y_true)

        pcnt_corr = 100 * np.sum(y_preds == y_true) / len(y_preds)
        print(f"{sess_id} 8-class accuracy: % {pcnt_corr:.2f}")
        accs_out.append(pcnt_corr)
        
    return accs_out

In [9]:
accs_out_bline = get_accuracies(sess_infos, time_range=(-np.inf, 0))
accs_out_targ = get_accuracies(sess_infos, time_range=(-np.inf, 0.250))
accs_out_prego = get_accuracies(sess_infos, time_range=(-np.inf, 1.45))
accs_out_all = get_accuracies(sess_infos, time_range=(-np.inf, np.inf))


Processing session sra3_2_j_037_00+03
Found 160 trials, 26 timestamps(-0.25 to 0.0 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (160, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_2_j_037_00+03 8-class accuracy: % 56.88

Processing session sra3_1_j_050_00+
Found 285 trials, 26 timestamps(-0.25 to 0.0 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (285, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_j_050_00+ 8-class accuracy: % 30.53

Processing session sra3_1_j_051_00+
Found 271 trials, 26 timestamps(-0.25 to 0.0 at 100.0 Hz), 32 channels
Returnin

Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_j_051_00+ 8-class accuracy: % 72.32

Processing session sra3_1_j_052_00+
Found 277 trials, 171 timestamps(-0.25 to 1.45 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (277, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_j_052_00+ 8-class accuracy: % 64.98

Processing session sra3_1_m_077_00+01
Found 128 trials, 171 timestamps(-0.25 to 1.45 at 100.0 Hz), 32 channels
Returning Y as sacClass with shape (128, 1).
Axis info has: dict_keys(['instance_data', 'instance_times', 'fs', 'timestamps', 'channel_names', 'channel_locs'])
Performing 10-fold cross-validated logistic regression...
Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Fold 6
Fold 7
Fold 8
Fold 9
Fold 10
sra3_1_m_077_00+01 8-class accuracy: % 53.91

Proc

In [10]:
print(np.vstack((accs_out_bline, accs_out_targ, accs_out_prego, accs_out_all)).T)
print([_['exp_code'] for _ in sess_infos])

[[56.875      55.         54.375      56.875     ]
 [30.52631579 53.68421053 84.9122807  97.89473684]
 [42.80442804 60.51660517 72.32472325 95.94095941]
 [40.433213   52.70758123 64.98194946 81.94945848]
 [30.46875    44.53125    53.90625    60.15625   ]
 [46.875      41.66666667 59.375      66.66666667]
 [37.5        35.9375     75.         70.3125    ]
 [43.52941176 57.64705882 74.11764706 76.47058824]]
['sra3_2_j_037_00+03', 'sra3_1_j_050_00+', 'sra3_1_j_051_00+', 'sra3_1_j_052_00+', 'sra3_1_m_077_00+01', 'sra3_1_m_081_00+01', 'sra3_1_m_082_00+01', 'sra3_1_m_083_00+01']
