In [1]:
from pyMCL.pymcl import markov_cluster
from pyMCL.jax_pymcl import jax_markov_cluster

import numpy as np
import jax.numpy as jnp


def generate_random_array(n):
    return np.random.rand(n, n)

# Example usage
n = 100
random_array = generate_random_array(n)

print("numpy")
%timeit markov_cluster(random_array)

print("jax")
%timeit jax_markov_cluster(random_array)

numpy
3.97 ms ± 90.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jax


TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [None]:
print("numpy")
%timeit np_mm()
print("jax")
%timeit jnp_mm()

In [15]:
import numpy as np
import jax.numpy as jnp
from jax import jit, lax

# Generate a random array and convert to JAX array
small_array = np.random.rand(100, 100)
small_array_jax = jnp.array(small_array)

@jit
def _prune(matrix: jnp.ndarray, threshold: float) -> jnp.ndarray:
    col_max = jnp.max(matrix, axis=0, keepdims=True)
    # Create boolean mask
    prune_mask = jnp.logical_and(matrix < threshold, matrix != col_max)
    # Use `jax.lax.select` to apply the mask
    set_matrix = lax.select(prune_mask, jnp.zeros_like(matrix), matrix)
    return set_matrix

@jit
def test(matrix):
    # Exponential
    matrix_exp = jnp.linalg.matrix_power(matrix, 2)
    # Prune
    matrix_prune = _prune(matrix_exp, 1e-5)
    # Inflation
    matrix_inf = jnp.power(matrix_prune, 2)

    return matrix_inf

test(small_array_jax)
%timeit test(small_array_jax)

67.7 μs ± 132 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [51]:
test1_array = np.array(
    [
        [0.2  , 0.25 , 0.   , 0.   , 0.   , 0.333, 0.25 , 0.   , 0.   , 0.25 , 0.   , 0.   ],
        [0.2  , 0.25 , 0.25 , 0.   , 0.2  , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ],
        [0.   , 0.25 , 0.25 , 0.2  , 0.2  , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.25 , 0.2  , 0.   , 0.   , 0.   , 0.2  , 0.2  , 0.   , 0.2  , 0.   ],
        [0.   , 0.25 , 0.25 , 0.   , 0.2  , 0.   , 0.25 , 0.2  , 0.   , 0.   , 0.   , 0.   ],
        [0.2  , 0.   , 0.   , 0.   , 0.   , 0.333, 0.   , 0.   , 0.   , 0.25 , 0.   , 0.   ],
        [0.2  , 0.   , 0.   , 0.   , 0.2  , 0.   , 0.25 , 0.   , 0.   , 0.25 , 0.   , 0.   ],
        [0.   , 0.   , 0.   , 0.2  , 0.2  , 0.   , 0.   , 0.2  , 0.2  , 0.   , 0.2  , 0.   ],
        [0.   , 0.   , 0.   , 0.2  , 0.   , 0.   , 0.   , 0.2  , 0.2  , 0.   , 0.2  , 0.333],
        [0.2  , 0.   , 0.   , 0.   , 0.   , 0.333, 0.25 , 0.   , 0.   , 0.25 , 0.   , 0.   ],
        [0.   , 0.   , 0.   , 0.2  , 0.   , 0.   , 0.   , 0.2  , 0.2  , 0.   , 0.2  , 0.333],
        [0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.   , 0.2  , 0.   , 0.2  , 0.333]
    ]
)
def _normalize_cols(matrix: np.ndarray)->np.ndarray:
    row_sums = matrix.sum(axis=0, keepdims=True)
    return matrix / row_sums

def _prune(matrix: np.ndarray, threshold: float)->np.ndarray:
    col_max = np.max(matrix, axis=0, keepdims=True)
    prune_mask = np.logical_and(matrix < threshold, matrix != col_max)
    matrix[prune_mask] = 0
    return matrix

def _measure_convergence(matrix: np.ndarray)->float:
    col_max = np.max(matrix, axis=0, keepdims=True)
    sum_of_squares = np.sum(np.square(matrix), axis=0)
    value = np.max(col_max - sum_of_squares)
    return value

def mcl(matrix, expansion=2, inflation=2, pruning_threshold=1e-5, convergence_threshold=1e-3):
    chaos = 1
    while True:
        # expansion
        matrix_exp = np.linalg.matrix_power(matrix, expansion)
        # purne
        matrix_prune = _prune(matrix_exp, pruning_threshold)
        # inflation
        matrix_inf = np.power(matrix_prune, inflation)

        # renormalize columns
        matrix = _normalize_cols(matrix_inf)
        # asess convergence
        chaos = _measure_convergence(matrix)
        if chaos < convergence_threshold:
            break

    return matrix

%timeit mcl(test1_array)

200 μs ± 2.21 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [50]:
import jax.numpy as jnp
from jax import jit, lax

test1_array = jnp.array(
    [
        [0.2, 0.25, 0.0, 0.0, 0.0, 0.333, 0.25, 0.0, 0.0, 0.25, 0.0, 0.0],
        [0.2, 0.25, 0.25, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.25, 0.25, 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.25, 0.2, 0.0, 0.0, 0.0, 0.2, 0.2, 0.0, 0.2, 0.0],
        [0.0, 0.25, 0.25, 0.0, 0.2, 0.0, 0.25, 0.2, 0.0, 0.0, 0.0, 0.0],
        [0.2, 0.0, 0.0, 0.0, 0.0, 0.333, 0.0, 0.0, 0.0, 0.25, 0.0, 0.0],
        [0.2, 0.0, 0.0, 0.0, 0.2, 0.0, 0.25, 0.0, 0.0, 0.25, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.2, 0.2, 0.0, 0.0, 0.2, 0.2, 0.0, 0.2, 0.0],
        [0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.2, 0.2, 0.0, 0.2, 0.333],
        [0.2, 0.0, 0.0, 0.0, 0.0, 0.333, 0.25, 0.0, 0.0, 0.25, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.2, 0.2, 0.0, 0.2, 0.333],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.2, 0.333],
    ]
)

@jit
def _normalize_cols(matrix: jnp.ndarray) -> jnp.ndarray:
    row_sums = matrix.sum(axis=0, keepdims=True)
    return matrix / row_sums

@jit
def _prune(matrix: jnp.ndarray, threshold: float) -> jnp.ndarray:
    col_max = jnp.max(matrix, axis=0, keepdims=True)
    prune_mask = jnp.logical_and(matrix < threshold, matrix != col_max)
    return jnp.where(prune_mask, 0, matrix)

@jit
def _measure_convergence(matrix: jnp.ndarray) -> float:
    col_max = jnp.max(matrix, axis=0, keepdims=True)
    sum_of_squares = jnp.sum(jnp.square(matrix), axis=0)
    value = jnp.max(col_max - sum_of_squares)
    return value


def mcl(matrix, expansion=2, inflation=2, pruning_threshold=1e-5, convergence_threshold=1e-3):
    def cond_fn(state):
        matrix, _ = state
        chaos = _measure_convergence(matrix)
        return jnp.any(jnp.greater(chaos, convergence_threshold))

    def body_fn(state):
        matrix, iteration = state
        # expansion
        matrix_exp = jnp.linalg.matrix_power(matrix, expansion)
        # prune
        matrix_prune = _prune(matrix_exp, pruning_threshold)
        # inflation
        matrix_inf = jnp.power(matrix_prune, inflation)
        # renormalize columns
        matrix = _normalize_cols(matrix_inf)
        return matrix, iteration + 1

    matrix, _ = lax.while_loop(cond_fn, body_fn, (matrix, 0))
    return matrix

# Run the MCL algorithm
mcl(test1_array)
%time mcl(test1_array)

CPU times: user 47.1 ms, sys: 3.37 ms, total: 50.5 ms
Wall time: 46.8 ms


Array([[0.2  , 0.25 , 0.   , 0.   , 0.   , 0.333, 0.25 , 0.   , 0.   ,
        0.25 , 0.   , 0.   ],
       [0.2  , 0.25 , 0.25 , 0.   , 0.2  , 0.   , 0.   , 0.   , 0.   ,
        0.   , 0.   , 0.   ],
       [0.   , 0.25 , 0.25 , 0.2  , 0.2  , 0.   , 0.   , 0.   , 0.   ,
        0.   , 0.   , 0.   ],
       [0.   , 0.   , 0.25 , 0.2  , 0.   , 0.   , 0.   , 0.2  , 0.2  ,
        0.   , 0.2  , 0.   ],
       [0.   , 0.25 , 0.25 , 0.   , 0.2  , 0.   , 0.25 , 0.2  , 0.   ,
        0.   , 0.   , 0.   ],
       [0.2  , 0.   , 0.   , 0.   , 0.   , 0.333, 0.   , 0.   , 0.   ,
        0.25 , 0.   , 0.   ],
       [0.2  , 0.   , 0.   , 0.   , 0.2  , 0.   , 0.25 , 0.   , 0.   ,
        0.25 , 0.   , 0.   ],
       [0.   , 0.   , 0.   , 0.2  , 0.2  , 0.   , 0.   , 0.2  , 0.2  ,
        0.   , 0.2  , 0.   ],
       [0.   , 0.   , 0.   , 0.2  , 0.   , 0.   , 0.   , 0.2  , 0.2  ,
        0.   , 0.2  , 0.333],
       [0.2  , 0.   , 0.   , 0.   , 0.   , 0.333, 0.25 , 0.   , 0.   ,
        0.25 , 0.   