In [1]:
import torch
import torchaudio
import torch.nn.functional as F
from main import load_cfg
from build import LightningNet
import pysepm
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

Global seed set to 0


In [7]:
model = LightningNet.load_from_checkpoint('log/DNS/version_0/epoch=380.ckpt').cuda()
model.eval()
print()




### Evaluate on DNS dev test_set no_reverb

In [10]:
def take_last(name):
    name = name.split('.')[0]
    id = name.split('_')[-1]
    return id

clean_files = os.listdir('/mnt/g/WQR/DNS_datasets/test_set/synthetic/no_reverb/clean')
noisy_files = os.listdir('/mnt/g/WQR/DNS_datasets/test_set/synthetic/no_reverb/noisy')
clean_files.sort(key=take_last)
noisy_files.sort(key=take_last)

csig_list, cbak_list, covl_list = [], [], []
pesq_list, stoi_list = [], []
for i, (clean_name, noisy_name) in tqdm(enumerate(zip(clean_files, noisy_files))):
    noisy_wav = torchaudio.load(
            os.path.join('/mnt/g/WQR/DNS_datasets/test_set/synthetic/no_reverb/noisy', noisy_name))[0][0].cuda()
    target_wav = torchaudio.load(
            os.path.join('/mnt/g/WQR/DNS_datasets/test_set/synthetic/no_reverb/clean', clean_name))[0][0].cuda()
    target_wav = target_wav.detach().cpu().numpy()
    
    if len(noisy_wav) <= 63744:
        real_length = len(noisy_wav)
        noisy_wav = F.pad(noisy_wav, (0, 63744-len(noisy_wav)))

        enhanced_wav = model(noisy_wav.unsqueeze(0))
        enhanced_wav = enhanced_wav.squeeze()[:real_length].detach().cpu().numpy()

    else:
        num_clips = len(noisy_wav)//63744 + 1
        noisy_wav_long = F.pad(noisy_wav, (0, (num_clips+1) * 63744 - len(noisy_wav)))
        enhanced_clips = []
        for j in range(num_clips+1):
            noisy_clip = noisy_wav_long[j * 63744 : (j+1) * 63744]
            enhanced_clip = model(noisy_clip.unsqueeze(0))
            enhanced_clip = enhanced_clip.squeeze().detach().cpu().numpy()
            enhanced_clips.append(enhanced_clip)
        enhanced_wav = np.zeros(((num_clips+1)*63744,))
        for k, clip in enumerate(enhanced_clips):
            enhanced_wav[k * 63744 : (k + 1) * 63744] = clip
        enhanced_wav = enhanced_wav[:len(noisy_wav)]
    
    # if i==10:
    #     plt.figure(figsize=(10,8))
    #     plt.plot(target_wav)
    #     plt.plot(enhanced_wav + 1)
    #     plt.grid()
    #     plt.show()
    #     assert False
    noisy_wav = noisy_wav.detach().cpu().numpy()
    csig, cbak, covl = pysepm.composite(target_wav, enhanced_wav, 16000)
    pesq = pysepm.pesq(target_wav, enhanced_wav, 16000)[1]
    stoi = pysepm.stoi(target_wav,enhanced_wav,16000)

    csig_list.append(csig)
    cbak_list.append(cbak)
    covl_list.append(covl)
    pesq_list.append(pesq)
    stoi_list.append(stoi)

def mean(lst):
    return sum(lst) / len(lst)

print(mean(pesq_list))
print(mean(csig_list))
print(mean(cbak_list))
print(mean(covl_list))
print(mean(stoi_list))

150it [01:32,  1.62it/s]

2.0135125796000164
3.654630264974779
2.8993064844403107
2.8363363047181758
0.8250538345700641





### Result on DNS
no_reverb: 2.83, 96.6


### Evaluate on VBD

In [3]:
csig_list, cbak_list, covl_list = [], [], []
pesq_list, stoi_list = [], []
clean_list = os.listdir('/home/wqr/vbd/clean_testset_wav')
noisy_list = os.listdir('/home/wqr/vbd/noisy_testset_wav')
clean_list.sort()
noisy_list.sort()
for i, (clean_name, noisy_name) in tqdm(enumerate(zip(clean_list, noisy_list))):
    noisy_wav = torchaudio.load(os.path.join('/home/wqr/vbd/noisy_testset_wav', noisy_name))[0][0].cuda()
    target_wav = torchaudio.load(os.path.join('/home/wqr/vbd/clean_testset_wav', clean_name))[0][0].cuda()
    target_wav = target_wav.detach().cpu().numpy()

    if len(noisy_wav) <= 63744:
        real_length = len(noisy_wav)
        noisy_wav = F.pad(noisy_wav, (0, 63744-len(noisy_wav)))

        enhanced_wav = model(noisy_wav.unsqueeze(0))
        enhanced_wav = enhanced_wav.squeeze()[:real_length].detach().cpu().numpy()

    else:
        num_clips = len(noisy_wav)//63744 + 1
        noisy_wav_long = F.pad(noisy_wav, (0, (num_clips+1) * 63744 - len(noisy_wav)))
        enhanced_clips = []
        for j in range(num_clips+1):
            noisy_clip = noisy_wav_long[j * 63744 : (j+1) * 63744]
            enhanced_clip = model(noisy_clip.unsqueeze(0))
            enhanced_clip = enhanced_clip.squeeze().detach().cpu().numpy()
            enhanced_clips.append(enhanced_clip)
        enhanced_wav = np.zeros(((num_clips+1)*63744,))
        for k, clip in enumerate(enhanced_clips):
            enhanced_wav[k * 63744 : (k + 1) * 63744] = clip
        enhanced_wav = enhanced_wav[:len(noisy_wav)]

    # if i==2:
    #     plt.figure(figsize=(15,10))
    #     plt.plot(target_wav)
    #     plt.plot(enhanced_wav + 1)
    #     plt.grid()
    #     plt.show()
    #     assert False
    csig, cbak, covl = pysepm.composite(target_wav, enhanced_wav, 16000)
    pesq = pysepm.pesq(target_wav, enhanced_wav, 16000)[1]
    stoi = pysepm.stoi(target_wav,enhanced_wav,16000)

    csig_list.append(csig)
    cbak_list.append(cbak)
    covl_list.append(covl)
    pesq_list.append(pesq)
    stoi_list.append(stoi)


def mean(lst):
    return sum(lst) / len(lst)

print(mean(pesq_list))
print(mean(csig_list))
print(mean(cbak_list))
print(mean(covl_list))
print(mean(stoi_list))

824it [02:27,  5.59it/s]

3.061062902790829
4.4385565697845495
3.501054016143857
3.780937904295498





### Result on VBD
p: 3.06  
s: 4.44  
b: 3.50  
o: 3.78

### solve pretrained weight keys

In [3]:
model = torch.load('log/DNS/version_0/epoch=380.ckpt')['state_dict']
d = {}
for key in model.keys():
    newkey = key.split('.',1)[-1]
    value = model[key]
    d[newkey] = value
torch.save(d,'log/DNS/version_0/DNS_pretrained_model.pth')