In [3]:
import numpy as np
import mne
import scipy.sparse as sparse
from scipy import special
from mne.stats.cluster_level import _find_clusters as find_clusters
from toolz.functoolz import curry
from scipy.stats._hypotests import _batch_generator as batch_generator
from pickle import dump

def edge_adjacency(n=116):
    inds = np.arange(n, dtype=int)
    nan_inds = np.tril_indices(n)
    a, b = np.meshgrid(inds, inds)
    edge_pairs = np.stack([a[nan_inds], b[nan_inds]])
    edge_pairs = edge_pairs.astype(np.uint16).T
    edge_inds = np.arange(edge_pairs.shape[0], dtype=np.uint16)
    a, b = np.meshgrid(edge_inds, edge_inds)
    a, b = edge_pairs[a], edge_pairs[b]
    adj = (a[..., 0] == b[..., 0]) | (a[..., 0] == b[..., 1]) | (a[..., 1] == b[..., 0]) | (a[..., 1] == b[..., 1])
    a, b = None, None
    adj = sparse.coo_matrix(adj)
    return adj

def r_ppf(q, n):
    ab = n/2 - 1
    return -(special.btdtri(ab, ab, q/2)*2 - 1)

def r_pdf(r, n):
    ab = n/2 - 1
    return 2*special.btdtr(ab, ab, 0.5*(1 - abs(np.float64(r))))

def pearsonr_no_p(x, y, axis):
    assert len(x.shape) == len(y.shape)
    x_m = x.mean(axis, keepdims=True)
    y_m = y.mean(axis, keepdims=True)
    x_sq = np.sum(np.square(x - x_m), axis=axis)
    y_sq = np.sum(np.square(y - y_m), axis=axis)
    num = np.sum((x - x_m)*(y - y_m), axis=axis)
    denum = np.sqrt(x_sq*y_sq)
    res = num/denum
    res[np.isnan(res)] = 0
    res = np.maximum(np.minimum(res, 1), -1)
    return res

def clean_clusters(clusters, sums, fun=lambda x, y: True):
    new_clusters = []
    new_sums = []
    for ind, cl in enumerate(clusters):
        if len(cl) >= 2 and fun(cl, sums[ind]):
            new_clusters.append(cl)
            new_sums.append(sums[ind])
    return new_clusters, new_sums

def cluster_inds(clusters, perminds):
    clus_inds = []
    for cl in clusters:
        temp = perminds[cl[0]]
        clus_inds.append(temp)
    return np.array(clus_inds)

def calc_pval(arr, y):
    denum = len(arr)
    sign = np.sign(y)
    if sign < 0:
        num = np.sum(np.abs(arr[arr < 0]) <= np.abs(y))
    else:
        num = np.sum(np.abs(arr[arr > 0]) <= np.abs(y))
    return num/denum

@curry
def compare_tail(arr, crit, x, y):
    return calc_pval(arr, y) <= crit

def perm_test(x, y, threshold, adj, permutations=1000, batch_num=250, tail=0, clusteralpha=0.05):
    print(x.shape, y.shape)
    n_samples = x.shape[0]
    inds = np.arange(n_samples, dtype=int)
    perms = [np.random.permutation(inds) for i in range(permutations)]
    observed_stat = pearsonr_no_p(x, y, axis=0)
    null_dist = []
    for batch in batch_generator(perms, batch_num):
        perms = len(batch)
        perms_adj = sparse.coo_array((perms, perms), dtype=int)
        temp_adj = mne.stats.combine_adjacency(perms_adj, adj)

        batch = np.array(batch)
        chunk = x[batch]
        null_part = pearsonr_no_p(chunk, y[None, :], axis=1)
        perms_inds = np.arange(null_part.shape[0]).repeat(null_part.shape[1])

        clus, sums = find_clusters(null_part.flatten(), threshold=threshold, adjacency=temp_adj, tail=tail)
        clus, sums = clean_clusters(clus, sums)
        clus_inds = cluster_inds(clus, perms_inds)

        max_sums = np.empty(clus_inds.max() + 1)
        np.maximum.at(max_sums,clus_inds,sums)

        null_dist += max_sums[clus_inds].tolist()

    null_dist = np.array(null_dist).flatten()

    clus, sums = find_clusters(observed_stat, threshold=threshold, adjacency=adj, tail=tail)
    fun = compare_tail(null_dist, clusteralpha)
    clus, sums = clean_clusters(clus, sums, fun=fun)
    pvals = [calc_pval(null_dist, s) for s in sums]

    return clus, pvals

PATH_TO_DATA = None # Путь к данным
PATH_TO_CONN = None # Путь к данным с функциональными сетями
mbrt = np.load(PATH_TO_DATA)
conn = np.load(PATH_TO_CONN)

adj = edge_adjacency(116)

conn_flat = conn[np.tril_indices(116)].T
stat = perm_test(mbrt, conn_flat, r_ppf(0.025, 12), adj, permutations=1000, batch_num=250, tail=0)
f = open('stats.pickle', 'wb')
dump(stat, f)
f.close()