# pseudo transfer entropy 

In [None]:
# Copyright (C) 2020 R. Silini, M. Chavez
#
#
# Please acknowledge and cite the use of this software and its authors
# when results are used in publications or published elsewhere.


def normalisa(a, order=2, axis=-1):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


def embed(x, embd, lag):
    N = len(x)
    hidx = np.arange(embd * lag, step=lag)
    vidx = np.arange(N - (embd - 1) * lag)
    vidx = vidx.T
    Nv = len(vidx)
    U = np.array([x, ] * embd)
    W = np.array([hidx, ] * Nv).T + np.array([vidx, ] * embd)
    u = np.zeros((embd, Nv))
    for i in range(embd):
        for j in range(Nv):
            u[i, j] = U[i, W[i, j]]
    return u.T

def PseudoTransferEntropy(z, tau, dimEmb, pvalue):

    NN, T = np.shape(z)
    Npairs = NN * (NN - 1)
    TXY = np.zeros((NN, NN))
    z = sps.detrend(z)
    z = normalisa(z)

    if NN > 2:
        thresholdF = ff.ppf(1.0 - pvalue / Npairs, dimEmb, T - 2 * dimEmb)
        Fstat = np.zeros((NN, NN))
        Fs = np.zeros((NN, NN))
    
    channels = np.arange(NN, step=1)

    for i in channels:
        EmbdDumm = embed(z[i], dimEmb + 1, tau)
        Xtau = EmbdDumm[:, :-1]
        for j in channels:
            if i != j:
                Yembd = embed(z[j], dimEmb + 1, tau)
                Y = Yembd[:, -1]
                Ytau = Yembd[:, :-1]
                XtYt = np.concatenate((Xtau, Ytau), axis=1)
                YYt = np.concatenate((Y[:, np.newaxis], Ytau), axis=1)
                YYtXt = np.concatenate((YYt, Xtau), axis=1)

                if dimEmb > 1:
                    TXYdum = np.linalg.det(np.cov(XtYt.T)) * np.linalg.det(np.cov(YYt.T)) / (
                            np.linalg.det(np.cov(YYtXt.T)) * np.linalg.det(np.cov(Ytau.T)))
                    if NN > 2:
                        RSSc = (T - dimEmb) * np.linalg.det(np.cov(YYt.T)) / np.linalg.det(np.cov(Ytau.T))
                        RSSu = (T - 2 * dimEmb) * np.linalg.det(np.cov(YYtXt.T)) / np.linalg.det(np.cov(XtYt.T))
                        Fstat[i, j] = ((T - 2 * dimEmb) / dimEmb) * (RSSc - RSSu) / RSSu
                        if Fstat[i,j] > thresholdF:
                            Fs[i,j] = 1
                else:
                    TXYdum = np.linalg.det(np.cov(XtYt.T)) * np.linalg.det(np.cov(YYt.T)) / (
                            np.linalg.det(np.cov(YYtXt.T)) * np.cov(Ytau.T))
                    if NN > 2:
                        RSSc = (T - dimEmb) * np.linalg.det(np.cov(YYt.T)) / np.cov(Ytau.T)
                        RSSu = (T - 2 * dimEmb) * np.linalg.det(np.cov(YYtXt.T)) / np.linalg.det(np.cov(XtYt.T))
                        Fstat[i, j] = ((T - 2 * dimEmb) / dimEmb) * (RSSc - RSSu) / RSSu
                        if Fstat[i,j] > thresholdF:
                            Fs[i,j] = 1

                TXY[i, j] = 0.5 * np.log(TXYdum)
    TXY_ = TXY            

    if np.sum(Fs) == 3 and np.linalg.det(Fs) == 0:

        k = np.argwhere(np.sum(Fs, axis=1)==2)
        j = np.argwhere(np.sum(Fs, axis=1)==1)
        l = np.argwhere(np.sum(Fs, axis=1)==0)
        if len(k)!=0 and len(j)!=0 and len(l)!=0:
            for idx, i in enumerate(Fs):
                indexes = np.where(i==1)[0]
                if len(indexes) > 0:
                    pairs = list(itertools.combinations(indexes, 2))
                    for pair in pairs:
                        indice1 = np.where(np.sum(Fs, axis = 1) == 2)[0]
                        indice2 = np.where(np.sum(Fs, axis = 0) == 2)[0]
                        if Fs[pair] == 1:
                            TXY_temp = np.multiply(TXY, Fs)
                            exponent = TXY_temp[k,l]/TXY_temp[j,l] - 1
                            if np.abs(exponent)>0.5:
                                ratio = (TXY_temp[pair[0], idx]/TXY_temp[idx, pair[0]])**(2*np.sign(exponent))
                                if ratio<1:
                                    TXY_[pair] = TXY[pair] * ratio
                                if ratio >= 1:
                                    ratio2 = (TXY_temp[indice1, pair[0]]/TXY_temp[pair[0], indice1])**(2*np.sign(exponent))
                                    if ratio2<1:    
                                        TXY_[indice1, indice2] = TXY[indice1, indice2] * ratio2
                        if Fs[pair[::-1]] == 1:
                            TXY_temp = np.multiply(TXY, Fs)
                            exponent = TXY_temp[k,l]/TXY_temp[j,l] - 1
                            if np.abs(exponent)>0.5:
                                ratio = (TXY_temp[pair[1], idx]/TXY_temp[idx, pair[1]])**(2*np.sign(exponent))
                                if ratio<1:
                                    TXY_[pair[::-1]] = TXY[pair[::-1]] * ratio
                                if ratio >= 1:  
                                    ratio2 = (TXY[indice1, pair[1]]/TXY[pair[1], indice1])**(2*np.sign(exponent))
                                    if ratio2<1:
                                        TXY_[indice2, indice1] = TXY[indice2, indice1] * ratio2
    TXY = TXY_                   
    return TXY    