In [None]:
%matplotlib inline

In [None]:
import numpy as np
from scipy import linalg

# Orthogonal matching pursuit (OMP)
#
# Y. C. Pati, R. Rezaiifar, and P. S. Krishnaprasad, 
# "Orthogonal matching pursuit: Recursive function approximation with applications to wavelet decomposition."
# The Twenty-Seventh Asilomar Conference on Signals, Systems and Computers, pp. 40-44, 1993.
#
def OMP(A, b, tol=1e-5, maxnnz=np.inf):
    m, n = A.shape
    supp = []
    x = np.zeros(n)
    r = b.copy()
    while len(supp) < maxnnz and linalg.norm(r) > tol:
        s = np.argmax(np.abs( A.T.dot(r) ))
        supp.append(s)
        Asupp = A[:,supp]
        x[supp] = np.linalg.lstsq(Asupp, b, rcond=None)[0]
        r = b - Asupp.dot(x[supp])
    return x

In [None]:
from time import time
#%% Demo: Sparse solvers
rng = np.random.RandomState(int(time()))
#m, n = 512, 2048
m, n = 1024, 8192
#m, n = 2000, 4000

# use a random matrix as a basis (design matrix)
A = rng.randn(m, n) / np.sqrt(m)

# generate a k-sparse Gaussian signal vector
k = 100
stdx = 1.
snr = 10.

x_true = np.zeros(n)
T = np.sort(rng.choice(n,k,replace=False))
print('True support of %d nonzeros = ' % (k))
print(T)
x_true[T] = rng.randn(k) * stdx

# make the query vector
b = A.dot(x_true)

# add noise
normb = linalg.norm(b)
noise = rng.randn(m)
noise = noise / linalg.norm(noise) * normb / snr
tol = linalg.norm(noise)
b = b + noise

# OMP
print("Running OMP..")
t0 = time()
#x_est = OMP(A, b, maxnnz=100)
x_est = OMP(A, b, tol=tol)
print('done in %.2fs.' % (time() - t0))

Tpred = np.nonzero(x_est)[0]
print('Predicted supprt of %d nonzeros = ' % (np.count_nonzero(x_est)))
print(Tpred)

from sklearn import metrics
print(metrics.classification_report(x_true == 0, x_est == 0))
print(metrics.confusion_matrix(x_true == 0, x_est == 0))

In [None]:
print('rel. error = %.2e' % (linalg.norm(x_est-x_true)/linalg.norm(x_true)))
import matplotlib.pyplot as plt
plt.figure()
plt.plot(np.arange(n), x_true, 'g.', markersize=8, mec='green', label='True')
plt.plot(np.arange(n), x_est, 'ro', mfc = 'None', markersize=8, mec='red', mew=1, label='Estimated')
plt.legend(loc='upper right', shadow=False)
plt.show()