# Constrained Orthogonal Matching Pursuit for Audio Declipping

In [147]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

## Signal Clipping

In [148]:
def clip(s, theta_clip):
    return np.clip(s, -theta_clip, theta_clip)

In [149]:
def get_theta_clip(y):
    return np.max(np.abs(y))

In [150]:
def get_M_r(y):
    N = y.shape[0]
    M_r = np.eye(N)
    theta_clip = get_theta_clip(y)
    I_r = np.abs(y) < theta_clip
    M_r = M_r[I_r, :]
    return M_r

## Gabor Dictionary

In [151]:
params = {
    "sampling_rate": 16000,
    "frame_length": 1024,
    "frame_overlap": 768
}

# Number of atoms
K_g = params["frame_length"]
# Length of the signal
N = params["frame_length"]

In [152]:
# Create the time-frequency grid
T = np.arange(0, N)
J = np.arange(0, K_g)
J, T = np.meshgrid(J, T)

# Dictionaries of shape (N, K_g)
gabor_cosine = np.cos(np.pi * (J+1/2) * (T+1/2) / K_g)
gabor_sine = np.sin(np.pi * (J+1/2) * (T+1/2) / K_g)

## Orthogonal Matching Pursuit

In [153]:
def least_squares(y, D_c, D_s, theta_clip = None, theta_max = None):
    """
    Solves the following least squares problem
        min_{x_c, x_s} ||y - D_c * x_c - D_s * x_s||^2     s.t.     ...
    """
    D = np.concatenate((D_c, D_s), axis=1)
    print(D.shape, y.shape)
    x = np.linalg.lstsq(D, y, rcond=None)[0]
    x_c = x[:D_c.shape[1]]
    x_s = x[D_c.shape[1]:]

    ## TODO: Implement optimization in the case with linear constraints

    return x_c, x_s

In [None]:
def OMP(y, M_r, K, eps, D_c = gabor_cosine, D_s = gabor_sine, theta_clip = None, theta_max = None):
    """ 
    Runs the Orthogonal Matching Pursuit algorithm, using Gabor Dictionaries

    Inputs:
    --------
    y: np.array
        Input signal of size N_r
    M_r: np.array
        Measurement matrix of size (N_r, N)
    D_c: np.array
        Dictionary for the cosine atoms of size (N, K_g)
    D_s: np.array
        Dictionary for the sine atoms of size (N, K_g)
    K: int
        Maximal number of atoms to select
    eps: float
        Stopping criterion
    theta_clip: float
        (Optional) Clipping value of the signal, used as an additional constraint in the least squares problem. If None, no clipping constraint is applied.
    theta_max: float
        (Optional) Maximum value of the signal, used as an additional constraint in the least squares problem. If None, no maximum constraint is applied.

    Outputs:
    --------
    x: np.array
        Estimated sparse signal
    residual_norms : list
        List of the residual norms at each iteration
    """

    N = D_c.shape[0]
    K_g = D_c.shape[1]
    N_r = M_r.shape[0]

    y_r = M_r @ y                                                          # Of shape (N_r)

    # Dictionaries
    W_c = np.linalg.inv(np.diag(np.linalg.norm(M_r @ D_c, axis=0)))        # W_j,j = 1/||M_r * d_j||, j = 0, ..., K_g-1, of shape (K_g, K_g)
    W_s = np.linalg.inv(np.diag(np.linalg.norm(M_r @ D_s, axis=0)))
    d_c_norm = M_r @ D_c @ W_c                                             # Of shape (N_r, K_g)
    d_s_norm = M_r @ D_s @ W_s
    d_cs_dot = np.diag(np.dot(d_c_norm.T, d_s_norm), k = 0)                # Array containing <d_norm_j^c|d_norm_j^s>, j = 0, ..., K_g-1, of shape (K_g)

    # Residual and support
    r = y_r
    Omega = []
    residual_norms = [np.linalg.norm(y_r)]


    for k in tqdm(range(K)):

        # Atom selection
        x_c = (np.dot(r, d_c_norm) - d_cs_dot * np.dot(r, d_s_norm)) / (1 - d_cs_dot**2)
        x_s = (np.dot(r, d_s_norm) - d_cs_dot * np.dot(r, d_c_norm)) / (1 - d_cs_dot**2)
        proj = np.zeros(K_g)
        for j in range(K_g):
            proj[j] = np.linalg.norm(r - x_c[j] * d_c_norm[:,j] - x_s[j] * d_s_norm[:,j])**2
        i = np.argmax(np.abs(proj))

        # Update support and residual
        Omega.append(i)
        x_c2, x_s2 = least_squares(y_r, d_c_norm[:,Omega], d_s_norm[:,Omega], theta_clip, theta_max)
        x_c, x_s = np.zeros(K_g), np.zeros(K_g)
        x_c[Omega] = x_c2
        x_s[Omega] = x_s2

        r = y_r - np.dot(d_c_norm[:,Omega], x_c2) - np.dot(d_s_norm[:,Omega], x_s2)
        r_norm = np.linalg.norm(r)
        residual_norms.append(r_norm)

        # Stopping criterion
        if r_norm < eps:
            break
    
    # Output
    x = W_c @ x_s + W_s @ x_s
    return x, residual_norms

## Je pense que ca marche pas, je debug ca demain - Pierre

In [None]:
## TODO : Function to do OMP on multiple frames, and then overlap them

## Data

### Dataset

In [None]:
# Synthetic data generation
def generate_synthetic_dataset(M, N, K, theta_clip, D_c = gabor_cosine, D_s = gabor_sine):
    """
    Generates M waveforms of length N. Each waveform is a sum of K Gabor atoms. Both the original signal and the signal clipped at theta_clip are returned, along with the groun-truth vector x.
    """
    X = np.zeros((M, N))
    Y = np.zeros((M, N))
    Y_clipped = np.zeros((M, N))
    
    ## TODO

    return X, Y, Y_clipped


In [None]:
## TODO: add a function to load some real data

### Exploratory Data Analysis

## Experiments