In [None]:
import torch
import torch.nn as nn
import torchaudio
import pandas as pd
import numpy as np
import librosa as lr
from tqdm import tqdm
tqdm.pandas()
import torch.nn.functional as F

from generator import Generator

from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility

In [None]:
generator = Generator()
generator.load_state_dict(torch.load('./SEGAN_generator_weights.pkl',map_location=torch.device('cpu')))

In [None]:
def ssim(signal_1, signal_2):
    mu_1 = signal_1.mean()
    mu_2 = signal_2.mean()
    var_1 = signal_1.var()
    var_2 = signal_2.var()
    cov_1_2 = np.cov([signal_1, signal_2])[1][0]
    L = abs(max(max(signal_1),max(signal_2)) -min(min(signal_1),min(signal_2)))
    k_1 = 0.01
    k_2 = 0.03
    c_1 = (k_1 * L)**2
    c_2 = (k_2 * L)**2

    l = (2*mu_1*mu_2 + c_1)/(mu_1**2 + mu_2**2 + c_1)
    c = (2*(var_1**(1/2))*(var_2**(1/2)) + c_2)/(var_1 + var_2 + c_2)
    s = (cov_1_2 + c_2/2)/((var_1**(1/2))*(var_2**(1/2))+c_2/2)

    return l*c*s

In [None]:
def calculate_metrics(clean_path, noisy_path):
    # cut signal into frames of length, pass them to model and glue it back
    clean_signal, _ = torchaudio.load(clean_path)
    if _ != 16000:
        clean_signal = torchaudio.transforms.Resample(_,16000)(clean_signal).flatten()
    else:
        clean_signal = clean_signal.flatten()
    noisy_signal, _ = torchaudio.load(noisy_path)
    if _ != 16000:
        noisy_signal = torchaudio.transforms.Resample(_,16000)(noisy_signal).flatten()

    # pad signal with zeros to have full frames to cut
    pad_size = (noisy_signal.shape[0]//2**14 + 1)*(2**14) - noisy_signal.shape[0]
    noisy_padded = F.pad(input=noisy_signal, pad=(0,pad_size), mode='constant', value=0)

    z = nn.init.normal_(torch.Tensor(1, 1024, 8))

    # crop signal and process it piece by piece
    denoised_signal = torch.tensor([])
    for frame in np.arange(0, (noisy_padded.shape[0]//2**14)*(2**14), step = 2**14):
        if denoised_signal.shape[0]==0:
            denoised_signal = generator(noisy_padded[frame:frame+2**14].unsqueeze(0).unsqueeze(0), z)[0][0]
        else:
            denoised_signal = torch.cat((denoised_signal,generator(noisy_padded[frame:frame+2**14].unsqueeze(0).unsqueeze(0), z)[0][0]))

    denoised_signal = denoised_signal[:noisy_signal.shape[0]]   


    # PESQ
    wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
    interim_denoised_pesq = wb_pesq(denoised_signal, clean_signal)

    # STOI
    denoised_stoi = short_time_objective_intelligibility(denoised_signal, clean_signal, 16000)

    # SSIM
    denoised_ssim = ssim(clean_signal.detach().numpy(), denoised_signal.detach().numpy())
    
    return [interim_denoised_pesq.item(), 
            denoised_stoi.item(), 
            denoised_ssim]

In [None]:
noisy_path = ''
clean_path = ''
pesq, stoi, ssim = calculate_metrics(clean_path, noisy_path)