In [2]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal

from IPython.display import display, Audio

from pesq import pesq

def si_sdr(s, s_hat):
    alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2   
    sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
        alpha*s - s_hat)**2)
    return sdr


In [3]:
speech, sr = torchaudio.load("/data/lemercier/databases/VCTK-Corpus/wav16/p225/p225_001.wav")
rir, sr = torchaudio.load("/data/lemercier/databases/rir_simulated_aligned/rir/tt/50_0.56.wav")

In [4]:
# Classical convolution

# reverberant = torch.nn.functional.conv1d(speech, torch.flip(rir.unsqueeze(0), dims=(-1,)), padding=)
reverberant = np.convolve(speech.squeeze(), rir.squeeze())[np.newaxis, ...]

Audio(speech, rate=sr)
Audio(reverberant, rate=sr)

In [5]:
# FFT convolution

reverberant_approx = scipy.signal.fftconvolve(speech.squeeze(), rir.squeeze())[np.newaxis, ...]

Audio(speech, rate=sr)
Audio(reverberant_approx, rate=sr)
print(pesq(16000, reverberant.squeeze(), reverberant_approx.squeeze()))
print(si_sdr(reverberant[..., :reverberant_approx.shape[-1]].squeeze(), reverberant_approx.squeeze()))

4.643888473510742
130.56555417862631


In [29]:
# Subband convolution

# stft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.hann_window(128), "return_complex": True}
# istft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.hann_window(128)}

# stft_kwargs = {"n_fft": 512, "hop_length": 128, "window": torch.hann_window(512), "return_complex": True}
# istft_kwargs = {"n_fft": 512, "hop_length": 128, "window": torch.hann_window(512)}

l = 2**3
stft_kwargs = {"n_fft": l, "hop_length": l//4, "window": torch.hann_window(l), "return_complex": True, "center": True}
istft_kwargs = {"n_fft": l, "hop_length": l//4, "window": torch.hann_window(l), "center": True}

# stft_kwargs = {"n_fft": 2**12, "hop_length": 2**10, "window": torch.hann_window(2**12), "return_complex": True}
# istft_kwargs = {"n_fft": 2**12, "hop_length": 2**10, "window": torch.hann_window(2**12)}

# l = 2**14
# stft_kwargs = {"n_fft": l, "hop_length": l//4, "window": torch.hann_window(l), "return_complex": True, "center": True}
# istft_kwargs = {"n_fft": l, "hop_length": l//4, "window": torch.hann_window(l), "center": True}

# l = 2**14
# stft_kwargs = {"n_fft": l, "hop_length": l, "return_complex": True}
# istft_kwargs = {"n_fft": l, "hop_length": l}

rir_subbands = torch.stft(rir.squeeze(), **stft_kwargs)
speech_subbands = torch.stft(speech.squeeze(), **stft_kwargs)

reverberant_subbands = []
for rir_subband, speech_subband in zip(rir_subbands, speech_subbands):
    reverberant_subbands.append( np.convolve(rir_subband, speech_subband) )

reverberant_subbands = np.stack(reverberant_subbands)


real_reverberant_subbands = torch.stft(torch.from_numpy(reverberant), **stft_kwargs)
print( (real_reverberant_subbands[..., :reverberant_subbands.shape[-1]] - torch.from_numpy(reverberant_subbands)).square().abs().mean().sqrt() )

reverberant_approx = torch.istft(torch.from_numpy(reverberant_subbands), onesided=True, **istft_kwargs)

print(pesq(16000, reverberant.squeeze(), reverberant_approx.squeeze().numpy()))
print(si_sdr(reverberant[..., :reverberant_approx.size(-1)].squeeze(), reverberant_approx.squeeze().numpy()))
Audio(speech, rate=sr)
Audio(reverberant_approx, rate=sr)

tensor(0.1532)
4.472358226776123
3.6331546393300607


In [19]:
# Convolution of STFT with fullband RIR

# stft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.hann_window(128), "return_complex": True}
# istft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.hann_window(128)}

# stft_kwargs = {"n_fft": 512, "hop_length": 128, "window": torch.hann_window(512), "return_complex": True}
# istft_kwargs = {"n_fft": 512, "hop_length": 128, "window": torch.hann_window(512)}

l = 2**14
# stft_kwargs = {"n_fft": l, "hop_length": l//2, "window": torch.hann_window(l), "return_complex": True, "center": True}
# istft_kwargs = {"n_fft": l, "hop_length": l//2, "window": torch.hann_window(l), "center": True}

stft_kwargs = {"n_fft": l, "hop_length": l//4, "window": torch.sqrt(torch.hann_window(l)), "return_complex": True, "center": True}
istft_kwargs = {"n_fft": l, "hop_length": l//4, "window": torch.sqrt(torch.hann_window(l)), "center": True}

# stft_kwargs = {"n_fft": l, "hop_length": l, "return_complex": True}
# istft_kwargs = {"n_fft": l, "hop_length": l}

rir_fft = torch.fft.rfft(rir.squeeze(), n=l).unsqueeze(-1)
speech_subbands = torch.stft(speech.squeeze(), **stft_kwargs)

reverberant_subbands = rir_fft * speech_subbands

# plt.imshow(reverberant_subbands.squeeze().abs(), origin="lower")

reverberant_approx = torch.istft(reverberant_subbands, onesided=True, **istft_kwargs)

print(pesq(16000, reverberant.squeeze(), reverberant_approx.squeeze().numpy()))
print(si_sdr(reverberant[..., :reverberant_approx.size(-1)].squeeze(), reverberant_approx.squeeze().numpy()))
Audio(speech, rate=sr)
Audio(reverberant_approx, rate=sr)

2.805319309234619
18.44198067516163


In [40]:
# Subband convolution with zero-phase subband filter

stft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.hann_window(128), "return_complex": True}
istft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.hann_window(128)}

# stft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.sqrt(torch.hann_window(128)), "return_complex": True}
# istft_kwargs = {"n_fft": 128, "hop_length": 32, "window": torch.sqrt(torch.hann_window(128))}

rir_subbands = torch.stft(rir.squeeze(), **stft_kwargs).abs()
speech_subbands = torch.stft(speech.squeeze(), **stft_kwargs)

reverberant_subbands = []
for rir_subband, speech_subband in zip(rir_subbands, speech_subbands):
    reverberant_subbands.append( np.convolve(rir_subband, speech_subband) )

reverberant_subbands = np.stack(reverberant_subbands)
reverberant_approx = torch.istft(torch.from_numpy(reverberant_subbands), onesided=True, **istft_kwargs)

Audio(speech, rate=sr)
Audio(reverberant_approx, rate=sr)
print(pesq(16000, reverberant.squeeze(), reverberant_approx.squeeze().numpy()))

1.9781622886657715
