In [17]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.ndimage import gaussian_filter1d
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LogisticRegression

from typing import Any, Dict, List

from utils import set_params
from utils import load_pickle, extract_used_data
from utils import pca_fit

from utils.config import Params

In [18]:

def baseline_z_trials(X, eps=1e-9):
    mu = X.mean(axis=1, keepdims=True)
    sd = X.std(axis=1, ddof=1, keepdims=True)
    return (X - mu) / (sd + eps)

def fit_pca_transform(Xtr, Xte, k=8, random_state=0):
    pca = PCA(n_components=k, svd_solver='full', random_state=random_state)
    Ztr = pca.fit_transform(Xtr)
    Zte = pca.transform(Xte)
    return Ztr, Zte, pca

def make_dct_basis(Tp, k=8):
    n = np.arange(Tp)
    B = [np.ones(Tp)/np.sqrt(Tp)]
    for j in range(1, k):
        B.append(np.sqrt(2/Tp)*np.cos(np.pi*(n+0.5)*j/Tp))
    return np.stack(B, axis=1)  # (Tp,k)

def basis_project(X, B):  # X:(n,Tp), B:(Tp,k)
    return X @ B

def energy_distance_stat(X, Y):
    d_xx = cdist(X, X)
    d_yy = cdist(Y, Y)
    d_xy = cdist(X, Y)
    return 2*np.mean(d_xy) - np.mean(d_xx) - np.mean(d_yy)

def rbf_mmd2_stat(X, Y, gamma=None):
    XY = np.vstack([X, Y])
    if gamma is None:
        d = cdist(XY, XY)
        med = np.median(d[np.triu_indices_from(d, k=1)])
        gamma = 1.0 / (2*(med**2 + 1e-12))
    Kxx = np.exp(-gamma*cdist(X, X)**2); np.fill_diagonal(Kxx, 0)
    Kyy = np.exp(-gamma*cdist(Y, Y)**2); np.fill_diagonal(Kyy, 0)
    Kxy = np.exp(-gamma*cdist(X, Y)**2)
    n, m = len(X), len(Y)
    mmd2 = (Kxx.sum()/(n*(n-1)) + Kyy.sum()/(m*(m-1)) - 2*Kxy.mean())
    return mmd2, gamma

def two_sample_permutation(XA, XB, proj='pca', k=8,
                           sigma_bins=5, step=31, n_perm=5000, random_state=0,
                           test='energy'):
    rng = np.random.default_rng(random_state)
    ZA = baseline_z_trials(XA)
    ZB = baseline_z_trials(XB)
    Tp = ZA.shape[1]
    
    def project_pair(A, B):
        if proj == 'pca':
            Z = np.vstack([A, B])
            Ztr, Zte, pca = fit_pca_transform(Z, Z, k=min(k, Z.shape[0]-2), random_state=rng.integers(1e9))
            A_p, B_p = Ztr[:len(A)], Ztr[len(A):]
            return A_p, B_p
        elif proj == 'dct':
            Bm = make_dct_basis(Tp, k=min(k, Tp))
            return basis_project(A, Bm), basis_project(B, Bm)
        else:
            raise ValueError("proj must be 'pca' or 'dct'")

    A_p, B_p = project_pair(ZA, ZB)

    if test == 'energy':
        stat_obs = energy_distance_stat(A_p, B_p)
    elif test == 'mmd':
        stat_obs, gamma = rbf_mmd2_stat(A_p, B_p, gamma=None)
    else:
        raise ValueError("test must be 'energy' or 'mmd'")

    X = np.vstack([ZA, ZB])
    y = np.r_[np.zeros(len(ZA), dtype=int), np.ones(len(ZB), dtype=int)]
    ge = 0
    for _ in range(n_perm):
        rng.shuffle(y)
        A = X[y == 0]; B = X[y == 1]
        A_pp, B_pp = project_pair(A, B)
        if test == 'energy':
            stat = energy_distance_stat(A_pp, B_pp)
        else:
            stat, _ = rbf_mmd2_stat(A_pp, B_pp, gamma=None)
        if stat >= stat_obs:
            ge += 1
    pval = (ge + 1) / (n_perm + 1)
    out = {"stat": stat_obs, "p": pval, "proj": proj, "k": k, "Tp": Tp, "test": test}
    if test == 'mmd':
        out["note"] = "MMD核宽用中位数距离启发式（每次置换内重估）"
    return out


def decode_cv_permutation(XA, XB, proj='pca', k=8,
                          sigma_bins=5, step=31, clf_type='lda',
                          n_perm=2000, random_state=0):
    rng = np.random.default_rng(random_state)
    ZA = baseline_z_trials(XA)
    ZB = baseline_z_trials(XB)
    Xall = np.vstack([ZA, ZB])
    yall = np.r_[np.zeros(len(ZA), dtype=int), np.ones(len(ZB), dtype=int)]

    skf = StratifiedKFold(n_splits=len(Xall))  # LOOCV
    y_true, y_score = [], []

    for tr, te in skf.split(Xall, yall):
        Xtr, Xte = Xall[tr], Xall[te]
        ytr, yte = yall[tr], yall[te]

        if proj == 'pca':
            Ztr, Zte, _ = fit_pca_transform(Xtr, Xte, k=min(k, len(tr)-2), random_state=rng.integers(1e9))
        else:
            Bm = make_dct_basis(Xtr.shape[1], k=min(k, Xtr.shape[1]))
            Ztr, Zte = basis_project(Xtr, Bm), basis_project(Xte, Bm)

        if clf_type == 'lda':
            clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
        else:
            clf = LogisticRegression(penalty='l2', C=1.0, solver='liblinear', max_iter=2000)

        clf.fit(Ztr, ytr)
        y_true.append(yte[0])
        y_score.append(clf.predict_proba(Zte)[0, 1])

    auc_obs = roc_auc_score(y_true, y_score)

    ge = 0
    for _ in range(n_perm):
        yperm = yall.copy(); rng.shuffle(yperm)
        y_true_p, y_score_p = [], []
        for tr, te in skf.split(Xall, yperm):
            Xtr, Xte = Xall[tr], Xall[te]
            ytr, yte = yperm[tr], yperm[te]
            if proj == 'pca':
                Ztr, Zte, _ = fit_pca_transform(Xtr, Xte, k=min(k, len(tr)-2), random_state=rng.integers(1e9))
            else:
                Bm = make_dct_basis(Xtr.shape[1], k=min(k, Xtr.shape[1]))
                Ztr, Zte = basis_project(Xtr, Bm), basis_project(Xte, Bm)
            if clf_type == 'lda':
                clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
            else:
                clf = LogisticRegression(penalty='l2', C=1.0, solver='liblinear', max_iter=2000)
            clf.fit(Ztr, ytr)
            y_true_p.append(yte[0])
            y_score_p.append(clf.predict_proba(Zte)[0,1])
        auc_p = roc_auc_score(y_true_p, y_score_p)
        if auc_p >= auc_obs:
            ge += 1
    pval = (ge + 1) / (n_perm + 1)
    return {"auc": auc_obs, "p": pval, "proj": proj, "k": k, "clf": clf_type}

In [19]:
data_path = "../../data/"
results_path = "../../results/"
sub_directory = 'flexible_shift'

params = set_params(tt_preset='basic',
                         bt_preset='basic',
                         data_path=data_path,
                         results_path=results_path,
                         sub_directory=sub_directory,
                         pca_n_components=20,
                         len_pos_average=30)

In [20]:
data = load_pickle('../../data/flexible_shift/RDH01-PFCsep2.pkl')
data = extract_used_data(data)

In [21]:
trials = []

for i in range(data['shape'][0]):
    for j in range(data['shape'][1]):
        
        if data['simple_firing'][i,j] is None:
            continue
        
        shape_ij = data['simple_firing'][i,j].shape
        num_trails = shape_ij[1]
        
        for k in range(num_trails):
            
            trial_data = {
                'trial_type': params.tt[i],
                'behavior_type': params.bt[j],
                'firing': data['simple_firing'][i,j][:,k,:],
                'index': data['type_index'][i,j][k],
                'lick': data['pos_lick_type'][i,j][k],
                'reward': data['pos_reward_type'][i,j][k]
            }
            
            trials.append(trial_data)
            
trials.sort(key=lambda t: t['index'])

num_neurons = trials[0]['firing'].shape[0]
len_position = trials[0]['firing'].shape[1]

In [22]:
trail_type_A = 'pattern_ACB'
trail_type_B = 'position_ACB'

trial_A = [t for t in trials if t['trial_type'] == trail_type_A and t['behavior_type'] == 'correct']
trial_B = [t for t in trials if t['trial_type'] == trail_type_B and t['behavior_type'] == 'correct']

In [None]:
# sigma_bins = 5 
# step = len_position // 100
# k = 6

# all_results = []

# for neuron_id in range(num_neurons):
    
#     firing_A = np.array([t['firing'][neuron_id,:] for t in trial_A])
#     firing_B = np.array([t['firing'][neuron_id,:] for t in trial_B])
    
#     res_energy = two_sample_permutation(
#         firing_A, firing_B, proj='pca', k=k, sigma_bins=sigma_bins, step=step,
#         n_perm=5000, test='energy', random_state=0
#     )

#     # res_mmd = two_sample_permutation(
#     #     firing_A, firing_B, proj='dct', k=k, sigma_bins=sigma_bins, step=step,
#     #     n_perm=5000, test='mmd', random_state=1
#     # )


#     all_results.append({
#         "neuron": neuron_id,
#         "energy": res_energy,
#     })
    

KeyboardInterrupt: 

In [None]:
for idx in range(num_neurons):
    
    p = all_results[idx]["energy"]["p"]\
    
    if p < 0.05:
        print(all_results[idx])

{'neuron': 1, 'energy': {'stat': 2.0833719702153797, 'p': 0.011797640471905619, 'proj': 'pca', 'k': 6, 'Tp': 103, 'test': 'energy'}, 'mmd': {'stat': 0.09837599317764667, 'p': 0.013197360527894421, 'proj': 'dct', 'k': 6, 'Tp': 103, 'test': 'mmd', 'note': 'MMD核宽用中位数距离启发式（每次置换内重估）'}}
{'neuron': 2, 'energy': {'stat': 3.7487614421727393, 'p': 0.0001999600079984003, 'proj': 'pca', 'k': 6, 'Tp': 103, 'test': 'energy'}, 'mmd': {'stat': 0.1950248488926687, 'p': 0.0001999600079984003, 'proj': 'dct', 'k': 6, 'Tp': 103, 'test': 'mmd', 'note': 'MMD核宽用中位数距离启发式（每次置换内重估）'}}
{'neuron': 3, 'energy': {'stat': 2.271144348567786, 'p': 0.007598480303939212, 'proj': 'pca', 'k': 6, 'Tp': 103, 'test': 'energy'}, 'mmd': {'stat': 0.11481643671641595, 'p': 0.004399120175964807, 'proj': 'dct', 'k': 6, 'Tp': 103, 'test': 'mmd', 'note': 'MMD核宽用中位数距离启发式（每次置换内重估）'}}
{'neuron': 5, 'energy': {'stat': 2.063780466977871, 'p': 0.008198360327934412, 'proj': 'pca', 'k': 6, 'Tp': 103, 'test': 'energy'}, 'mmd': {'stat': 0.1137

In [None]:
print(all_results[0])

{'neuron': 0, 'energy': {'stat': 1.8935247970913824, 'p': 0.05158968206358728, 'proj': 'pca', 'k': 6, 'Tp': 103, 'test': 'energy'}, 'mmd': {'stat': 0.05463609204252262, 'p': 0.05578884223155369, 'proj': 'dct', 'k': 6, 'Tp': 103, 'test': 'mmd', 'note': 'MMD核宽用中位数距离启发式（每次置换内重估）'}}
