In [69]:
!git clone https://github.com/asteroid-team/asteroid

Cloning into 'asteroid'...
remote: Enumerating objects: 7684, done.[K
remote: Counting objects: 100% (307/307), done.[K
remote: Compressing objects: 100% (193/193), done.[K
remote: Total 7684 (delta 143), reused 248 (delta 111), pack-reused 7377[K
Receiving objects: 100% (7684/7684), 5.98 MiB | 4.92 MiB/s, done.
Resolving deltas: 100% (4838/4838), done.


In [97]:
import torch
import numpy as np
from scipy.io.wavfile import read
from asteroid.models import BaseModel
import os

In [176]:
def cal_SISNRi(src_ref, src_est, mix):

    """
        Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
        Args:
            src_ref: numpy.ndarray, [C, T]
            src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
            mix: numpy.ndarray, [T]
        Returns:
            average_SISNRi
    """

    sisnr1 = cal_SISNR(src_ref[0], src_est[0])
    #print("SISNR 1", sisnr1)
    sisnr2 = cal_SISNR(src_ref[1], src_est[1])
    #print("SISNR 2", sisnr2)

    sisnr1b = cal_SISNR(src_ref[0], mix)
    #print("SISNR 1b", sisnr1b)

    sisnr2b = cal_SISNR(src_ref[1], mix)
    #print("SISNR 2b", sisnr2b)

    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2

    return avg_SISNRi

def cal_SISNR(ref_sig, out_sig, eps=1e-8):

    """
        Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
        Args:
            ref_sig: numpy.ndarray, [T]
            out_sig: numpy.ndarray, [T]
        Returns:
            SISNR
    """

    assert len(ref_sig) == len(out_sig)

    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)

    ref_energy = np.sum(ref_sig ** 2) + eps

    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy

    noise = out_sig - proj

    #ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    ratio = ref_energy / (np.sum(noise ** 2) + eps)

    sisnr = 10 * np.log10(ratio + eps)

    return sisnr

def data_prep(path_orig1, path_orig2, path_est_1, path_est_2, path_mix):

    """
        Convert input .wav files to the right numpy arrays for cal_SISNRi

    """

    orig_1 = read(path_orig1)
    orig_2 = read(path_orig2)
    est_1 = read(path_est_1)
    est_2 = read(path_est_2)
    mix = read(path_mix)

    src_ref = np.vstack((np.array(orig_1[1]), np.array(orig_2[1])))
    src_est = np.vstack((np.array(est_1[1]), np.array(est_2[1])))
    mix = np.array([mix[1]])
    mix = np.squeeze(mix)

    return src_ref, src_est, mix

def separate(path_mix):

    """
        Separate mixture file and return the two paths to the separated files
    """
    model.separate(path_mix, force_overwrite=True)
    path_est_1 = "/" + path_mix[1:-4] + '_est1.wav' 
    path_est_2 = "/" + path_mix[1:-4] + '_est2.wav' 
    
    return path_est_1, path_est_2

In [171]:
# Input: Path to folder with mixtures, path to folder with s1, path to folder with s2
path_mixtures = "/content/mix_2/"
path_s1 = "/content/s1/"
path_s2 = "/content/s2/"

In [172]:
# iterate over files in that directory
paths_s1 = []
for filename in os.listdir(path_s1):
    f = os.path.join(path_s1, filename)
    # checking if it is a file
    if os.path.isfile(f):
        #print(f)
        paths_s1.append(f)

paths_s2 = []
for filename in os.listdir(path_s2):
    f = os.path.join(path_s2, filename)
    # checking if it is a file
    if os.path.isfile(f):
        #print(f)
        paths_s2.append(f)

paths_mixtures = []
for filename in os.listdir(path_mixtures):
    f = os.path.join(path_mixtures, filename)
    # checking if it is a file
    if os.path.isfile(f):
        #print(f)
        paths_mixtures.append(f)

In [115]:
# load model

model = BaseModel.from_pretrained("mpariente/DPRNNTasNet-ks2_WHAM_sepclean")



In [177]:
SISNRi = []
for i, path_mix in enumerate(paths_mixtures):
  path_est_1, path_est_2 = separate(paths_mixtures[i])
  src_ref1, src_est1, mix1 = data_prep(paths_s1[i], paths_s2[i], path_est_1, path_est_2, path_mix)
  src_ref2, src_est2, mix2 = data_prep(paths_s1[i], paths_s2[i], path_est_2, path_est_1, path_mix)

  avg_SISNRi_1 = cal_SISNRi(src_ref1, src_est1, mix1)
  avg_SISNRi_2 = cal_SISNRi(src_ref2, src_est2, mix2)

  SISNRi.append(max(avg_SISNRi_1, avg_SISNRi_2))

  print("Progress:", int(((i+1)/len(paths_mixtures))*100), "%")
  #print(SISNRi)

SISNRi = np.array(SISNRi)
print("SISNRi: ", SISNRi.mean())

Progress: 5 %
Progress: 10 %
Progress: 15 %
Progress: 20 %
Progress: 25 %
Progress: 30 %
Progress: 35 %
Progress: 40 %
Progress: 45 %
Progress: 50 %
Progress: 55 %
Progress: 60 %
Progress: 65 %
Progress: 70 %
Progress: 75 %
Progress: 80 %
Progress: 85 %
Progress: 90 %
Progress: 95 %
Progress: 100 %
SISNRi:  13.25061246981647
