<a href="https://colab.research.google.com/github/tsakailab/spmlib/blob/master/demo/eg20_SeparateSignalDCTWT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from scipy import linalg
import scipy.sparse.linalg as splinalg

# soft thresholding function
def soft(z, th):
    return np.sign(z) * np.maximum(np.abs(z) - th, 0)

# Fast iterative soft thresholding algorithm (FISTA)
def FISTA(A, b, x=None, tol=1e-5, maxiter=1000, tolx=1e-12, l=1., L=None):

    # A : m x n matrix, LinearOperator, or tuple (fA, fAT) of functions fA(z)=A.dot(z) and fAT(r)=A.conj().T.dot(r).
    # define the functions that compute projections by A and its adjoint
    if type(A) is tuple:
        fA, fAT = A[0], A[1]
    else:
        A = splinalg.aslinearoperator(A)
        fA, fAT = A.matvec, A.rmatvec

    # roughly estimate the Lipschitz constant
    if L is None:
        L = 2*linalg.norm(fA(fAT(b))) / linalg.norm(b)

    # initialize x
    if x is None:
        x = fAT(b)

    # initialize variables
    t = 1.
    w = x.copy()
    r = b - fA(w)

    count = 0
    cost_history = []
    normr = linalg.norm(r)
    while count < maxiter and normr > tol:
        count += 1
        dx = x.copy()
        x = soft(w + fAT(r) / L, l/L)
        dx = x - dx

        told = t
        t = 0.5 * (1. + np.sqrt(1. + 4. * t * t))
        w = x + ((told - 1.) / t) * dx

        r = b - fA(w)
        normr = linalg.norm(r)
        cost_history.append( normr*normr + l*np.sum(np.abs(x)) )

        if linalg.norm(dx) < tolx:
            break

    return x, cost_history

In [None]:
import numpy as np
from scipy.fftpack import dct, idct
from scipy import linalg
from pywt import wavedec, waverec, coeffs_to_array, array_to_coeffs

def dwt(data, wavelet, mode='per', level=None):
    # returns (coeff_arr, coeff_slices)
    return coeffs_to_array(wavedec(data, wavelet, mode, level))

def idwt(coeff_arr_slices, wavelet, mode='per'):
    # returns data
    return waverec(array_to_coeffs(coeff_arr_slices[0], coeff_arr_slices[1], output_format='wavedec'), wavelet, mode)

# fA
def reconst_dctwt(coeffs, coeff_slices, wavelet='db10', wl_weight=0.5):
    n = coeffs.shape[0] // 2
    return idct(coeffs[:n], norm='ortho') +  wl_weight * idwt([coeffs[n:], coeff_slices], wavelet)

# fAT
def decomp_dctwt(signal, wavelet='db10', level=None, wl_weight=0.5):
    # returns (coeffs, coeff_slices)
    coeff_arr, coeff_slices = dwt(signal, wavelet, level=level)
    return np.concatenate((dct(signal, norm='ortho'), wl_weight*coeff_arr), axis=0), coeff_slices

In [None]:
from scipy.fftpack import next_fast_len

def separate_signal_dctwt_FISTA(signal, x=None, tol=1e-5, maxiter=1000, tolx=1e-12, l=1., L=None,
                                wavelet='db10', level = 3, wl_weight = 0.5):
    
    length = len(signal)
    n = next_fast_len(length)

    b = np.zeros((n))
    b[0:length] = signal[0:length]

    cnnz = float("Inf")
    slices = decomp_dctwt(b, wavelet, level, wl_weight)[1]

    fA = lambda x: reconst_dctwt(x, slices, wavelet, wl_weight)   
    fAT = lambda y: decomp_dctwt(y, wavelet, level, wl_weight)[0]

    #FISTA
    x, cost_history = FISTA(A=(fA, fAT), b=b, x=x, tol=tol, maxiter=maxiter, tolx=tolx, l=l, L=L)

    signal_dct = idct(x[:n], norm='ortho')[:length]
    signal_wl = wl_weight * idwt((x[n:], slices), wavelet)[:length]

    return  signal_dct, signal_wl, x, cost_history

In [None]:
signal = ??? # np.array()

In [None]:
signal_dct, signal_wl, x, cost_history = separate_signal_dctwt_FISTA(signal, maxiter=100, l=1.)

import matplotlib.pyplot as plt
plt.plot(signal, '.-', signal_dct, 'r.-', signal_wl, 'g.-')

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
!ls