### 테스트 데이터셋 설정

**필요한 경우만 실행할 것**

dns/datasets/...에 있는 테스트 데이터를 eval_data로 복사한 후,

noisy의 데이터들의 이름을 noisy_fileid_12.wav와 같이 변경합니다.

In [None]:
import os
import shutil
import re


def copytree(src, dst, symlinks=False, ignore=None):
    for item in os.listdir(src):
        s = os.path.join(src, item)
        d = os.path.join(dst, item)
        if os.path.isdir(s):
            shutil.copytree(s, d, symlinks, ignore)
        else:
            shutil.copy2(s, d)


original_testset_path = os.path.join(
    "dns", "datasets", "test_set", "synthetic", "no_reverb"
)
new_testset_path = os.path.join("eval_data")

if not os.path.exists(new_testset_path):
    os.mkdir(new_testset_path)

copytree(original_testset_path, new_testset_path)

noisy_files = os.listdir(os.path.join(new_testset_path, "noisy"))

for file in noisy_files:
    m = re.search(r"_fileid_\d+.wav", y)
    if m != None:
        os.rename(file, "noisy" + m.group())

### 비교용 템플릿

먼저 메트릭을 뽑아내는 부분입니다

In [24]:
import os
from collections import defaultdict
from tqdm import tqdm
import time
import warnings

warnings.filterwarnings("ignore")

import numpy as np
from numpy.typing import NDArray
from scipy.io import wavfile

from pesq import pesq
from pystoi import stoi


def result_to_metric(result):
    metric = defaultdict(float)

    metric["pesq_wb"] = result["pesq_wb"] / result["count"]
    metric["pesq_nb"] = result["pesq_nb"] / result["count"]
    metric["stoi"] = result["stoi"] / result["count"]
    metric["rtf"] = result["infer_time"] / result["length"]

    return metric


def eval_metric(infer, target_name, testset_path="eval"):
    result = defaultdict(int)

    for i in tqdm(range(300)):
        duration = 0
        try:
            rate, clean = wavfile.read(
                os.path.join(testset_path, "clean", "clean_fileid_{}.wav".format(i))
            )
            # As we infer on the CPU device, we don't need to sync with GPU.
            # So, we can utilize time.
            start_time = time.time()
            rate, target_wav = infer(rate, clean, i)
            n_samples = target_wav.shape[-1]
            duration = time.time() - start_time
        except:
            continue

        n_samples = target_wav.shape[-1]
        length = n_samples / rate

        result["pesq_wb"] += (
            pesq(rate, clean, target_wav, "wb") * n_samples
        )  # wide band
        result["pesq_nb"] += (
            pesq(rate, clean, target_wav, "nb") * n_samples
        )  # narrow band
        result["stoi"] += stoi(clean, target_wav, rate) * n_samples
        result["count"] += 1 * n_samples
        result["length"] += length
        result["infer_time"] += duration

    if result["count"] is None:
        return None
    metric = result_to_metric(result)
    return metric

infer에 들어갈 spectral_subtraction 함수입니다.

메트릭을 구하기 위해 길이가 같아야 해서 코드를 적절히 변경하였습니다.

해당 코드를 비교해보시면 변경된 부분을 쉽게 구하실 수 있을 거에요.

In [25]:
def spectral_subtraction(rate: int, clean: NDArray, i: int):
    fft = abs(np.fft.fft(clean))
    len_ = 20 * rate // 1000  # frame size in samples
    PERC = 50  # window overlap in percent of frame
    len1 = len_ * PERC // 100  # overlap'length
    len2 = len_ - len1  # window'length - overlap'length

    # setting default parameters
    Thres = 3  # VAD threshold in dB SNRseg
    Expnt = 1.0  # exp(Expnt)
    G = 0.9

    # initial Hamming window
    win = np.hamming(len_)
    # normalization gain for overlap+add with 50% overlap
    winGain = len2 / sum(win)

    # nFFT = 2 * 2 ** (nextpow2.nextpow2(len_))
    nFFT = 2 * 2**8
    noise_mean = np.zeros(nFFT)
    j = 1
    for k in range(1, 6):
        noise_mean = noise_mean + abs(np.fft.fft(win * clean[j : j + len_], nFFT))
        j = j + len_
    noise_mu = noise_mean / 5

    # initialize various variables
    k = 1
    img = 1j
    x_old = np.zeros(len1)
    Nframes = len(clean) // len2 - 1
    xfinal = np.zeros((Nframes + 1) * len2)

    # === Start Processing === #
    for n in range(0, Nframes):
        # Windowing
        insign = win * clean[k - 1 : k + len_ - 1]
        # compute fourier transform of a frame
        spec = np.fft.fft(insign, nFFT)
        # compute the magnitude
        sig = abs(spec)
        # save the noisy phase information
        theta = np.angle(spec)
        # SNR
        SNRseg = 10 * np.log10(
            np.linalg.norm(sig, 2) ** 2 / np.linalg.norm(noise_mu, 2) ** 2
        )

        # --- spectral subtraction --- #
        sub_speech = sig**Expnt - noise_mu**Expnt
        # the pure signal is less than the noise signal power
        diffw = sig**Expnt - noise_mu**Expnt

        # beta negative components
        def find_index(x_list):
            index_list = []
            for i in range(len(x_list)):
                if x_list[i] < 0:
                    index_list.append(i)
            return index_list

        z = find_index(diffw)
        if len(z) > 0:
            sub_speech[z] = 0

        # --- implement a simple VAD detector --- #
        if SNRseg < Thres:  # Update noise spectrum
            noise_temp = (
                G * noise_mu**Expnt + (1 - G) * sig**Expnt
            )  # Smoothing processing noise power spectrum
            noise_mu = noise_temp ** (1 / Expnt)  # New noise amplitude spectrum

        # add phase
        x_phase = (sub_speech ** (1 / Expnt)) * np.exp(img * theta)
        # take the IFFT
        xi = np.fft.ifft(x_phase).real

        # --- Overlap and add --- #
        xfinal[k - 1 : k + len2 - 1] = x_old + xi[0:len1]
        x_old = xi[0 + len1 : len_]

        k = k + len2

    xfinal[k - 1 : k + len2 - 1] = x_old

    return rate, winGain * xfinal.astype(clean.dtype)

여기에 새로운 infer 함수를 만들어주세요

메트릭을 측정합니다.

In [26]:
testset_path = "eval_data"


def load_noisy(_rate, _clean, i):
    return wavfile.read(os.path.join(testset_path, "noisy", "noisy_fileid_{}.wav".format(i)))


targets = [
    {"name": "noisy", "infer": load_noisy},
    {"name": "spectral_subtraction", "infer": spectral_subtraction},
]

metrics = {}
for target in targets:
    metric = eval_metric(target["infer"], target["name"], testset_path)
    metrics[target["name"]] = metric
print(metrics)

100%|██████████| 300/300 [01:09<00:00,  4.29it/s]
100%|██████████| 300/300 [01:34<00:00,  3.16it/s]

{'noisy': defaultdict(<class 'float'>, {'pesq_wb': 1.5853275564692964, 'pesq_nb': 2.1636971763316417, 'stoi': 0.9156399918415755, 'rtf': 5.7368310505911805e-05}), 'spectral_subtraction': defaultdict(<class 'float'>, {'pesq_wb': 4.105553290987975, 'pesq_nb': 4.255975185624705, 'stoi': 0.9969548738942908, 'rtf': 0.01921462676669127})}





metrics를 표, 그래프 등으로 시각화합니다

In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np

import numpy as np
import matplotlib.pyplot as plt

# 주어진 데이터

    

# x 축 레이블
labels = ['pesq_nb', 'pesq_wb', 'rtf', 'stoi']

# 각 방법의 데이터를 리스트로 정리
noisy_data = [metrics['noisy'][label] for label in labels]
spectral_data = [metrics['spectral_subtraction'][label] for label in labels]
mmse_data = [metrics['mmse'][label] for label in labels]
wiener_data = [metrics['wiener_filtering'][label] for label in labels]

# 막대의 위치
x = np.arange(len(labels))

# 막대의 너비
width = 0.2

# 플롯 생성
fig, ax = plt.subplots(figsize=(12, 8))

# 각 데이터 방법에 대한 막대 그래프
rects1 = ax.bar(x - 1.5*width, noisy_data, width, label='noisy')
rects2 = ax.bar(x - 0.5*width, spectral_data, width, label='spectral_subtraction')
rects3 = ax.bar(x + 0.5*width, mmse_data, width, label='mmse')
rects4 = ax.bar(x + 1.5*width, wiener_data, width, label='wiener_filtering')

# x축 레이블 설정
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_ylabel('Values')
ax.set_title('Comparison of Different Denoising Methods')
ax.legend()

# 막대 위에 값 표시
def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.2f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

autolabel(rects1)
autolabel(rects2)
autolabel(rects3)
autolabel(rects4)

# 레이아웃 조정
fig.tight_layout()

# 그래프 표시
plt.show()