In [None]:
import torch
import torchaudio
import torchaudio.transforms as tt
import torchaudio.functional as ff
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset

In [None]:
import sys
sys.path.append(r"C:\Users\toviste\Local_Documents\Local_Python\asteroid")
import asteroid.dsp.beamforming as bf

In [None]:
SAMPLE_RATE = 16000
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean_speech.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")

In [None]:
def generate_mixture(waveform_clean, waveform_noise, target_snr):
    power_clean_signal = waveform_clean.pow(2).mean()
    power_noise_signal = waveform_noise.pow(2).mean()
    current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
    waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
    return waveform_clean + waveform_noise

In [None]:
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
# The mixture waveform is a combination of clean and noise waveforms with a desired SNR.
target_snr = 3
waveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)

In [None]:
waveform_mix = waveform_mix.to(torch.double).unsqueeze(0)
waveform_clean = waveform_clean.to(torch.double).unsqueeze(0)
waveform_noise = waveform_noise.to(torch.double).unsqueeze(0)

In [None]:
N_FFT = 1024
N_HOP = 256
stft = tt.Spectrogram(n_fft=N_FFT, hop_length=N_HOP, power=None)
istft = tt.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)

stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)

In [None]:
REFERENCE_CHANNEL = 0

In [None]:
def get_irms(stft_clean, stft_noise):
    mag_clean = stft_clean.abs() ** 2
    mag_noise = stft_noise.abs() ** 2
    irm_speech = mag_clean / (mag_clean + mag_noise)
    irm_noise = mag_noise / (mag_clean + mag_noise)
    return irm_speech[:, REFERENCE_CHANNEL, :, :], irm_noise[:, REFERENCE_CHANNEL, :, :]


irm_speech, irm_noise = get_irms(stft_clean, stft_noise)

In [None]:
scm_speech = bf.compute_scm(x=stft_mix, mask=irm_speech)
scm_noise = bf.compute_scm(x=stft_mix, mask=irm_noise)

In [None]:
beamformer = bf.RTFMVDRBeamformer()

In [None]:
stft_enhanced = beamformer(mix=stft_mix, target_scm=scm_speech, noise_scm=scm_noise, solution='evd')
waveform_enhanced = istft(stft_enhanced, length=waveform_mix.shape[-1])

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(waveform_clean[0, 0, :4*SAMPLE_RATE])
plt.tight_layout()

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(waveform_mix[0, 0, :4*SAMPLE_RATE])
plt.tight_layout()

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(waveform_enhanced[0, :4*SAMPLE_RATE])
plt.tight_layout()

In [None]:
torchaudio.save("enhanced.wav", waveform_enhanced, sr)