In [None]:
import torch
from model import *
from data import *
import pandas as pd
import numpy as np
import torchaudio
from torch.utils.data import DataLoader
from tqdm import tqdm
tqdm.pandas()
from scipy.signal import wiener
import torch.nn.functional as F

from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility

In [None]:
DTLN = Pytorch_DTLN(frame_len=1536, 
                         frame_hop=384, 
                         dropout=0.3,
                         encoder_size=1024,
                         hidden_size=512,
                         LSTM_size=2)
DTLN.load_state_dict(torch.load('./48khz_parameter_epoch128.pth',map_location=torch.device('cpu')))

In [None]:
def ssim(signal_1, signal_2):
    mu_1 = signal_1.mean()
    mu_2 = signal_2.mean()
    var_1 = signal_1.var()
    var_2 = signal_2.var()
    cov_1_2 = np.cov([signal_1, signal_2])[1][0]
    L = abs(max(max(signal_1),max(signal_2)) -min(min(signal_1),min(signal_2)))
    k_1 = 0.01
    k_2 = 0.03
    c_1 = (k_1 * L)**2
    c_2 = (k_2 * L)**2

    l = (2*mu_1*mu_2 + c_1)/(mu_1**2 + mu_2**2 + c_1)
    c = (2*(var_1**(1/2))*(var_2**(1/2)) + c_2)/(var_1 + var_2 + c_2)
    s = (cov_1_2 + c_2/2)/((var_1**(1/2))*(var_2**(1/2))+c_2/2)

    return l*c*s

In [None]:
def calculate_metrics_16khz(clean_path, noisy_path):
    # cut signal into frames of length, pass them to model and glue it back
    clean_signal, _ = torchaudio.load(clean_path)
    if _ != 16000:
        clean_signal = torchaudio.transforms.Resample(_,16000)(clean_signal).flatten()
    else:
        clean_signal = clean_signal.flatten()
    noisy_signal, _ = torchaudio.load(noisy_path)
    if _ != 16000:
        noisy_signal = torchaudio.transforms.Resample(_,16000)(noisy_signal).flatten()

    # pad signal with zeros to have full frames to cut
    pad_size = (noisy_signal.shape[0]//2**15 + 1)*(2**15) - noisy_signal.shape[0]
    noisy_padded = F.pad(input=noisy_signal, pad=(0,pad_size), mode='constant', value=0)

    # crop signal and process it piece by piece
    denoised_signal = torch.tensor([])
    for frame in np.arange(0, (noisy_padded.shape[0]//2**15)*(2**15), step = 2**15):
        if denoised_signal.shape[0]==0:
            denoised_signal = DTLN(noisy_padded[frame:frame+2**15].unsqueeze(0))[0]
        else:
            denoised_signal = torch.cat((denoised_signal,DTLN(noisy_padded[frame:frame+2**15].unsqueeze(0))[0]))

    denoised_signal = denoised_signal[:noisy_signal.shape[0]]    

    # PESQ
    wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
    denoised_pesq = wb_pesq(denoised_signal, clean_signal)

    # STOI
    denoised_stoi = short_time_objective_intelligibility(denoised_signal, clean_signal, 16000)

    # SSIM
    denoised_ssim = ssim(clean_signal.detach().numpy(), denoised_signal.detach().numpy())
    
    return [denoised_pesq.item(), 
            denoised_stoi.item(), 
            denoised_ssim]

In [None]:
def calculate_metrics_48khz(clean_path, noisy_path):
    # cut signal into frames of length, pass them to model and glue it back
    clean_signal, _ = torchaudio.load(clean_path)
    if _ != 48000:
        clean_signal = torchaudio.transforms.Resample(_,48000)(clean_signal).flatten()
    else:
        clean_signal = clean_signal.flatten()
    noisy_signal, _ = torchaudio.load(noisy_path)
    if _ != 48000:
        noisy_signal = torchaudio.transforms.Resample(_,48000)(noisy_signal).flatten()
    else:
        noisy_signal = noisy_signal.flatten()

    # pad signal with zeros to have full frames to cut
    pad_size = (noisy_signal.shape[0]//2**15 + 1)*(2**15) - noisy_signal.shape[0]
    noisy_padded = F.pad(input=noisy_signal, pad=(0,pad_size), mode='constant', value=0)

    # crop signal and process it piece by piece
    denoised_signal = torch.tensor([])
    for frame in np.arange(0, (noisy_padded.shape[0]//2**15)*(2**15), step = 2**15):
        if denoised_signal.shape[0]==0:
            denoised_signal = DTLN(noisy_padded[frame:frame+2**15].unsqueeze(0))[0]
        else:
            denoised_signal = torch.cat((denoised_signal,DTLN(noisy_padded[frame:frame+2**15].unsqueeze(0))[0]))

    denoised_signal = denoised_signal[:noisy_signal.shape[0]]    

    # PESQ
    wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
    denoised_pesq = wb_pesq(torchaudio.transforms.Resample(_,16000)(denoised_signal).flatten(), 
                            torchaudio.transforms.Resample(_,16000)(clean_signal).flatten())

    # STOI
    denoised_stoi = short_time_objective_intelligibility(denoised_signal, clean_signal, 48000)

    # SSIM
    denoised_ssim = ssim(clean_signal.detach().numpy(), denoised_signal.detach().numpy())
    
    return [denoised_pesq.item(), 
            denoised_stoi.item(), 
            denoised_ssim]

In [None]:
# 16khz
clean_path = ''
noisy_path = ''
pesq, stoi, ssim = calculate_metrics_16khz(clean_path, noisy_path)

In [None]:
# 48khz
clean_path = ''
noisy_path = ''
pesq, stoi, ssim = calculate_metrics_48khz(clean_path, noisy_path)