In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from numpy.random import choice, rand, randn, permutation

import helperFunctions as hf 

In [2]:
# % Follow the synthethic test procedure as outlined in the paper
# % "K-SVD: An Algorithm for Designing Overcomplete Dictionaries for 
# % Sparse Representation." IEEE Trans. on Sig. Proc., Vol. 54, No. 11, 2006.

numSignalsRange = [250, 500, 1000, 1500]
# numSignalsRange = [250, 500]
dimSignal  = 20
numAtoms   = 50
numTrials  = 10
countMat = np.zeros((len(numSignalsRange),numTrials))

for nS, numSignals in enumerate(numSignalsRange):
    for nT in range(numTrials):
        # % generate random matrix D (the dictionary K-SVD will reconstruct)
        D = randn(dimSignal,numAtoms)
        D = normalize(D,norm='l2',axis=0)

        # % generate numSignals of size dimSignal created by linear combination of
        # % 3 different dictionary atoms, with uniformly distributed i.i.d 
        # % coefficients and added WGN

        sigma = 0.2 # modify for noise level
        Y = np.zeros((dimSignal,numSignals))
        for k in range(numSignals):
            rInds = choice(numAtoms,3) # 3 random dictionary atoms
            Y[:, k] = D[:,rInds]@(5*rand(3)) + sigma*randn(dimSignal)


        # % dictionary learning via K-SVD!
        Dhat, Xhat = hf.kSVD(Y,numAtoms,3)

        # % compare generated dictionary to true dictionary. 
        # % check if recovering true atoms

        count = 0
        for k in range(numAtoms):
            for j in range(numAtoms):
                if (1-np.abs(Dhat[:,k].T @ D[:,j]) < 0.01):
                    count = count + 1

        print(count)
        countMat[nS, nT] = count

2
0
0
0
0
0
2
1
1


ValueError: k must be between 1 and min(A.shape), k=1

In [None]:
ind_mat = np.vstack((250*np.ones(10), 500*np.ones(10), 1000*np.ones(10), 2000*np.ones(10)))
ind_mat.shape

In [None]:
colors = ['r','b','m','g']
for ii in range(4):
    plt.plot(ind_mat[:,ii], count_mat[:,ii], marker='s', color=colors[ii]);

plt.xlabel('Number of Signals')
plt.ylabel('Atoms Recovered')
plt.title('Recovery vs. Data Size')