# pseudo transfer entropy 

In [None]:
# -*- coding: utf-8 -*-

# Copyright (c) 2020 Riccardo Silini
# Adapted and modified from a MATLAB routine written by M. Chavez
# Please acknowledge and cite the use of this software and its authors
# when results are used in publications or published elsewhere.

"""Functions to compute pseudo transfer entropy (pTE).

This module provides a set of functions to compute pTE between different
time series.

Functions
---------------------

  * normalisa -- L2 normalization, can be replaced by the 
    sklearn.preprocessing.normalize(*args) function
  * embed -- generates matrices containing segments of the original time
    series, depending on the embedding size chosen.
  * timeshifted -- creeates time shifted surrogates. The sign on the shift means
    that the time series that must be shifted is the independent one
  * pTE -- Computes the pseudo transfer entropy between time series.

Libraries required
---------------------
import numpy as np
import scipy.signal as sps
from collections import deque

"""


def normalisa(a, order=2, axis=-1):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


def embed(x, embd, lag):
    N = len(x)
    hidx = np.arange(embd * lag, step=lag)
    vidx = np.arange(N - (embd - 1) * lag)
    vidx = vidx.T
    Nv = len(vidx)
    U = np.array([x, ] * embd)
    W = np.array([hidx, ] * Nv).T + np.array([vidx, ] * embd)
    u = np.zeros((embd, Nv))
    for i in range(embd):
        for j in range(Nv):
            u[i, j] = U[i, W[i, j]]
    return u.T


def timeshifted(timeseries, shift):
    ts = deque(timeseries)      
    ts.rotate(shift)
    return np.asarray(ts)   


def pTE(z, tau=1, dimEmb=1):
    
    """Returns pseudo transfer entropy.

    Parameters
    ----------
    z : array
        array of arrays, containing all the time series.
    tau : integer
        delay of the embedding.  
    dimEMb : integer
        embedding dimension, or model order.       

    Returns
    -------
    pte : array
        array of arrays. The dimension is (# time series, # time series). 
        The diagonal is 0, while the off diagonal term (i, j) corresponds
        to the pseudo transfer entropy from time series i to time series j.
    """

    NN, T = np.shape(z)
    Npairs = NN * (NN - 1)
    pte = np.zeros((NN, NN))
    z = normalisa(sps.detrend(z))
    channels = np.arange(NN, step=1)

    for i in channels:
        EmbdDumm = embed(z[i], dimEmb + 1, tau)
        Xtau = EmbdDumm[:, :-1]
        for j in channels:
            if i != j:
                Yembd = embed(z[j], dimEmb + 1, tau)
                Y = Yembd[:, -1]
                Ytau = Yembd[:, :-1]
                XtYt = np.concatenate((Xtau, Ytau), axis=1)
                YYt = np.concatenate((Y[:, np.newaxis], Ytau), axis=1)
                YYtXt = np.concatenate((YYt, Xtau), axis=1)

                if dimEmb > 1:
                    ptedum = np.linalg.det(np.cov(XtYt.T)) * np.linalg.det(np.cov(YYt.T)) / (
                            np.linalg.det(np.cov(YYtXt.T)) * np.linalg.det(np.cov(Ytau.T)))
                else:
                    ptedum = np.linalg.det(np.cov(XtYt.T)) * np.linalg.det(np.cov(YYt.T)) / (
                            np.linalg.det(np.cov(YYtXt.T)) * np.cov(Ytau.T))
 
                pte[i, j] = 0.5 * np.log(ptedum)

    TXY_ = pte            

    """ WORK IN PROGRESS : the following part is use to deal with fake causalities arising from 3+ processes
                           systems """
    
    if np.sum(Fs) == 3 and np.linalg.det(Fs) == 0:

        k = np.argwhere(np.sum(Fs, axis=1)==2)
        j = np.argwhere(np.sum(Fs, axis=1)==1)
        l = np.argwhere(np.sum(Fs, axis=1)==0)
        if len(k)!=0 and len(j)!=0 and len(l)!=0:
            for idx, i in enumerate(Fs):
                indexes = np.where(i==1)[0]
                if len(indexes) > 0:
                    pairs = list(itertools.combinations(indexes, 2))
                    for pair in pairs:
                        indice1 = np.where(np.sum(Fs, axis = 1) == 2)[0]
                        indice2 = np.where(np.sum(Fs, axis = 0) == 2)[0]
                        if Fs[pair] == 1:
                            TXY_temp = np.multiply(TXY, Fs)
                            exponent = TXY_temp[k,l]/TXY_temp[j,l] - 1
                            if np.abs(exponent)>0.5:
                                ratio = (TXY_temp[pair[0], idx]/TXY_temp[idx, pair[0]])**(2*np.sign(exponent))
                                if ratio<1:
                                    TXY_[pair] = TXY[pair] * ratio
                                if ratio >= 1:
                                    ratio2 = (TXY_temp[indice1, pair[0]]/TXY_temp[pair[0], indice1])**(2*np.sign(exponent))
                                    if ratio2<1:    
                                        TXY_[indice1, indice2] = TXY[indice1, indice2] * ratio2
                        if Fs[pair[::-1]] == 1:
                            TXY_temp = np.multiply(TXY, Fs)
                            exponent = TXY_temp[k,l]/TXY_temp[j,l] - 1
                            if np.abs(exponent)>0.5:
                                ratio = (TXY_temp[pair[1], idx]/TXY_temp[idx, pair[1]])**(2*np.sign(exponent))
                                if ratio<1:
                                    TXY_[pair[::-1]] = TXY[pair[::-1]] * ratio
                                if ratio >= 1:  
                                    ratio2 = (TXY[indice1, pair[1]]/TXY[pair[1], indice1])**(2*np.sign(exponent))
                                    if ratio2<1:
                                        TXY_[indice2, indice1] = TXY[indice2, indice1] * ratio2
    pte = TXY_                     
    return pte

