In [6]:
import numpy as np
from scipy.stats import entropy
from scipy.sparse import issparse
from joblib import Parallel, delayed
from scipy.sparse import csr_matrix
import time

def mutual_information_matrix_serial(matrix, nbins=20, n_jobs=-1):
    """
    Computes the mutual information matrix in parallel, working directly with sparse matrices,
    and only computes the upper triangular part of the matrix.
    """
    if not issparse(matrix):
        matrix = csr_matrix(matrix)

    n_features = matrix.shape[0]
    mi_matrix = np.zeros((n_features, n_features))

    def compute_pairwise_mi(vi, vj, nbins=20):
        joint_counts, _, _ = np.histogram2d(vi, vj, bins=nbins)
        if joint_counts.sum() == 0:
            return 0  # No mutual information if no overlap
        joint_prob = joint_counts / (joint_counts.sum() + 1e-10)

        marginal_i = joint_prob.sum(axis=1) + 1e-10
        marginal_j = joint_prob.sum(axis=0) + 1e-10

        h_xy = entropy(joint_prob.flatten(), base=2)
        h_x = entropy(marginal_i, base=2)
        h_y = entropy(marginal_j, base=2)

        return float(h_x + h_y - h_xy)

    for i in range(n_features):
        for j in range(i, n_features):
            vi = matrix[i, :].toarray().flatten() if issparse(matrix) else matrix[i, :]
            vj = matrix[j, :].toarray().flatten() if issparse(matrix) else matrix[j, :]
            mi_matrix[i, j] = compute_pairwise_mi(vi, vj, nbins=nbins)
            if i != j:
                mi_matrix[j, i] = mi_matrix[i, j]  # Exploit symmetry
    return mi_matrix


In [None]:
# vec1 = [1, 2, 3, 0, 0]  # Row 0
# vec2 = [4, 0, 6, 0, 0]  # Row 1
# vec3 = [0, 1, 3, 7, 9]  # Row 2
# vec4 = [5, 0, 0, 0 ,2] # Row 3
# matrix = [ vec1 , vec2, vec3, vec3]
# sparse_matrix = csr_matrix(matrix)

In [7]:
from scipy.sparse import random as sparse_random
from scipy.io import mmwrite

# Generate a sparse random matrix with 1000 rows and 5000 columns
# Density of the matrix is set to 0.01 (1% non-zero elements)
sparse_matrix = sparse_random(5000, 5000, density=0.01, format='csr')


mmwrite("sparse_matrix.mtx", sparse_matrix)

In [3]:
start_time = time.time()

mi_matrix = mutual_information_matrix_serial(sparse_matrix, nbins=20, n_jobs=-1)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.4f} seconds") 

sparse_matrix_mi = csr_matrix(mi_matrix)
print(sparse_matrix_mi)
mmwrite("sparse_matrix_mi.mtx", sparse_matrix_mi)

Elapsed time: 591.4128 seconds
<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 1000000 stored elements and shape (1000, 1000)>
  Coords	Values
  (0, 0)	0.08648760903092427
  (0, 1)	9.420314531990992e-05
  (0, 2)	0.001549447524230002
  (0, 3)	0.00011493756192126892
  (0, 4)	0.004310319340525559
  (0, 5)	0.0018164588544711535
  (0, 6)	0.00011258155871585185
  (0, 7)	0.0020293543704757444
  (0, 8)	0.00011258058279708938
  (0, 9)	0.0031197733992681387
  (0, 10)	9.233747467682352e-05
  (0, 11)	0.002214613839367441
  (0, 12)	0.0035140042299334484
  (0, 13)	0.00011476426096238002
  (0, 14)	9.875123338720648e-05
  (0, 15)	0.0013639988262479086
  (0, 16)	8.810503988537777e-05
  (0, 17)	9.090068226275116e-05
  (0, 18)	9.265163942409615e-05
  (0, 19)	0.00011274820697315668
  (0, 20)	0.00010032937815018794
  (0, 21)	0.00011082802055487062
  (0, 22)	8.432561813756512e-05
  (0, 23)	0.00010865607498786178
  (0, 24)	0.0016378665976523044
  :	:
  (999, 975)	0.00012451360994444882
  (999, 

In [8]:
from scipy.sparse import issparse
from scipy.stats import entropy
import numpy as np
from joblib import Parallel, delayed

def mutual_information_matrix_parallel(matrix, nbins=20, n_jobs=-1):
    """
    Computes the mutual information matrix in parallel, working directly with sparse matrices,
    and computes the full matrix (including the diagonal elements).
    """
    if not issparse(matrix):
        matrix = matrix.tocsr()

    n_features = matrix.shape[0]
    mi_matrix = np.zeros((n_features, n_features))

    def compute_pairwise_mi(i, j, matrix, nbins=20):
        """
        Computes mutual information between row i and row j of the sparse matrix.
        """
        vi = matrix[i, :].toarray().flatten() if issparse(matrix) else matrix[i, :]
        vj = matrix[j, :].toarray().flatten() if issparse(matrix) else matrix[j, :]
        
        joint_counts, _, _ = np.histogram2d(vi, vj, bins=nbins)
        if joint_counts.sum() == 0:
            return 0  # No mutual information if no overlap
        joint_prob = joint_counts / (joint_counts.sum())

        marginal_i = joint_prob.sum(axis=1) + 1e-8
        marginal_j = joint_prob.sum(axis=0) + 1e-8

        h_xy = entropy(joint_prob.flatten(), base=2)
        h_x = entropy(marginal_i, base=2)
        h_y = entropy(marginal_j, base=2)

        return float(h_x + h_y - h_xy)

    # Parallelizing the pairwise mutual information computation
    jobs = [(i, j) for i in range(n_features) for j in range(i, n_features)]  # Includes diagonal
    results = Parallel(n_jobs=n_jobs)(
        delayed(compute_pairwise_mi)(i, j, matrix, nbins) for i, j in jobs
    )

    # Fill the matrix with the results
    for idx, (i, j) in enumerate(jobs):
        mi_matrix[i, j] = results[idx]
        mi_matrix[j, i] = results[idx]  # Exploit symmetry to avoid duplicate computation

    return mi_matrix


In [9]:
start_time = time.time()

mi_matrix = mutual_information_matrix_parallel(sparse_matrix, nbins=20, n_jobs=-1)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.4f} seconds") 

sparse_matrix_mi = csr_matrix(mi_matrix)
print(sparse_matrix_mi)
mmwrite("sparse_matrix_mi.mtx", sparse_matrix_mi)

Elapsed time: 4365.2635 seconds
<Compressed Sparse Row sparse matrix of dtype 'float64'
	with 25000000 stored elements and shape (5000, 5000)>
  Coords	Values
  (0, 0)	0.08595718539621552
  (0, 1)	8.695767045516223e-05
  (0, 2)	0.0001009282230546582
  (0, 3)	0.0021832049015041155
  (0, 4)	0.00010265814458118583
  (0, 5)	9.906287545818904e-05
  (0, 6)	0.00010296923787081469
  (0, 7)	0.00012718322926646985
  (0, 8)	0.00011097445947733098
  (0, 9)	0.00010094093701806806
  (0, 10)	0.0001068829646203584
  (0, 11)	9.042029584752087e-05
  (0, 12)	0.00011710334953562995
  (0, 13)	0.00012103499616814006
  (0, 14)	0.0017203431515214473
  (0, 15)	0.00011506174719264073
  (0, 16)	8.070379777949666e-05
  (0, 17)	0.0013403142832507375
  (0, 18)	7.867074161838072e-05
  (0, 19)	0.00011288742424414577
  (0, 20)	9.669157042743737e-05
  (0, 21)	8.228068031501667e-05
  (0, 22)	0.00011694402415221572
  (0, 23)	0.00012938554579056127
  (0, 24)	0.00013740750799048906
  :	:
  (4999, 4975)	0.000106136075548546

In [6]:
print(mi_matrix.shape)

(500, 500)
