In [208]:
import numpy as np
import scipy

def gen_instance(n, latent_dist, random_seed=123):
    """ Generate an SCM instance.
    """
    np.random.seed(random_seed)

    A = np.random.uniform(.2, 1, size=(n,n))
    A = np.tril(A, k=-1)

    perm = np.arange(n)
    np.random.shuffle(perm)
    invperm = np.argsort(perm)
    
    # Some quantities used for debugging. These relate to the quantities
    # that ICA should be recovering (e.g. the orthogonal un-mixing matrix).
    U = latent_dist(100000, 1)
    k40 = np.ones(n) * ((np.mean(U**4) / np.mean(U**2)**2) - 3)
    VU = np.diag(np.ones(n) * np.var(U))
    
    W0 = np.linalg.inv((np.eye(n) - A)[:, perm])
    W0 = W0 @ (VU**(1/2))
    U, s, V = scipy.linalg.svd(W0, full_matrices=False)
    V0sqrtinv = U @ (U/s).T
    M0 = V0sqrtinv @ W0 # This is the orthogonal matrix in the decorrelated ICA: x=Mz

    return A, perm, invperm, M0, k40, VU

In [209]:
def gen_data(A, N, latent_dist, perm, random_seed=123):
    """ Generate data from the linear SCM
    """
    np.random.seed(random_seed)
    n = A.shape[1]
    U = latent_dist(N, n)
    X = U @ np.linalg.pinv(np.eye(n) - A).T
    return X[:, perm]

In [210]:
import matplotlib.pyplot as plt

def fast_ica(X, *, M0=None, k40=None, whiten='svd', n_iters=10):
    """ Implementation of Fast ICA based on Kurtosis
    """
    N, n = X.shape

    # PCA Step: de-correlate and normalize the data
    if whiten == 'eigh':
        V = X.T @ X / N
        D, U = scipy.linalg.eigh(V)
        Vsqrt = (U * np.sqrt(np.clip(D, 0, np.inf))) @ U.T
        Vsqrtinv = np.linalg.inv(Vsqrt)
    else:
        U, s, V = scipy.linalg.svd(X.T, full_matrices=False)
        Vsqrtinv = U @ (U/s).T * np.sqrt(X.shape[0])
    X = X @ Vsqrtinv.T

    # Initialize transformation matrix and orthogonal projection
    W = np.zeros((n, n))
    P = np.eye(n)
    XP = X @ P
    for d in range(n):
        w = np.random.normal(0, 1, size=n) # draw a random initial vector
        w = P @ w # project out directions already discovered
        w /= np.linalg.norm(w, ord=2) # normalize vector to unit norm
        if M0 is not None: # for debugging: look at vector in the "coefficient space"
            b = np.abs(M0.T @ w) * np.sqrt(np.abs(k40)) 
            i = np.argmax(b) # algorithm should be converging to this true column
            print(b, i)
            plt.scatter(np.arange(n), b)
            plt.show()
        for t in range(n_iters):
            w = XP.T @ (XP @ w)**3 / N - 3 * w # gradient update rule
            w = P @ w # project out already discovered directions
            w /= np.linalg.norm(w, ord=2) # normalize to unit norm
            if M0 is not None: # for debugging: look at vector in the "coefficient space"
                b = np.abs(M0.T @ w) * np.sqrt(np.abs(k40))
        if M0 is not None: # for debugging
            print(np.abs(M0.T @ w))
            plt.scatter(np.arange(n), np.abs(M0.T @ w))
            plt.show()
        W[d, :] = w # store vector found
        P -= w.reshape(-1, 1) @ w.reshape(1, -1) # update the orthogonal projection operator
        XP = XP @ P # remove this direction from the data

    return W @ Vsqrtinv # return transform with respect to pre-decorrelated data

In [212]:
import networkx as nx

def align_matrix(W):
    """ Find the row permutation such that all diagonal entries are non-zero.
    This aligns the matrix W such that it corresponds to the permuted lower
    triangular entry where the same permutation has been applied to rows and
    columns
    """
    n = W.shape[1]
    values = np.sort(np.abs(W).flatten())
    thr = values[(n * (n - 1) // 2) - 1]

    G = nx.Graph()
    edges = []
    for i in range(n):
        for j in range(n):
            edges += [(i, n + j, {'weight': (np.abs(W[i, j]) > thr)})]

    G.add_edges_from(edges)
    # we sort the edges based on which right node j they map to
    matching = nx.max_weight_matching(G, maxcardinality=True)
    matching = sorted([sorted(e, reverse=True) for e in matching])
    # now we find which element goes to each slot in increasing order of slots
    best = [e[1] for e in matching]
    return W[best]

In [213]:
import itertools

def attempt_lower_tri(B):
    """ Attempts to find a way to re-arrange B to a lower triangular
    matrix via symmetric row and column permutations.
    """
    n = B.shape[0]
    perm = []
    for i in range(n):
        success = False
        for j in range(n):
            if j in perm:
                continue
            if np.sum(np.delete(B[j], perm)) == 0:
                success = True
                perm += [j]
                break
        if not success:
            return False
    return perm

def get_lower_tri_perm(W):
    """ Finds the permutation that can be applied to rows and columns
    such that the matrix is approximately lower triangular. It clips
    progressively more and more small entries and then applies the previous
    procedure to find an exact lower triangular
    """
    n = W.shape[0]
    B = W.copy()
    np.fill_diagonal(B, 0)

    values = np.sort(np.abs(B).flatten())
    thr_ind = (n * (n + 1) // 2) - 1
    B = B * (np.abs(B) > values[thr_ind])

    success = False
    while not success:
        perm = attempt_lower_tri(B)
        if perm == False:
            thr_ind += 1
            B = B * (np.abs(B) > values[thr_ind])
        else:
            success = True
    
    return perm

def get_lower_tri_perm_lowdim(W):
    """ For small number of variables we can just go over all possible
    permutations and find the one that has the smallest mass on the upper
    triangle part.
    """
    n = W.shape[0]
    best = (np.inf, None)
    for P in itertools.permutations(np.arange(n)):
        Wp = W[list(P)][:, list(P)]
        value = np.linalg.norm(np.triu(Wp, k=1).flatten(), ord=1)
        if value < best[0]:
            best = (value, P)
    return list(best[1])

In [214]:
from sklearn.decomposition import FastICA

def lingam(X, *, M0=None, k40=None, whiten='svd', n_iters=10):
    """ Implementation of the full LinGAM algorithm.
    """
    n = X.shape[1]

    W = fast_ica(X, M0=M0, k40=k40, whiten=whiten, n_iters=n_iters)
    # W = FastICA(whiten='unit-variance', max_iter=200, random_state=1543).fit(X).components_

    W = align_matrix(W)

    if n <= 5:
        best = get_lower_tri_perm_lowdim(W)
    else:
        best = get_lower_tri_perm(W)

    return best, W[best][:, best]

In [232]:
latent_dist = lambda N, n: np.random.uniform(-.1, .1, size=(N, n))
A, perm, invperm, M0, k40, VU0 = gen_instance(5, latent_dist)

X = gen_data(A, 2000, latent_dist, perm, random_seed=23423)

best, W = lingam(X, n_iters=100)

In [233]:
np.array(best) # discovered causal order

array([1, 4, 2, 0, 3])

In [234]:
invperm # true causal order

array([1, 4, 2, 0, 3], dtype=int64)

In [235]:
np.all(best == invperm) # test that we recovered the true causal order

True