In [None]:
def sparse_dictionary_learning(Y, K, L, iters=10, D_initial=None, algo='OMP', samples=100, with_errors=False):
    """
    This algorithm finds a (d x K) matrix D (the dictionary) and a (K x N) matrix A (the sparse representation) which minimise the L2 distance between Y and D A, ie, minimise ||Y - D A ||, subject to the constraint that each column of A has at most L non-zero elements.

    :param Y: This is the (d x N) matrix representing the N different d-dimensional given signals.
    :param K: An integer representing the size of the dictionary.
    :param L: An integer representing the maximum number of "atoms", D[:, k], in the dictionary that each sparse representation vector, A[:, i], can use.

    Note: This algorithm is written under the assumption that: 0 < L < d < K < N

    :param D_initial: This is the initial guess for the (d x N) matrix D. If not None, the columns of this matrix must be unit length.
    :param algo: This is a string defining the sparse representation algorithm. Either algo = 'OMP' for Orhtogonal Matching Pursuit, or algo = 'MP' for Matching Pursuit.
    :param iters: The number of iterations this will run for
    :param with_errors: A boolean which determines if the output includes the list of the error values at each step of the iteration.
    :param samples: This tells us the number of random samples to take from the training data Y at each step

    :return: (D, A, errors)
        D: This is the (d x K) matrix representing the dictionary of K different atoms, where the atoms are d-dimensional
    vectors.
        A: This is the (K x N) matrix of the N different K-dimensional sparse representations of the columns of Y.
        errors: This is an optional output. It is the list of the error values at each step of the iteration.
    """

    Y_full = Y

    # Get Initial D
    if D_initial == None:
        D = Y[:, random.sample(range(N), k=K)]
        D = D / np.linalg.norm(D, axis=0)

    # Get the correct algorithm
    if algo == 'OMP':
        sparse_rep = find_sparse_rep_OMP
    elif algo == 'MP':
        sparse_rep = find_sparse_rep

    # Initialize the list of error values
    errors = []

    for step in range(iters):
        Y = Y_full[:, random.sample(range(len(Y[0])), k=samples)]
        A = sparse_rep(Y, D, L)
        D = update_dictionary_kSVD(Y, D, A)

        if with_errors:
            errors.append(np.linalg.norm(Y - np.dot(D, A)))

    A = sparse_rep(Y_full, D, L)

    if with_errors:
        return (D, A, errors)
    else:
        return (D, A)

