In [None]:
# parameters
import os
foldername = "~/area2_population_analysis/s1-kinematics/actpas_NWB/"
foldername = os.path.expanduser(foldername)
# monkey = "Han_20171207"
# monkey = 'Duncan_20190710'
# monkey = "Chips_20170913"
# monkey = "Lando_20170731"

# monkey = "Han_20171201"
# monkey = "Han_20171204"
# monkey = 'Duncan_20191016'
# monkey = 'Duncan_20191106'âˆ‚

monkey = "Lando_20170803"

filename = foldername + monkey + "_COactpas_TD_offset6.nwb"

In [None]:
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib as mpl
from sklearn.linear_model import Ridge, LinearRegression, Lasso
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedShuffleSplit
import math
import multiprocessing as mp
from scipy.linalg import orth

from Neural_Decoding.preprocessing_funcs import get_spikes_with_history
from Area2_analysis.lr_funcs import process_train_test, gaussian_filter1d_oneside, gaussian_filter1d_twoside, comp_cc, xcorr, r2_score
from Area2_analysis.lr_funcs import get_sses_pred, get_sses_mean, nans
# from Area2_analysis.lr_funcs import fit_and_predict, sub_and_predict, pred_with_new_weights
# from Area2_analysis.lr_funcs import fit_and_predict_lasso, sub_and_predict_lasso, 
from Area2_analysis.lr_funcs import fit_and_predict_MC, calc_proj, principal_angles, angle_between
figDir = '/Users/sherryan/Desktop/paper/'
matplotlib.rc('font', size=18)


In [None]:

dataset_10ms = NWBDataset(filename, split_heldout=False)

dataset_10ms.resample(10) #in 10-ms bin, has to resample first for Duncan
bin_width = dataset_10ms.bin_width
print(bin_width)

dataset_10ms.smooth_spk(40, name='smth_40')
n_dims = 20 
all_data = np.array(dataset_10ms.data.spikes_smth_40)
print(all_data.shape)
if not np.isnan(all_data).any():
    scaler = StandardScaler()
    X = scaler.fit_transform(all_data)
    pca = PCA(n_components=n_dims,random_state = 42)
    PCA_data = pca.fit_transform(X)
print(PCA_data.shape)
dataset_10ms.add_continuous_data(PCA_data,'PCA_40')
print('PCA total var explained:',sum(pca.explained_variance_ratio_))

In [None]:
dataset = dataset_10ms
trial_info = dataset.trial_info

# Basic counts
n_trials = trial_info.shape[0]
print(n_trials, 'total trials')

n_neurons = dataset.data.spikes.shape[1]
print(n_neurons, 'neurons')

# Masks
valid_mask   = (trial_info['split'] != 'none')
active_mask  = (trial_info.ctr_hold_bump == 0) & valid_mask
passive_mask = (trial_info.ctr_hold_bump == 1) & valid_mask
nan_mask     = trial_info.ctr_hold_bump.isna() & valid_mask

print(trial_info.loc[valid_mask].shape[0],   'valid trials')
print(trial_info.loc[active_mask].shape[0],  'active trials')
print(trial_info.loc[passive_mask].shape[0], 'passive trials')
print(trial_info.loc[nan_mask].shape[0],     'reach bump trials')

In [None]:
def build_mask_and_cond_dict(
    dataset,
    base_mask,                 # e.g. active_mask / passive_mask / nan_mask
    align_field,
    align_range,
    dir_col='cond_dir',
    drop_to_four_if_not_full8=True,
):
    """
    Returns:
      new_mask: mask with trials removed that are not present in aligned_df OR weird dirs
      cond_dict: integer labels, aligned to trial_ids_kept (trial order in aligned_df)
      trial_ids_kept: the trial IDs used (same order as cond_dict)
    """
    trial_info = dataset.trial_info

    # 1) keep only trials that survive this alignment
    df = dataset.make_trial_data(
        align_field=align_field,
        align_range=align_range,
        ignored_trials=~base_mask
    )
    trial_ids = df['trial_id'].drop_duplicates().to_numpy()

    # Update mask to only trials that appear in df
    new_mask = base_mask & trial_info['trial_id'].isin(trial_ids)

    # 2) compute directions for these trial_ids (in same order)
    ti = trial_info.set_index('trial_id') if trial_info.index.name != 'trial_id' else trial_info
    dirs = (ti.loc[trial_ids, dir_col] % 360).to_numpy().astype(float)

    full_dirs = np.array([0,45,90,135,180,225,270,315], dtype=float)
    four_dirs = np.array([0,90,180,270], dtype=float)

    # drop weird values not in full_dirs
    canon_mask = np.isin(dirs, full_dirs)
    trial_ids = trial_ids[canon_mask]
    dirs = dirs[canon_mask]

    # decide which canonical set to use
    unique_dirs = np.sort(np.unique(dirs))
    if (set(unique_dirs) == set(full_dirs)) or (not drop_to_four_if_not_full8):
        cond_dirs = full_dirs
    else:
        cond_dirs = np.array([d for d in four_dirs if d in set(unique_dirs)], dtype=float)

    # drop trials not in cond_dirs (this removes diagonals for 6-dir monkeys, etc.)
    keep = np.isin(dirs, cond_dirs)
    trial_ids_kept = trial_ids[keep]
    dirs_kept = dirs[keep]

    # update mask again to drop those trials globally
    new_mask = new_mask & trial_info['trial_id'].isin(trial_ids_kept)

    # build cond_dict aligned to trial_ids_kept
    dir_to_cond = {d: i for i, d in enumerate(cond_dirs)}
    cond_dict = np.array([dir_to_cond[d] for d in dirs_kept], dtype=int)

    return new_mask, cond_dict, trial_ids_kept

In [None]:
dataset = dataset_10ms
trial_info = dataset.trial_info

valid_mask   = (trial_info['split'] != 'none')
active_mask  = (trial_info.ctr_hold_bump == 0) & valid_mask
passive_mask = (trial_info.ctr_hold_bump == 1) & valid_mask
nan_mask     = trial_info.ctr_hold_bump.isna() & valid_mask

# Build cond_dict that matches *whatever alignment you plan to use later*
active_mask,  active_cond_dict,  active_trial_ids  = build_mask_and_cond_dict(
    dataset, active_mask,  align_field='move_onset_time',  align_range=(-100, 0), dir_col='cond_dir'
)
print('active (onset):', len(active_trial_ids), 'trials')
print('active_cond_dict_onset length:', len(active_cond_dict))

active_mask_offset, active_cond_dict_offset, active_trial_ids_offset = build_mask_and_cond_dict(
    dataset, active_mask,  align_field='move_offset_time', align_range=(-100, 0), dir_col='cond_dir'
)
print('active (offset):', len(active_trial_ids_offset), 'trials')
print('active_cond_dict_offset length:', len(active_cond_dict_offset))

passive_mask, passive_cond_dict, passive_trial_ids = build_mask_and_cond_dict(
    dataset, passive_mask, align_field='move_onset_time',  align_range=(-100, 0), dir_col='cond_dir'
)
print('passive:', len(passive_trial_ids), 'trials')
print('passive_cond_dict length:', len(passive_cond_dict))

# counts you said you still need
active_n_trials  = len(active_trial_ids)
passive_n_trials = len(passive_trial_ids)

In [None]:
import numpy as np
from scipy.stats import ttest_rel, wilcoxon
## TODO: corrections for multiple comparisons

def df_to_trial_matrix(df, field='spikes', agg='mean'):
    """
    Returns A: [n_trials, n_time]
    agg: 'mean' -> mean over neurons, 'sum' -> sum over neurons
    """
    mats = []
    for _, trial in df.groupby('trial_id'):
        X = getattr(trial, field).to_numpy()  # [n_time, n_neurons]
        mats.append(X.mean(axis=1) if agg=='mean' else X.sum(axis=1))
    return np.vstack(mats)

def window_mean_matrix(A, x_axis, window):
    idx = (x_axis >= window[0]) & (x_axis <= window[1])
    return A[:, idx].mean(axis=1)

# ---- config ----
plot_range = (-1000, 500)
x_axis = np.arange(plot_range[0], plot_range[1], dataset_10ms.bin_width)

active_df = dataset_10ms.make_trial_data(
    align_field='move_onset_time', align_range=plot_range,
    ignored_trials=~active_mask, allow_overlap=True
)

A = df_to_trial_matrix(active_df, field='spikes', agg='mean')  # [n_trials, n_time]

baseline_window = (-250, -150)
baseline_vals = window_mean_matrix(A, x_axis, baseline_window)

window_size = 20
step_size = 10
test_centers = np.arange(-150, 0, step_size)

p_ttest, p_wilcoxon = [], []
for c in test_centers:
    w = (c - window_size/2, c + window_size/2)
    test_vals = window_mean_matrix(A, x_axis, w)

    # paired tests
    p_ttest.append(ttest_rel(test_vals, baseline_vals, nan_policy='omit').pvalue)

    # wilcoxon 
    diff = test_vals - baseline_vals
    if np.allclose(diff, 0):
        p_wilcoxon.append(1.0)
    else:
        p_wilcoxon.append(wilcoxon(test_vals, baseline_vals).pvalue)

# plot
plt.plot(test_centers, p_wilcoxon, marker='o', label='wilcoxon')
plt.plot(test_centers, p_ttest, marker='o', label='ttest')
plt.axhline(0.05, color='red', linestyle='--')
plt.xlabel('Time relative to movement onset (ms)')
plt.ylabel('p-value')
plt.legend()
plt.show()

In [None]:
def psth_from_df(df, x_axis, bin_width_ms, n_neurons, field='spikes'):
    A = df_to_trial_matrix(df, field=field, agg='sum')  # [n_trials, n_time] counts per bin (summed neurons)
    fr = A / bin_width_ms * 1000.0 / n_neurons          # spikes/s/neuron
    mean = fr.mean(axis=0)
    sem  = fr.std(axis=0, ddof=1) / np.sqrt(fr.shape[0])
    return mean, sem, fr

plot_range = (-300, 600)
x_axis = np.arange(plot_range[0], plot_range[1], dataset_10ms.bin_width)

active_df = dataset_10ms.make_trial_data(
    align_field='move_onset_time', align_range=plot_range,
    ignored_trials=~active_mask, allow_overlap=True
)
passive_df = dataset_10ms.make_trial_data(
    align_field='move_onset_time', align_range=plot_range,
    ignored_trials=~passive_mask, allow_overlap=True
)

active_mean, active_sem, _ = psth_from_df(active_df, x_axis, dataset_10ms.bin_width, n_neurons)
passive_mean, passive_sem, _ = psth_from_df(passive_df, x_axis, dataset_10ms.bin_width, n_neurons)

fig, ax = plt.subplots(figsize=(8,6))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

ax.plot(x_axis, active_mean, linewidth=2, color='k', label='Active')
ax.fill_between(x_axis, active_mean-active_sem, active_mean+active_sem, alpha=0.2, color='k')

ax.plot(x_axis, passive_mean, linewidth=2, color='red', label='Passive')
ax.fill_between(x_axis, passive_mean-passive_sem, passive_mean+passive_sem, alpha=0.2, color='red')

ax.axvline(0, color='k', linestyle='--')
ax.set_xlim([-300, 600])
ax.set_ylabel('Average firing rate (sp/s/neuron)')
ax.set_xlabel('Time after movement onset (ms)')
ax.legend(frameon=False)
plt.tight_layout()
plt.savefig(figDir + monkey + '_psth.pdf', dpi='figure')
plt.show()

In [None]:
def pick_ref(align_field, onset_ref, offset_ref):
    return offset_ref if "offset" in align_field else onset_ref

def align_cond_dict_to_df(df, ref_trial_ids, ref_cond_dict):
    trial_ids_df = df['trial_id'].drop_duplicates().to_numpy()
    ref_map = {tid: int(ref_cond_dict[i]) for i, tid in enumerate(ref_trial_ids)}

    keep = np.array([tid in ref_map for tid in trial_ids_df])
    kept_ids = trial_ids_df[keep]
    dropped_ids = trial_ids_df[~keep]

    cond_df = np.array([ref_map[tid] for tid in kept_ids], dtype=int)
    return kept_ids, cond_df, dropped_ids

In [None]:
def compute_global_zscore(dataset, x_field, z_cfg):
    df = dataset.make_trial_data(
        align_field=z_cfg["align_field"],
        align_range=z_cfg["align_range"],
        ignored_trials=~z_cfg["trial_mask"],
    )
    dim = dataset.data[x_field].shape[1]
    X = df[x_field].to_numpy().reshape((-1, dim))
    mean = np.nanmean(X, axis=0)
    std  = np.nanstd(X, axis=0)
    return mean, std

In [None]:
def get_epoch_X_mean_and_labels(dataset, x_field, epoch_cfg, mean, std, onset_ref, offset_ref):
    df = dataset.make_trial_data(
        align_field=epoch_cfg["align_field"],
        align_range=epoch_cfg["align_range"],
        ignored_trials=~epoch_cfg["trial_mask"],
    )

    ref_ids, ref_cond = pick_ref(epoch_cfg["align_field"], onset_ref, offset_ref)
    trial_ids, cond_dict, dropped = align_cond_dict_to_df(df, ref_ids, ref_cond)

    # keep only kept trial rows to make reshape consistent
    df2 = df[df["trial_id"].isin(trial_ids)]

    dim = dataset.data[x_field].shape[1]
    X = df2[x_field].to_numpy().reshape((-1, dim))
    X = (X - mean) / std

    n_trials = len(trial_ids)
    n_time = X.shape[0] // n_trials
    X_trials = X.reshape((n_trials, n_time, dim))
    X_mean = np.nanmean(X_trials, axis=1)

    # trial directions (only for cos/sin target)
    ti = dataset.trial_info.set_index("trial_id") if dataset.trial_info.index.name != "trial_id" else dataset.trial_info
    dirs = (ti.loc[trial_ids, "cond_dir"] % 360).to_numpy().astype(float)

    return dict(
        name=epoch_cfg["name"],
        X_mean=X_mean,
        dirs=dirs,
        cond_dict=cond_dict,
        trial_ids=trial_ids,
        dropped_ids=dropped,
    )

In [None]:
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.linear_model import Ridge

def find_axes_iterative(X, y, cond_dict, N=10, r2_thresh=0.0,
                        alpha_grid=None, n_folds=10, random_state=42):
    if alpha_grid is None:
        alpha_grid = np.logspace(-3, 3, 7)

    n_trials, dim = X.shape
    X_proc = X.copy()
    axes_all = np.full((N, dim), np.nan)
    r2_cv = np.full((N,), np.nan)

    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_state)

    for i in range(N):
        reg = GridSearchCV(Ridge(), {'alpha': alpha_grid}).fit(X_proc, y)
        w = reg.best_estimator_.coef_.reshape(-1)
        axes_all[i, :] = w

        true_concat = np.full((n_trials, 1), np.nan)
        pred_concat = np.full((n_trials, 1), np.nan)
        save = 0
        for tr, te in skf.split(np.arange(n_trials), cond_dict):
            lr = GridSearchCV(Ridge(), {'alpha': alpha_grid})
            lr.fit(X_proc[tr], y[tr])
            y_pred = lr.predict(X_proc[te]).reshape(-1, 1)

            n = len(te)
            true_concat[save:save+n] = y[te].reshape(-1, 1)
            pred_concat[save:save+n] = y_pred
            save += n

        sses = get_sses_pred(true_concat, pred_concat)
        sses_mean = get_sses_mean(true_concat)
        r2_cv[i] = 1 - np.sum(sses) / np.sum(sses_mean)

        X_proc = X_proc - calc_proj(X_proc, w.reshape(-1, 1)).T

    axes_kept = axes_all[r2_cv > r2_thresh, :]
    return axes_kept, r2_cv

def compute_axes_cos_sin(epoch_data, N=10, r2_thresh=0.0, proj_out_axes=None):
    X = epoch_data["X_mean"].copy()
    if proj_out_axes is not None and proj_out_axes.size:
        X = X - calc_proj(X, proj_out_axes.T).T

    dirs = epoch_data["dirs"]
    cond = epoch_data["cond_dict"]

    cos_sig = np.cos(np.deg2rad(dirs)).reshape(-1, 1)
    sin_sig = np.sin(np.deg2rad(dirs)).reshape(-1, 1)

    axes_x, r2x = find_axes_iterative(X, cos_sig, cond, N=N, r2_thresh=r2_thresh)
    axes_y, r2y = find_axes_iterative(X, sin_sig, cond, N=N, r2_thresh=r2_thresh)

    axes = np.vstack([axes_x, axes_y]) if (axes_x.size or axes_y.size) else np.zeros((0, X.shape[1]))
    return axes, r2x, r2y

In [None]:
def project_continuous(dataset, x_field, mean, std, axes):
    X = dataset.data[x_field].to_numpy()
    Xz = (X - mean) / std
    return Xz @ axes.T if axes.size else np.zeros((Xz.shape[0], 0))

In [None]:
import os

def run_cdfb_pipeline(
    monkey,
    dataset,
    x_field,
    onset_ref,          # (active_trial_ids, active_cond_dict)
    offset_ref,         # (active_trial_ids_offset, active_cond_dict_offset)
    z_cfg,
    epoch_cfgs,         # list of configs
    N=10,
    r2_thresh=0.0,
    save_dir=".",
    save_tag="v6_zscore",
):
    # 1) global zscore
    mean, std = compute_global_zscore(dataset, x_field, z_cfg)

    # 2) compute axes per epoch (with optional proj_out)
    axes_dict = {}
    meta = {"epochs": [], "x_field": x_field, "z_cfg": z_cfg, "N": N, "r2_thresh": r2_thresh}

    for ep in epoch_cfgs:
        ep_data = get_epoch_X_mean_and_labels(dataset, x_field, ep, mean, std, onset_ref, offset_ref)

        proj_out_name = ep.get("proj_out", None)
        proj_out_axes = axes_dict.get(proj_out_name, None) if proj_out_name else None

        axes, r2x, r2y = compute_axes_cos_sin(ep_data, N=N, r2_thresh=r2_thresh, proj_out_axes=proj_out_axes)
        axes_dict[ep["name"]] = axes

        meta["epochs"].append({
            "name": ep["name"],
            "align_field": ep["align_field"],
            "align_range": ep["align_range"],
            "trial_mask_name": ep.get("mask_name", ""),
            "n_trials": int(len(ep_data["trial_ids"])),
            "n_dropped_not_in_ref": int(len(ep_data["dropped_ids"])),
            "proj_out": proj_out_name,
            "axes_shape": tuple(axes.shape),
            "r2x": r2x.tolist(),
            "r2y": r2y.tolist(),
        })

        print(f"[{ep['name']}] axes {axes.shape}, trials {len(ep_data['trial_ids'])}, dropped(not-in-ref) {len(ep_data['dropped_ids'])}")

    # 3) project continuous
    proj_dict = {f"{name}_proj": project_continuous(dataset, x_field, mean, std, axes)
                 for name, axes in axes_dict.items()}

    # common combos if present
    if "CD" in axes_dict and "FB_onset" in axes_dict:
        proj_dict["CD_FB_onset_proj"] = np.hstack([proj_dict["CD_proj"], proj_dict["FB_onset_proj"]])
    if "CD" in axes_dict and "FB_offset" in axes_dict:
        proj_dict["CD_FB_offset_proj"] = np.hstack([proj_dict["CD_proj"], proj_dict["FB_offset_proj"]])

    # 4) save 2 npz
    weights_path = os.path.join(save_dir, f"{monkey}_{save_tag}_cdfb_weights_{x_field}.npz")
    data_path    = os.path.join(save_dir, f"{monkey}_{save_tag}_cdfb_data_{x_field}.npz")

    np.savez(weights_path, mean=mean, std=std,
             **{f"{k}_axes": v for k, v in axes_dict.items()},
             meta=np.array([meta], dtype=object))
    np.savez(data_path, **proj_dict)

    print("Saved weights:", weights_path)
    print("Saved data   :", data_path)
    return weights_path, data_path

In [None]:
def project_continuous(dataset, x_field, mean, std, axes):
    X = dataset.data[x_field].to_numpy()
    Xz = (X - mean) / std
    return Xz @ axes.T if axes.size else np.zeros((Xz.shape[0], 0))

In [None]:
import os

def run_cdfb_pipeline(
    monkey,
    dataset,
    x_field,
    onset_ref,          # (active_trial_ids, active_cond_dict)
    offset_ref,         # (active_trial_ids_offset, active_cond_dict_offset)
    z_cfg,
    epoch_cfgs,         # list of configs
    N=10,
    r2_thresh=0.0,
    save_dir=".",
    save_tag="v6_zscore",
):
    # 1) global zscore
    mean, std = compute_global_zscore(dataset, x_field, z_cfg)

    # 2) compute axes per epoch (with optional proj_out)
    axes_dict = {}
    meta = {"epochs": [], "x_field": x_field, "z_cfg": z_cfg, "N": N, "r2_thresh": r2_thresh}

    for ep in epoch_cfgs:
        ep_data = get_epoch_X_mean_and_labels(dataset, x_field, ep, mean, std, onset_ref, offset_ref)

        proj_out_name = ep.get("proj_out", None)
        proj_out_axes = axes_dict.get(proj_out_name, None) if proj_out_name else None

        axes, r2x, r2y = compute_axes_cos_sin(ep_data, N=N, r2_thresh=r2_thresh, proj_out_axes=proj_out_axes)
        axes_dict[ep["name"]] = axes

        meta["epochs"].append({
            "name": ep["name"],
            "align_field": ep["align_field"],
            "align_range": ep["align_range"],
            "trial_mask_name": ep.get("mask_name", ""),
            "n_trials": int(len(ep_data["trial_ids"])),
            "n_dropped_not_in_ref": int(len(ep_data["dropped_ids"])),
            "proj_out": proj_out_name,
            "axes_shape": tuple(axes.shape),
            "r2x": r2x.tolist(),
            "r2y": r2y.tolist(),
        })

        print(f"[{ep['name']}] axes {axes.shape}, trials {len(ep_data['trial_ids'])}, dropped(not-in-ref) {len(ep_data['dropped_ids'])}")

    # 3) project continuous
    proj_dict = {f"{name}_proj": project_continuous(dataset, x_field, mean, std, axes)
                 for name, axes in axes_dict.items()}

    # common combos if present
    if "CD" in axes_dict and "FB_onset" in axes_dict:
        proj_dict["CD_FB_onset_proj"] = np.hstack([proj_dict["CD_proj"], proj_dict["FB_onset_proj"]])
    if "CD" in axes_dict and "FB_offset" in axes_dict:
        proj_dict["CD_FB_offset_proj"] = np.hstack([proj_dict["CD_proj"], proj_dict["FB_offset_proj"]])

    # 4) save 2 npz
    weights_path = os.path.join(save_dir, f"{monkey}_{save_tag}_cdfb_weights_{x_field}.npz")
    data_path    = os.path.join(save_dir, f"{monkey}_{save_tag}_cdfb_data_{x_field}.npz")

    np.savez(weights_path, mean=mean, std=std,
             **{f"{k}_axes": v for k, v in axes_dict.items()},
             meta=np.array([meta], dtype=object))
    np.savez(data_path, **proj_dict)

    print("Saved weights:", weights_path)
    print("Saved data   :", data_path)
    return weights_path, data_path

In [None]:
onset_ref  = (active_trial_ids, active_cond_dict)
offset_ref = (active_trial_ids_offset, active_cond_dict_offset)

z_cfg = dict(
    align_field="move_onset_time",
    align_range=(-100, 1500),
    trial_mask=active_mask,          # ä½ æƒ³ç”¨å“ªä¸ª mask ç»Ÿè®¡ mean/std å°±å¡«å“ªä¸ª
)

epoch_cfgs = [
    dict(name="CD", align_field="move_onset_time", align_range=(-100, 0),  trial_mask=active_mask, mask_name="active"),
    dict(name="FB_onset", align_field="move_onset_time", align_range=(200, 400), trial_mask=active_mask, mask_name="active",
         proj_out="CD"),
    # optional
    dict(name="FB_offset", align_field="move_offset_time", align_range=(-100, 0), trial_mask=active_mask_offset, mask_name="active_offset",
         proj_out="CD"),
]

weights_path, data_path = run_cdfb_pipeline(
    monkey=monkey,
    dataset=dataset_10ms,
    x_field="spikes",
    onset_ref=onset_ref,
    offset_ref=offset_ref,
    z_cfg=z_cfg,
    epoch_cfgs=epoch_cfgs,
    N=10,
    r2_thresh=0.0,
    save_dir=".",
    save_tag="v6_zscore_unsmoothed",
)

In [None]:
dataset = dataset_10ms
x_field = 'spikes'
unsmoothed = np.load(monkey + f'_v6_zscore_unsmoothed_cdfb_data_{x_field}.npz')
gaussian_kernel_width = 40  # ms
sigma = int(gaussian_kernel_width / bin_width)

proj_keys = {
    'CD_proj':     'unsmoothed_CD_proj',
    'FB_onset_proj':     'unsmoothed_FB_proj',
    'CD_FB_onset_proj':  'unsmoothed_CD_FB_proj',
}

smoothed_data = {}

for key, dataset_name in proj_keys.items():
    X = unsmoothed[key].astype(np.float64)
    X_sm = gaussian_filter1d_twoside(X, sigma, axis=0)

    smoothed_data[key] = X_sm
    # dataset.add_continuous_data(X_sm, f'smoothed_{key}_{x_field}')

print("Added smoothed data to dataset:")
print(dataset.data.keys().unique(0))
np.savez(
    monkey + f'_v6_zscore_smoothed_cdfb_data_{x_field}.npz',
    **smoothed_data
)


In [None]:
rename_map = {
    'CD_proj':            'CD_proj',
    'FB_onset_proj':      'FB_proj',
    'CD_FB_onset_proj':   'CD_FB_proj',
}
dataset = dataset_10ms
x_field = 'spikes'

data = np.load(monkey + f'_v6_zscore_smoothed_cdfb_data_{x_field}.npz')
print("Keys in file:", data.files)

for file_key, dataset_key in rename_map.items():
    dataset.add_continuous_data(data[file_key], dataset_key)

print(dataset.data.keys().unique(0))

In [None]:
if len(np.unique(active_cond_dict_offset)) == 8:
    plot_dir = np.array([0,45,90,135,180,225,270,315]) 
    directions = np.array([0,45,90,135,180,225,270,315])
else:
    plot_dir = np.array([0,90,180,270]) 
    directions = np.array([0,90,180,270])
cmap = plt.get_cmap('coolwarm',len(plot_dir))
custom_palette = [mpl.colors.rgb2hex(cmap(i)) for i in range(len(plot_dir))]
plot_field = 'CD_FB_proj'
N = dataset_10ms.data[plot_field].shape[1]
order = range(N)

pred_range = (-200, 1100)
trial_mask = active_mask
cond_dict = active_cond_dict
n_timepoints = int((pred_range[1] - pred_range[0])/dataset_10ms.bin_width)
data = dataset_10ms.make_trial_data(align_field='move_onset_time', align_range=pred_range, ignored_trials=~trial_mask, allow_overlap=True)
n_trials = data['trial_id'].nunique()
trials_pca = nans([n_trials,n_timepoints,N])
i = 0
for idx, trial in data.groupby('trial_id'):
    trials_pca[i,:,:]=trial[plot_field].to_numpy()
    i+=1
print(trials_pca.shape)

x_axis = np.arange(pred_range[0], pred_range[1], dataset_10ms.bin_width)

# define some useful time points
move_idx=0
ret_idx = 500

plot_dims = N

fig,ax=plt.subplots(plot_dims,1,figsize=(10,N+4))
for i in range(plot_dims):
    for j in range(len(plot_dir)):
        color = custom_palette[j]
        dir_idx = np.argwhere(directions == plot_dir[j])[0]
        cond_mean_proj = np.mean(trials_pca[np.argwhere(cond_dict==dir_idx).flatten(),:,:], axis = 0)[:,order[i]] 
        pca_mean = np.mean(data[plot_field].to_numpy(),axis = 0)[order[i]] 
        ax[i].plot(x_axis,cond_mean_proj - pca_mean,linewidth=2.25,color = color,label = plot_dir[j])
        
        ax[i].axvline(move_idx, color='k',linewidth = 1)
        ax[i].axvline(ret_idx, color='k',linewidth = 1)
        
        ax[i].set_xlim([-200,1000])
        # ax[i].set_ylim([-.5, .5])
        ax[i].axhline(0,color ='k',ls = '--')
        if i<plot_dims-1:
            ax[i].set_xticks([])
        else:
            ax[i].set_xlabel('Time after movement onset (ms)')
            
        ax[i].set_yticks([])
        ax[i].set_ylabel('Dim. '+str(i+1))

    ax[0].set_title('Active trials')
     
plt.legend(bbox_to_anchor = (1, 1), loc = 'upper left')
plt.tight_layout()
# plt.savefig(figDir + monkey + '_ylim_cdfb_active_smooth.pdf',dpi = 'figure')

In [None]:
# Plot PCA projections over trial, for different reaching directions
pred_range = (-100, 600)
trial_mask = passive_mask
cond_dict = passive_cond_dict
n_timepoints = int((pred_range[1] - pred_range[0])/dataset_10ms.bin_width)
data = dataset_10ms.make_trial_data(align_field='move_onset_time', align_range=pred_range, ignored_trials=~trial_mask, allow_overlap=True)
n_trials = data['trial_id'].nunique()
trials_pca = nans([n_trials,n_timepoints,N])
i = 0
for idx, trial in data.groupby('trial_id'):
    trials_pca[i,:,:]=trial[plot_field].to_numpy()
    i+=1
print(trials_pca.shape)

x_axis = np.arange(pred_range[0], pred_range[1], dataset_10ms.bin_width)

# define some useful time points
move_idx=0
ret_idx = 120

plot_dims = N

fig,ax=plt.subplots(plot_dims,1,figsize=(10,N+4))
for i in range(plot_dims):
    for j in range(len(plot_dir)):
        color = custom_palette[j]
        dir_idx = np.argwhere(directions == plot_dir[j])[0]
        cond_mean_proj = np.mean(trials_pca[np.argwhere(cond_dict==dir_idx).flatten(),:,:], axis = 0)[:,order[i]] 
        pca_mean = np.mean(data[plot_field].to_numpy(),axis = 0)[order[i]]
        ax[i].plot(x_axis,cond_mean_proj - pca_mean,linewidth=2.25,color = color,label = plot_dir[j])
        
        ax[i].axvline(move_idx, color='k',linewidth = 1)
        # ax[i].axvline(120, color='k',linewidth = 1)
        ax[i].axvline(ret_idx, color='k',linewidth = 1)
        ax[i].set_xlim([-100,500])
        # ax[i].set_ylim([-.5, .5])
        ax[i].axhline(0,color ='k',ls = '--')
        if i<plot_dims-1:
            ax[i].set_xticks([])
        else:
            ax[i].set_xlabel('Time after movement onset (ms)')

        ax[i].set_yticks([])
        ax[i].set_ylabel('Dim. '+str(i+1))

    ax[0].set_title('Passive trials')

plt.legend(bbox_to_anchor = (1, 1), loc = 'upper left')
plt.tight_layout()
# plt.savefig(figDir + monkey + '_ylim_cdfb_passive_smooth.pdf',dpi = 'figure')

In [None]:
dataset.data.keys().unique(0)

In [None]:
n_cd_dims = data['CD_proj'].shape[1]
print(n_cd_dims,"CD dimensions")

In [None]:
y_field = 'hand_vel'
lag_axis = np.arange(-300, 320, 20)
norm_x    = True
n_splits = 20

In [None]:
def run_decoding_for_feature(
    name,            # e.g. 'cd_only', 'SC_cd_only', 'fb_only'
    x_field,         # e.g. 'CD_proj', 'FB_proj', 'CD_FB_proj', 'spikes_smth_40', 'PCA_40'
    dataset,
    trial_mask,
    cond_dict,
    lag_axis,
    pred_range,
    y_field='hand_vel',
    norm_x=True,
    pos_bool=False,
    n_cd_dims=0,
    n_splits=20,
):
    """
    Runs fit_and_predict_MC over all lags for a given feature set x_field.
    Uses your existing fit_and_predict_MC signature exactly.
    """
    dim   = dataset.data[x_field].shape[1]         # # of features
    y_dim = dataset.data[y_field].shape[1]         # e.g. 2 for x,y velocity

    r2_array_MC      = np.full((len(lag_axis), n_splits),        np.nan)
    r2_feature_array = np.full((len(lag_axis), n_splits, y_dim), np.nan)
    coef_array       = np.full((len(lag_axis), y_dim, dim),      np.nan)

    for i, lag in enumerate(lag_axis):
        print(f"{name}: lag {lag} ms ({i+1}/{len(lag_axis)})")

        r2, coef, _, vel_df, r2_arr = fit_and_predict_MC(
            dataset,
            trial_mask,
            align_field='move_onset_time',
            align_range=pred_range,
            lag=lag,
            x_field=x_field,
            y_field=y_field,
            norm_x=norm_x,
            pos_bool=pos_bool,
            split_pred=False,     # your default; keep as-is
            n_cd_dims=n_cd_dims,  # varies per model
            n_splits=n_splits,    # explicit, same as array size
            cond_dict=cond_dict,
        )

        r2_array_MC[i, :]        = r2
        r2_feature_array[i, ...] = r2_arr
        coef_array[i, ...]       = coef

    results = {}
    # feature 0 -> x, feature 1 -> y
    results[f"x_r2_{name}"]   = r2_feature_array[:, :, 0]
    results[f"y_r2_{name}"]   = r2_feature_array[:, :, 1]
    results[f"r2_{name}"]     = r2_array_MC
    results[f"{name}_coefs"]  = coef_array

    return results


In [None]:
condition_configs = [
    {
        "cond_name":  "early_act",
        "trial_mask": active_mask,
        "cond_dict":  active_cond_dict,
        "pred_range": (-100, 120),
    },
    {
        "cond_name":  "act",
        "trial_mask": active_mask,
        "cond_dict":  active_cond_dict,
        "pred_range": (-100, 1000),
    },
    {
        "cond_name":  "pas",
        "trial_mask": passive_mask,
        "cond_dict":  passive_cond_dict,
        "pred_range": (-100, 120),
    },
]

In [None]:
model_configs = [
    #  name          x_field           pos_bool   n_cd_dims
    ("cd_only",     "CD_proj",        False,     0),
    ("SC_cd_only",  "CD_proj",        True,      0),

    ("fb_only",     "FB_proj",        False,     0),
    ("SC_fb_only",  "FB_proj",        True,      0),

    ("cd_fb",       "CD_FB_proj",     False,     n_cd_dims),
    ("SC_cd_fb",    "CD_FB_proj",     True,      n_cd_dims),

    ("nrn",         "spikes_smth_40", False,     0),
    ("pc",          "PCA_40",         False,     0),
]

In [None]:
for cond_cfg in condition_configs:
    cond_name        = cond_cfg["cond_name"]
    trial_mask       = cond_cfg["trial_mask"]
    cond_dict        = cond_cfg["cond_dict"]
    pred_range_cond  = cond_cfg["pred_range"]

    print(f"\n=== Running condition: {cond_name} ===")
    all_results = {}

    for (name, x_field_i, pos_bool_i, n_cd_dims_i) in model_configs:

        res = run_decoding_for_feature(
            name=name,
            x_field=x_field_i,
            dataset=dataset,
            trial_mask=trial_mask,
            cond_dict=cond_dict,
            lag_axis=lag_axis,
            pred_range=pred_range_cond,   # ðŸ”¥ condition controls pred_range
            y_field=y_field,
            norm_x=norm_x,
            pos_bool=pos_bool_i,
            n_cd_dims=n_cd_dims_i,
            n_splits=n_splits,
        )
        all_results.update(res)

    save_name = f"{monkey}_MC_norm_zscore_smooth40_spikes_{y_field}_{cond_name}_r2s.npz"
    np.savez(save_name, **all_results)
    print("Saved:", save_name)

In [None]:
import os

LAG_AXIS = np.arange(-300, 320, 20) 
LW = 3

def load_decoding_results(monkey, y_field, cond_name,
                          base_dir='.', 
                          prefix='MC_norm_zscore_smooth40_spikes'):
    """
      f"{monkey}_{prefix}_{y_field}_{cond_name}_r2s.npz"
    """
    fname = f"{monkey}_{prefix}_{y_field}_{cond_name}_r2s.npz"
    path = os.path.join(base_dir, fname)
    data = np.load(path)
    print("Loaded:", path)
    return data

def plot_r2_for_model(data, model_name, lag_axis=LAG_AXIS, color='brown', label=None,
                      ax=None, lw=LW, alpha_fill=0.3, print_stats=True):
    if label is None:
        label = model_name

    if ax is None:
        fig, ax = plt.subplots(figsize=(5.5, 4))
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

    r2 = data[f"r2_{model_name}"]          # shape [n_lags, n_splits]
    mean_r2 = np.nanmean(r2, axis=1)
    std_r2  = np.nanstd(r2, axis=1)

    ax.plot(lag_axis, mean_r2, linewidth=lw, color=color, label=label)
    ax.fill_between(lag_axis, mean_r2 - std_r2, mean_r2 + std_r2,
                    color=color, alpha=alpha_fill)

    if print_stats:
        best_idx = np.nanargmax(mean_r2)
        best_lag = lag_axis[best_idx]
        best_r2  = mean_r2[best_idx]
        print(f"{model_name}: best RÂ² = {best_r2:.3f} at lag {best_lag} ms")

        for frac in [0.95, 0.9, 0.8, 0.7]:
            good_r2 = best_r2 * frac
            good_idx = np.where(mean_r2 >= good_r2)[0]
            if len(good_idx) > 0:
                good_lag = lag_axis[good_idx[0]]
                print(f"  {frac*100:.0f}% of max (RÂ²={good_r2:.3f}) first at lag {good_lag} ms")
            else:
                print(f"  {frac*100:.0f}% of max: no lag reaches this threshold")

    return ax

def get_mean_std(data, key_prefix, model_name):
    """
    key_prefix: 'r2_' / 'x_r2_' / 'y_r2_'
    """
    arr = data[f"{key_prefix}{model_name}"]   # [n_lags, n_splits]
    mean = np.nanmean(arr, axis=1)
    std  = np.nanstd(arr, axis=1)
    return mean, std

def plot_r2_comparison_overall(data, lag_axis=LAG_AXIS, title=None, ylim=(-0.1, 0.9),
                               ax=None, lw=LW):
    """
    RÂ²: cd / fb / cd+fb / nrn / pc
    key: 'r2_*'
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(5.5, 4))

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # (model_name, color, linestyle, label, alpha_fill)
    models = [
        ('cd_only', 'green',   '-', 'SC_cd',      0.3),
        ('fb_only', 'magenta', '-', 'SC_fb',      0.3),
        ('cd_fb',   'brown',   '-', 'cd+fb',   0.3),
        ('nrn',     'grey',    '--','neurons', 0.5),
        ('pc',      'lightgrey','--','PCs',    0.5),
    ]

    for model_name, color, ls, label, alpha_fill in models:
        mean_r2, std_r2 = get_mean_std(data, 'r2_', model_name)
        ax.plot(lag_axis, mean_r2, linewidth=lw, color=color,
                linestyle=ls, label=label)
        ax.fill_between(lag_axis, mean_r2 - std_r2, mean_r2 + std_r2,
                        color=color, alpha=alpha_fill)

        print(f"[overall] {model_name}: max RÂ² = {np.nanmax(mean_r2):.3f} at lag {lag_axis[np.nanargmax(mean_r2)]} ms")

    ax.axvline(0, color='k', linestyle='--')
    ax.set_xlabel('Time lag (ms)')
    ax.set_ylabel('RÂ²')
    ax.set_ylim(*ylim)
    if title is not None:
        ax.set_title(title)
    ax.legend(fontsize=8)
    plt.tight_layout()
    return ax

def plot_r2_comparison_xy(data, comp='x', lag_axis=LAG_AXIS, title=None,
                          ylim=(-0.1, 0.9), ax=None, lw=LW):
    """
    comp: 'x' or 'y'
    key: 'x_r2_*' / 'y_r2_*'
    """
    assert comp in ('x', 'y')
    prefix = f"{comp}_r2_"

    if ax is None:
        fig, ax = plt.subplots(figsize=(5.5, 4))

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    models = [
        ('cd_only', 'green',   '-', 'SC_cd',      0.3),
        ('fb_only', 'magenta', '-', 'SC_fb',      0.3),
        ('cd_fb',   'brown',   '-', 'cd+fb',   0.3),
        ('nrn',     'grey',    '--','neurons', 0.5),
        ('pc',      'lightgrey','--','PCs',    0.5),
    ]

    for model_name, color, ls, label, alpha_fill in models:
        mean_r2, std_r2 = get_mean_std(data, prefix, model_name)
        ax.plot(lag_axis, mean_r2, linewidth=lw, color=color,
                linestyle=ls, label=label)
        ax.fill_between(lag_axis, mean_r2 - std_r2, mean_r2 + std_r2,
                        color=color, alpha=alpha_fill)

        print(f"[{comp.upper()}] {model_name}: max RÂ² = {np.nanmax(mean_r2):.3f} at lag {lag_axis[np.nanargmax(mean_r2)]} ms")

    ax.axvline(0, color='k', linestyle='--')
    ax.set_xlabel('Time lag (ms)')
    ax.set_ylabel(f"{comp.upper()} RÂ²")
    ax.set_ylim(*ylim)
    if title is not None:
        ax.set_title(title)
    # ax.legend(fontsize=8)
    plt.tight_layout()
    return ax

def plot_r2_SC_vs_nonSC(data, lag_axis=LAG_AXIS, title=None,
                        ylim=(-0.1, 0.9), ax=None, lw=LW):
    """
    compare cd_only vs SC_cd_only, fb_only vs SC_fb_only, cd_fb vs SC_cd_fb
    'r2_*' total RÂ²
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(5.5, 4))

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    pair_models = [
        ('cd_only',   'SC_cd_only',  'green',   'cd',   'SC cd'),
        ('fb_only',   'SC_fb_only',  'magenta', 'fb',   'SC fb'),
        ('cd_fb',     'SC_cd_fb',    'brown',   'cd+fb','SC cd+fb'),
    ]

    for base, sc, color, base_label, sc_label in pair_models:
        mean_base, std_base = get_mean_std(data, 'r2_', base)
        mean_sc,   std_sc   = get_mean_std(data, 'r2_', sc)

        # Non-SCï¼šlighter
        ax.plot(lag_axis, mean_base, linewidth=lw, color=color,
                linestyle='--', label=base_label, alpha=0.4)
        ax.fill_between(lag_axis, mean_base - std_base, mean_base + std_base,
                        color=color, alpha=0.1)

        # SCï¼šsolid
        ax.plot(lag_axis, mean_sc, linewidth=lw, color=color,
                linestyle='-', label=sc_label, alpha=0.9)
        ax.fill_between(lag_axis, mean_sc - std_sc, mean_sc + std_sc,
                        color=color, alpha=0.3)

        print(f"[SC compare] {base}: max {np.nanmax(mean_base):.3f} at {lag_axis[np.nanargmax(mean_base)]} ms")
        print(f"[SC compare] {sc}:   max {np.nanmax(mean_sc):.3f} at {lag_axis[np.nanargmax(mean_sc)]} ms")

    ax.axvline(0, color='k', linestyle='--')
    ax.set_xlabel('Time lag (ms)')
    ax.set_ylabel('RÂ²')
    ax.set_ylim(*ylim)
    if title is not None:
        ax.set_title(title)
    ax.legend(fontsize=8)
    plt.tight_layout()
    return ax


In [None]:
def plot_condition_overview(monkey, y_field, cond_name,
                            base_dir='.',
                            fig_dir='.',      
                            save_fig=False):
    """
      1. overall RÂ² (cd / fb / cd+fb / nrn / pc)
      2. X RÂ²
      3. Y RÂ²
      4. SC vs non-SC
    """
    data = load_decoding_results(monkey, y_field, cond_name, base_dir=base_dir)
    title_suffix = f"{monkey}_{y_field}_{cond_name}"
    
    ylim_overall = (-0.1, 0.9)

    # 1. overall
    fig = plot_r2_comparison_overall(
        data, LAG_AXIS,
        title=f"{monkey} {y_field} {cond_name} overall",
        ylim=ylim_overall
    )
    if save_fig:
        fname = os.path.join(fig_dir, f"{title_suffix}_overall_r2.pdf")
        plt.savefig(fname, dpi='figure', bbox_inches='tight')
        print("Saved:", fname)
    plt.show()

    # 2. X RÂ²
    fig = plot_r2_comparison_xy(
        data, comp='x', lag_axis=LAG_AXIS,
        title=f"{monkey} {y_field} {cond_name} X",
        ylim=ylim_overall
    )
    if save_fig:
        fname = os.path.join(fig_dir, f"{title_suffix}_x_r2.pdf")
        plt.savefig(fname, dpi='figure', bbox_inches='tight')
        print("Saved:", fname)
    plt.show()

    # 3. Y RÂ²
    fig = plot_r2_comparison_xy(
        data, comp='y', lag_axis=LAG_AXIS,
        title=f"{monkey} {y_field} {cond_name} Y",
        ylim=ylim_overall
    )
    if save_fig:
        fname = os.path.join(fig_dir, f"{title_suffix}_y_r2.pdf")
        plt.savefig(fname, dpi='figure', bbox_inches='tight')
        print("Saved:", fname)
    plt.show()

    # 4. SC vs non-SC
    fig = plot_r2_SC_vs_nonSC(
        data, lag_axis=LAG_AXIS,
        title=f"{monkey} {y_field} {cond_name} SC vs non-SC",
        ylim=ylim_overall
    )
    if save_fig:
        fname = os.path.join(fig_dir, f"{title_suffix}_SC_vs_nonSC_r2.pdf")
        plt.savefig(fname, dpi='figure', bbox_inches='tight')
        print("Saved:", fname)
    plt.show()

In [None]:
# monkey = "Lando_20170731"
# monkey = "Chips_20170913"

y_field  = 'hand_vel'
base_dir = '/Users/sherryan/area2_population_analysis'
fig_dir = '/Users/sherryan/Desktop/paper/'

# passive
plot_condition_overview(monkey, y_field, 'pas', base_dir=base_dir, fig_dir=fig_dir, save_fig=False)

# act
plot_condition_overview(monkey, y_field, 'early_act', base_dir=base_dir, fig_dir=fig_dir, save_fig=False)
 ## whole trial
plot_condition_overview(monkey, y_field, 'act', base_dir=base_dir, fig_dir=fig_dir, save_fig=False)


# # plot for a single condition, i.e. passive cd+fb:
# data_pas = load_decoding_results(monkey, y_field, 'pas')
# fig, ax = plt.subplots(figsize=(5.5, 4))
# ax.spines['right'].set_visible(False)
# ax.spines['top'].set_visible(False)
# plot_r2_for_model(data_pas, 'cd_fb', lag_axis=LAG_AXIS,
#                   color='brown', label='cd+fb', ax=ax)
# ax.set_xlabel('Time lag (ms)')
# ax.set_ylabel('RÂ²')
# ax.set_ylim([-0.1, 0.85])
# plt.tight_layout()
# plt.show()