## Helper function for background thread


In [23]:
import asyncio

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)


def background(f):
    def wrapped(*args, **kwargs):
        global loop
        return loop.run_in_executor(None, f, *args, **kwargs)

    return wrapped

## Enhancement functions

In [24]:
from numpy.typing import NDArray
import numpy as np


def next_power_of_2(x):
    return 1 if x == 0 else 2 ** (x - 1).bit_length()


def spectral_subtraction(rate: int, noisy: NDArray):
    fft = abs(np.fft.fft(noisy))
    len_ = 20 * rate // 1000  # frame size in samples
    PERC = 50  # window overlap in percent of frame
    len1 = len_ * PERC // 100  # overlap'length
    len2 = len_ - len1  # window'length - overlap'length

    # setting default parameters
    Thres = 3  # VAD threshold in dB SNRseg
    Expnt = 1.0  # exp(Expnt)
    G = 0.9

    # initial Hamming window
    win = np.hamming(len_)
    # normalization gain for overlap+add with 50% overlap
    winGain = len2 / sum(win)

    # nFFT = 2 * 2 ** (nextpow2.nextpow2(len_))
    nFFT = 2 * next_power_of_2(len_)
    noise_mean = np.zeros(nFFT)
    j = 1
    for k in range(1, 6):
        noise_mean = noise_mean + abs(np.fft.fft(win * noisy[j : j + len_], nFFT))
        j = j + len_
    noise_mu = noise_mean / 5

    # initialize various variables
    k = 1
    img = 1j
    x_old = np.zeros(len1)
    Nframes = len(noisy) // len2 - 1
    xfinal = np.zeros(noisy.shape[0])

    # === Start Processing === #
    for n in range(0, Nframes):
        # Windowing
        insign = win * noisy[k - 1 : k + len_ - 1]
        # compute fourier transform of a frame
        spec = np.fft.fft(insign, nFFT)
        # compute the magnitude
        sig = abs(spec)
        # save the noisy phase information
        theta = np.angle(spec)
        # SNR
        SNRseg = 10 * np.log10(
            np.linalg.norm(sig, 2) ** 2 / np.linalg.norm(noise_mu, 2) ** 2
        )

        # --- spectral subtraction --- #
        sub_speech = sig**Expnt - noise_mu**Expnt
        # the pure signal is less than the noise signal power
        diffw = sig**Expnt - noise_mu**Expnt

        # beta negative components
        def find_index(x_list):
            index_list = []
            for i in range(len(x_list)):
                if x_list[i] < 0:
                    index_list.append(i)
            return index_list

        z = find_index(diffw)
        if len(z) > 0:
            sub_speech[z] = 0

        # --- implement a simple VAD detector --- #
        if SNRseg < Thres:  # Update noise spectrum
            noise_temp = (
                G * noise_mu**Expnt + (1 - G) * sig**Expnt
            )  # Smoothing processing noise power spectrum
            noise_mu = noise_temp ** (1 / Expnt)  # New noise amplitude spectrum

        # add phase
        x_phase = (sub_speech ** (1 / Expnt)) * np.exp(img * theta)
        # take the IFFT
        xi = np.fft.ifft(x_phase).real

        # --- Overlap and add --- #
        xfinal[k - 1 : k + len2 - 1] = x_old + xi[0:len1]
        x_old = xi[0 + len1 : len_]

        k = k + len2

    xfinal[k - 1 : k + len2 - 1] = x_old

    return (winGain * xfinal).astype(noisy.dtype)


def mmse(rate: int, noisy: NDArray) -> NDArray:
    len_ = 20 * rate // 1000
    PERC = 50
    len1 = len_ * PERC // 100
    len2 = len_ - len1

    aa = 0.98
    eta = 0.15
    Thres = 3
    mu = 0.98
    c = np.sqrt(np.pi) / 2
    ksi_min = 10 ** (-25 / 10)

    win = np.hamming(len_)
    winGain = len2 / sum(win)

    nFFT = 2 * next_power_of_2(len_)
    j = 1
    noise_mean = np.zeros(nFFT)
    for k in range(1, 6):
        noise_mean = noise_mean + abs(np.fft.fft(win * noisy[j : j + len_], nFFT))
        j = j + len_
    noise_mu = noise_mean / 5
    noise_mu2 = noise_mu**2

    k = 1
    img = 1j
    x_old = np.zeros(len2)
    Nframes = len(noisy) // len2 - 1
    xfinal = np.zeros(Nframes * len2)

    for n in range(0, Nframes):
        insign = win * noisy[k - 1 : k + len_ - 1]
        spec = np.fft.fft(insign, nFFT)
        sig = abs(spec)
        sig2 = sig**2
        theta = np.angle(spec)

        noise_norm = np.linalg.norm(noise_mu, 2)
        if noise_norm == 0:
            noise_norm = 1e-10  # avoid division by zero
        SNRpos = 10 * np.log10(np.linalg.norm(sig, 2) ** 2 / noise_norm**2)
        gammak = np.minimum(sig2 / noise_mu2, 40)

        if n == 0:
            ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
        else:
            ksi = aa * Xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
            ksi = np.maximum(ksi_min, ksi)

        if SNRpos < Thres:
            noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
            noise_mu = np.sqrt(noise_mu2)

        vk = gammak * ksi / (1 + ksi)
        j_0 = sp.iv(0, vk / 2)
        j_1 = sp.iv(1, vk / 2)
        C = np.exp(-0.5 * vk)
        A = ((c * (vk**0.5)) * C) / gammak
        B = (1 + vk) * j_0 + vk * j_1
        hw = A * B

        mmse_speech = hw * sig

        Xk_prev = mmse_speech**2

        x_phase = mmse_speech * np.exp(img * theta)
        xi_w = np.fft.ifft(x_phase, nFFT).real

        xfinal[k - 1 : k + len2 - 1] = x_old + xi_w[0:len1]
        x_old = xi_w[len1 + 0 : len_]

        k = k + len2

    xfinal = winGain * xfinal.astype(noisy.dtype)

    if len(xfinal) < len(noisy):
        xfinal = np.pad(xfinal, (0, len(noisy) - len(xfinal)), "constant")
    else:
        xfinal = xfinal[: len(noisy)]

    return xfinal.astype(noisy.dtype)


def wiener_filtering(rate: int, noisy: NDArray) -> NDArray:
    len_ = 20 * rate // 1000
    PERC = 50
    len1 = len_ * PERC // 100
    len2 = len_ - len1

    Thres = 3
    Expnt = 1.0
    G = 0.9

    i = np.linspace(0, len_ - 1, len_)
    win = np.sqrt(2 / (len_ + 1)) * np.sin(np.pi * (i + 1) / (len_ + 1))
    winGain = len2 / sum(win)

    nFFT = 2 * next_power_of_2(len_)
    j = 1
    noise_mean = np.zeros(nFFT)
    for k in range(1, 6):
        noise_mean = noise_mean + abs(np.fft.fft(win * noisy[j : j + len_], nFFT))
        j = j + len_
    noise_mu = noise_mean / 5

    k = 1
    img = 1j
    x_old = np.zeros(len1)
    Nframes = len(noisy) // len2 - 1
    xfinal = np.zeros(Nframes * len2)

    for n in range(0, Nframes):
        insign = win * noisy[k - 1 : k + len_ - 1]
        spec = np.fft.fft(insign, nFFT)
        sig = abs(spec)
        theta = np.angle(spec)
        noise_norm = np.linalg.norm(noise_mu, 2)
        if noise_norm == 0:
            noise_norm = 1e-10  # avoid division by zero
        SNRpos = 10 * np.log10(np.linalg.norm(sig, 2) ** 2 / noise_norm**2)

        sub_speech = sig**Expnt - noise_mu**Expnt
        diffw = sig**Expnt - noise_mu**Expnt

        z = np.where(diffw < 0)[0]
        if len(z) > 0:
            sub_speech[z] = 0

        SNRpri = 10 * np.log10(
            np.linalg.norm(sub_speech, 2) ** 2 / np.linalg.norm(noise_mu, 2) ** 2
        )
        mel_max = 10
        mel_0 = (1 + 4 * mel_max) / 5
        s = 25 / (mel_max - 1)

        def get_mel(SNR):
            if -5.0 <= SNR <= 20.0:
                a = mel_0 - SNR / s
            else:
                if SNR < -5.0:
                    a = mel_max
                if SNR > 20:
                    a = 1
            return a

        mel = get_mel(SNRpri)
        G_k = sub_speech**2 / (sub_speech**2 + mel * noise_mu**2)
        wf_speech = G_k * sig

        if SNRpos < Thres:
            noise_temp = G * noise_mu**Expnt + (1 - G) * sig**Expnt
            noise_mu = noise_temp ** (1 / Expnt)

        x_phase = wf_speech * np.exp(img * theta)
        xi = np.fft.ifft(x_phase).real

        xfinal[k - 1 : k + len2 - 1] = x_old + xi[0:len1]
        x_old = xi[0 + len1 : len_]

        k = k + len2

    xfinal = winGain * xfinal.astype(noisy.dtype)

    if len(xfinal) < len(noisy):
        xfinal = np.pad(xfinal, (0, len(noisy) - len(xfinal)), "constant")
    else:
        xfinal = xfinal[: len(noisy)]

    return xfinal.astype(noisy.dtype)

## User Configuration

In [25]:
enhancement_functions = {
    "spectral_subtraction": spectral_subtraction,
    "mmse": mmse,
    "wiener_filtering": wiener_filtering,
}

ENHANCE = "spectral_subtraction"
AUDIO_PATH = "./Noise2_mono.wav"
CHUNK_SECONDS = 5

## Main code

In [26]:
import pyaudio
from scipy.io import wavfile
import time
import enlighten


def read_wave(path):
    rate, data = wavfile.read(path)
    if len(data.shape) > 1 and data.shape[1] >= 2:
        data = data.reshape((-1, 2)).mean(axis=1)
    data = np.trim_zeros(data, "f")
    return rate, data


p = pyaudio.PyAudio()
rate, data = read_wave(AUDIO_PATH)

chunk = rate * CHUNK_SECONDS

nsamples = len(data)
data_seconds = nsamples / rate

result = np.zeros(data.shape, dtype=np.int16)
current_frame = 0
next_chunk_frame = 0


bar_format = (
    "{desc}{desc_pad}{percentage:3.0f}%|{bar}| "
    + "{count:{len_total}d} "
    + "[{rate:.2f}{unit_pad}{unit}/s]"
)
manager = enlighten.get_manager()
processed = manager.counter(
    total=nsamples,
    desc="Processing",
    unit="frames",
    color="green",
    bar_format=bar_format,
)
playing = manager.counter(
    total=nsamples,
    desc="Playing",
    unit="frames",
    color="white",
    bar_format=bar_format,
)

enhance = enhancement_functions[ENHANCE]


@background
def process(rate, data):
    global result
    global processed
    result[current_frame : current_frame + chunk] = enhance(
        rate, (data[current_frame : current_frame + chunk]).astype(np.int16)
    )
    processed.update(chunk)


def callback(in_data, frame_count, time_info, status):
    global current_frame
    global result
    global next_chunk_frame
    global current_frame
    global playing

    while current_frame + chunk >= next_chunk_frame:
        process(rate, data)
        next_chunk_frame += chunk
    playing.update(frame_count)
    current_frame += frame_count
    return (
        result[current_frame - frame_count : current_frame],
        pyaudio.paContinue,
    )


player = p.open(
    format=pyaudio.paInt16,
    channels=1,
    rate=rate,
    output=True,
    stream_callback=callback,
)

try:
    while player.is_active():
        time.sleep(0.1)
except KeyboardInterrupt:
    pass

player.stop_stream()
player.close()
p.terminate()