# Homework description

In this assignment we are going to train a neural network noise reduction.

The total cost of this assignment is 15 pts.

## Plan:
0. Datasets
1. Volume normalization, gain, RMS: everything we need to mix signal with noise [2 points]
2. Room impulse response (RIR): what we need to simulated acoustics and perform partial dereverberation [1 point]
3. On-the-fly data generation [5 points]
4. Neural network architecture [3 points]
5. Loss function [1 point]
6. Train Loop [3 points]
7. Streaming implementation for the neural network [bonus, 3 points]

## A homework submission should include:
1. filled notebook
2. tensorboard logs
3. 5 examples of input-output files from the trained model in .wav format

# 0. Datasets

We are going to use clean speech and room impluse responses from DNS Challenge dataset.
For speech it is random subsample, for RIR we shall use the full smallroom partition to avoid extreme reverberation levels.

Originally DNS Challenge data comes in 48 kHz sample rate. We [down-sampled](https://librosa.org/doc/0.10.1/generated/librosa.resample.html#librosa-resample) it to 16 kHz in advance.

For noise we are going to use Musan dataset. Why not DNS Challenge? Training progress will be seen faster with Musan.

Let's download the data:

In [None]:
from io import BytesIO
import os
import requests
from urllib.parse import urlencode
from zipfile import ZipFile

base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
public_key = 'https://disk.yandex.ru/d/ECHrgBGJrrGQqw'

final_url = base_url + urlencode(dict(public_key=public_key))
response = requests.get(final_url)
download_url = response.json()['href']
response = requests.get(download_url)

path_to_dataset = 'data/homework_16_kHz'    # Choose any appropriate local path

zipfile = ZipFile(BytesIO(response.content))
zipfile.extractall(path=path_to_dataset)

In [None]:
import os

# load_data

ROOT_DATA = os.path.join(path_to_dataset, "homework_1_16kHz")

DATA_PATHS = {
    "speech": os.path.join(ROOT_DATA, "clean_train"),
    "noise": os.path.join(ROOT_DATA, "musan/noise"),
    "rir": os.path.join(ROOT_DATA, "impulse_responses_all/SLR26/simulated_rirs_48k/smallroom"),
}

How many audio files do we have?

In [None]:
from glob import glob


def list_wavs_in_folder_recursively(path: str) -> list[str]:
    return sorted(glob(os.path.join(path, "**", "*.wav"), recursive=True))

for key, folder in DATA_PATHS.items():
    paths = list_wavs_in_folder_recursively(folder)
    print(f"{key}: {len(paths)}")

In [None]:
from __future__ import annotations

import librosa  # to plot mel-spectrograms
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import scipy.signal as sig
from tqdm.notebook import tqdm

SR = 16_000

In [None]:
def build_spec_for_plot(waveform):
    mel_spec = librosa.feature.melspectrogram(y=waveform, sr=SR)
    min_val = 1e-10
    mel_spec = np.clip(mel_spec, min_val, None)
    mel_spec[-1, -1] = min_val
    mel_spec_db = 10 * np.log10(mel_spec)
#     return mel_spec_db
    return np.flip(mel_spec_db, axis=0)

**Let's select sample files**

In [None]:
rel_path = "read_speech/book_00007_chp_0008_reader_01326_64_seg_2.wav"
SAMPLE_SIGNAL, _sr = sf.read(os.path.join(DATA_PATHS["speech"], rel_path))

rel_path = "free-sound/noise-free-sound-0001.wav"
SAMPLE_NOISE, _sr = sf.read(os.path.join(DATA_PATHS["noise"], rel_path))

SAMPLE_RIR, _sr = sf.read(os.path.join(
    ROOT_DATA,
    "impulse_responses_all/SLR28/RIRS_NOISES/real_rirs_isotropic_noises/air_type1_air_binaural_office_0_1_1ch.wav",
))

# 1. Volume normalization, gain, RMS: everything we need to mix signal with noise [2 points]

## RMS-full-scale

Let's suppose we have an audio signal $x=(x[0], x[1], ..., x[T-1])$.

A basic measure of it's loundness would be its L2-norm. It is also referred to as RMS (root-mean-square):
$$\text{rms}_{\text{raw}}(x) = ||x|| = ||x||_2 = \sqrt{\frac{1}{T}\sum_{t=0}^{T-1} x[t]^2}$$

It is convenient to express RMS in decibels. Decibels involve logarithm computation and are only applicable to dimensionless physical quantities, typically to ratios.

So we need to define a reference value to normalize by. As we live in world where signals are represented with floating-point values ranged between -1 and 1, 2 options are typically adopted:

$$\text{rms}_{\text{ref}, \sin} = \int_{0}^{2\pi} sin^2(t)dt = 0.5$$
which corresponds to rms of a sine wave , or
$$\text{rms}_{\text{ref, square}} = 1$$
the latter corresponds to rms of a [square wave](https://en.wikipedia.org/wiki/Square_wave).

Both options are used, leading to confusion in the industry.

In this assignment **let's use:**
$$\text{rms}_{\text{ref}} = \text{rms}_{\text{ref, square}} = 1$$

Thus we get:

$$\text{rms}_\text{dB}(x) = 20\log_{10}\frac{||x||}{\text{rms}_{\text{ref, square}}} = 20\log_{10}||x|| = 10\log_{10}\frac{1}{T}\sum_{t=0}^{T-1}x[t]^2$$

**The latter is called RMS-full-scale (or RMS-fs)**, i.e. RMS relative to full scale of 1.

In [None]:
def eval_mean_square(x: np.ndarray) -> float:
    """
    Computes mean-square of x
    """
    raise NotImplementedError("Your code")


def power_to_db(x: float) -> float:
    """
    Computes 10log10(x)
    """
    raise NotImplementedError("Your code")


def eval_rms_db(x: np.ndarray) -> float:
    """
    Computes rms-fs of x
    """
    rms_square_raw = # your code
    rms_db = # your code
    return rms_db

In [None]:
eps = 1e-12

assert abs(power_to_db(0.01) + 20) < eps
assert abs(power_to_db(0.1) + 10) < eps
assert abs(power_to_db(1) - 0) < eps
assert abs(power_to_db(10) - 10) < eps
assert abs(power_to_db(100) - 20) < eps

assert abs(eval_rms_db(np.ones(999) * 0.1) + 20) < eps
assert abs(eval_rms_db(np.ones(999) * 1) - 0) < eps
assert abs(eval_rms_db(np.ones(531) * 10) - 20) < eps

## Gain, normalization

When a signal $x$ is multiplied by a scalar factor of $\alpha \geq 0$, it corresponds to addition in the world of decibels:
$$\text{rms}_\text{dB}(\alpha x) = 20 \log_{10} ||\alpha x|| = 20\log_{10}\alpha + 20\log_{10} ||x|| = 20\log_{10}\alpha + \text{rms}_\text{dB}(x)$$

The multiplication of $x$ by $\alpha \geq 0$ is often referred to as **gain by $\boldsymbol{G}$ dB**, where $G = 20\log_{10}\alpha$, which can be both positive or negative (or even infinite negative if $\alpha=0$).

The inverse relationship can be inferred to express the scalar factor from the gain in decibels:

In [None]:
def gain_to_mult(gain_db: float) -> float:
    """
    Finds the positive scalar factor which corresponds to given gain_db
    """
    raise NotImplementedError("Your code")
    
    
def mult_to_gain(mult: float) -> float:
    """
    Finds the gain in dB from positive scalar factor
    """

In [None]:
eps = 1e-12

assert abs(gain_to_mult(-20) - 0.1) < eps
assert abs(gain_to_mult(0) - 1) < eps
assert abs(gain_to_mult(20) - 10) < eps

for mult in [1., 0.265, 10.5]:
    gain_db = mult_to_gain(mult)
    mult_power = mult ** 2  # if x is multiplied by mult, x ** 2 is multiplied by mult ** 2
    gain_db_from_power = power_to_db(mult_power)
    assert abs(gain_db_from_power - gain_db) < eps
    mult_restored = gain_to_mult(gain_db)
    assert abs(mult - mult_restored) < eps

Now we know how to apply gain in decibels:

In [None]:
def apply_gain(x: np.ndarray, gain_db: float) -> np.ndarray:
    mult = # your code
    result = # your code
    return result

In [None]:
x = np.ones(10)

assert np.allclose(apply_gain(x, -20), x / 10)
assert np.allclose(apply_gain(x, 0), x)
assert np.allclose(apply_gain(x, 20), x * 10)


for _ in range(100):
    x = np.random.uniform(-1, 1, 16_000)
    gain = np.random.uniform(-20, 20)
    rms_before = eval_rms_db(x)
    rms_expected = rms_before + gain
    
    gained = apply_gain(x, gain)
    rms_after = eval_rms_db(gained)
    
    assert abs(rms_after - rms_expected) < 1e-12

Let's implement a function which will normalize an input signal to a desired level of $\text{rms}_\text{dB}$.

How should it work?

1. Calculate $\text{rms}_\text{dB}(\text{signal})$
2. Calculate gain in decibels: $g = \text{rms}_\text{dB, target} - \text{rms}_\text{dB}(\text{signal})$
3. Apply the gain to the signal

In [None]:
def normalize_to_rms(x: np.ndarray, target_rms_db: float) -> np.ndarray:
    """
    Normalizes signal x to target_rms_db
    """
    raise NotImplementedError("Your Code")

In [None]:
for _ in range(100):
    x = np.random.uniform(-1, 1, 16_000)
    target_rms_db = np.random.uniform(-20, 20)
    normalized = normalize_to_rms(x, target_rms_db)
    rms_after_normalization = eval_rms_db(normalized)
    assert abs(target_rms_db - rms_after_normalization) < 1e-12
    
print("Ok")

**Let's play with it:**

In [None]:
x = np.copy(SAMPLE_SIGNAL)

gain_db = -10
x_gained = apply_gain(x, gain_db)
x_normalized = normalize_to_rms(x, -20)

_, ax = plt.subplots()
ax.plot(x_normalized, label="rms: -20 dB")
ax.plot(x, label="raw")
ax.plot(x_gained, label=f"gained by {gain_db} dB")
ax.legend()
ax.grid()
plt.show()

### SNR

SNR (signal-to-noise ratio) is expressed in decibels and is defined as:

$$\text{SNR} = 10\log_{10}\frac{||\text{signal}||^2}{||\text{noise}||^2} = 10\log_{10}||\text{signal}||^2 - 10\log_{10}||\text{noise}||^2 = \text{rms}_{\text{dB}}(\text{signal}) - \text{rms}_{\text{dB}}(\text{noise})$$

Also, **SNR can be used as a quality metrics** or even a loss function for gradient descent.

Given a ground truth signal $y$ and its estimate $\hat y$, we define noise as $\hat y - y$. Slightly abusing notation we get:

$$\text{SNR}(\hat y, y) = 10 \log_{10} \frac{||\hat y - y||^2}{||y||^2}$$

In [None]:
def eval_snr(estimate: np.ndarray, signal: np.ndarray) -> float:
    """
    evaluates SNR as a quality metrics
    """
    raise NotImplementedError("Your code")

In [None]:
for _ in range(100):

    signal = np.random.uniform(-1, 1, 16_000)
    noise = np.random.uniform(-1, 1, 16_000) * np.random.uniform(0.3, 2)
    mixture = signal + noise
    snr = eval_rms_db(signal) - eval_rms_db(noise)
    snr_est = eval_snr(mixture, signal)
    assert abs(snr - snr_est) < 1e-12

Now we have everything to generate a mixture of speech and noise with defined signal loudness (RMS) and SNR:

1. Normalize signal to $\text{rms}_\text{target, signal}$
2. Noramlize noise to $\text{rms}_\text{target, noise} = \text{rms}_\text{target, signal} - \text{SNR}$
3. Add noise to signal. What if shapes don't match? Let's just assume they match and enforce it outside the function.

In [None]:
def mix_speech_with_noise(signal, noise, rms_signal, snr):
    signal_normalized = # your code
    rms_noise_target = # your code
    noise_normalized = # your code
    mixture = signal + noise
    return mixture

In [None]:
for _ in range(100):
    signal = np.random.uniform(-1, 1, 16_000)
    noise = np.random.uniform(-1, 1, 16_000) * np.random.uniform(0.3, 2)
    rms_signal = np.random.uniform(-20, 20)
    snr = np.random.uniform(-20, 20)
    mixture = mix_speech_with_noise(signal, noise, rms_signal, snr)
    
    signal_gained = normalize_to_rms(signal, rms_signal)
    snr_est = eval_snr(mixture, signal_gained)

    assert abs(snr - snr_est) < 1e-12
    
print("Ok")

In [None]:
for snr in [-10, 10]:
    min_len = min(len(SAMPLE_SIGNAL), len(SAMPLE_NOISE))
    signal = SAMPLE_SIGNAL[:min_len]
    noise = SAMPLE_NOISE[:min_len]

    mixture = mix_speech_with_noise(signal, noise, -20, snr)

    spec = build_spec_for_plot(mixture)
    _, ax = plt.subplots()
    ax.set_title(f"SNR: {snr} dB")
    ax.imshow(spec)
    plt.show()

# 2. Room impulse response (RIR): what we need to simulated acoustics and perform partial dereverberation [1 point]

The common approach to simulate acoustics is convolving signal with room impulse response (RIR).

It follows from the linear acoustic model (which is accurate enough to be used in practice) and the assumption of time-invariance (i.e. that room acoustics does not change over time or it changes slowly).

For input signal $x$ and RIR $r$:

$$x_{\text{reverberated}} = x * r$$

A RIR is defined as the reverberated version of the unit impulse, i.e. the (1, 0, 0, 0, ...) signal.

An RIR can be listened to and it sounds like a click.

**Let's take a look at an impulse response.**

This will be a real impulse response from a relatively highly reverberant environment.

In [None]:
rir = np.copy(SAMPLE_RIR)

_, ax = plt.subplots()
ax.plot(rir)
ax.grid()
plt.show()

**This is how RIR is convolved with a signal**

In [None]:
convolved = sig.convolve(SAMPLE_SIGNAL, SAMPLE_RIR, mode="full")

In [None]:
_, axes = plt.subplots(nrows=2, figsize=(16, 7), sharex=True, sharey=True)
for ax, (name, data) in zip(axes, [
    ("raw", SAMPLE_SIGNAL),
    ("convolved", convolved),
]):
    ax.set_title(name)
    ax.plot(data)
    ax.grid()
plt.show()

_, axes = plt.subplots(nrows=2, figsize=(16, 7), sharex=True, sharey=True)
for ax, (name, data) in zip(axes, [
    ("raw", SAMPLE_SIGNAL),
    ("convolved", convolved),
]):
    spec = build_spec_for_plot(data)
    ax.set_title(name)
    ax.imshow(spec)
    ax.set_aspect("auto")
plt.show()

**Note the shapes:**

In [None]:
assert len(convolved) == len(SAMPLE_SIGNAL) + len(SAMPLE_RIR) - 1

The output file is longer, and as it can be seen from the spectrum, the appended length maingly consists of reverberation tail.

**Let's take a closer look what a RIR looks like**

**These are helper functinons** which evaluate windowed power of a signal (RIR) and plot it in the dB scale:

In [None]:
def get_win_power(rir, win_size=160):
    win = np.hanning(win_size)
    win /= win.sum()
    rir_sq = np.square(rir)
    win_power = sig.convolve(rir_sq, win, mode="valid")
    return win_power


def plot_win_power_db(win_power_db, ax):
    lines = ax.plot(win_power_db)
    return lines[0]


def plot_rir_for_rt60(rir, ax):
    win_power = get_win_power(rir)
    win_power_db = 10 * np.log10(win_power)
    line = plot_win_power_db(win_power_db, ax)
    return line

In [None]:
_, ax = plt.subplots()
win_power = get_win_power(rir)
win_power_db = 10 * np.log10(win_power)
plot_win_power_db(win_power_db, ax)
ax.set_title("RIR power decay")
ax.set_ylabel("Power (db-fs)")
ax.grid()
plt.show()

**We can oberve the following pattern:**

First the power drops abruptly and then it decays by a linear pattern in the log scale.

The slope of the linear fit defines the $\boldsymbol{rt_{60}}$ property of a RIR (and even a room).

$\boldsymbol{rt_{60}}$ (reverb time 60) is the time in which the linear fit decays by 60 dB. Measured in seconds.

Why 60 dB? It is the difference between the loudest and the quietest volumes in a symphonic orchestra.

**For the curious:** [More about rt60](https://svantek.com/academy/rt60-reverberation-time/), they measure it directly, not from an RIR.

**Let's fit a linear regression estimator:**

In [None]:
from sklearn.linear_model import LinearRegression

linear_regression = LinearRegression()

linear_trend_start = 500
linear_trend_end = 8000

x = np.arange(linear_trend_end)[linear_trend_start: linear_trend_end]
y = win_power_db[linear_trend_start: linear_trend_end]

linear_regression.fit(x[:, None], y);
linear_fit = linear_regression.predict(x[:, None])

In [None]:
_, ax = plt.subplots()

ax.set_title("RIR power decay")
ax.set_ylabel("Power (db-fs)")
plot_rir_for_rt60(rir, ax)
ax.plot(x, linear_fit)
ax.grid()

plt.show()

In [None]:
coef = linear_regression.coef_.item()
# intercept = linear_regression.intercept_
rt_60_sec = # your code: use coef to evaluate rt60

assert abs(rt_60_sec - 0.7) < 0.02, rt_60_sec

**RIR decay**

To prepare targets for partial dereverberation RIR is decayed.

How we will do it:

1. Find the argmax of a RIR and keep the next 20 ms as well as the part before argmax unchanged. This part of RIR corresponds to direct sound and early reverberation.

2. The rest part should be decayed exponentially, -60 dB per 0.3 sec.

This is something between the way it is done in [PoCoNet](https://arxiv.org/pdf/2008.04470.pdf) and [Cruse](https://arxiv.org/pdf/2101.09249.pdf).

In [None]:
def decay_rir(rir: np.ndarray, decay_rt_60_sec = 0.3, sr=SR) -> np.ndarray:
    """
    Decays a RIR as described above
    """
    main_tap = np.argmax(rir).item()
    early_reverb_duration_sec = 0.020
    early_reverb_duration_frames = int(early_reverb_duration_sec * sr)

    # your code
    return rir_decayed


rir_decayed = decay_rir(rir)

In [None]:
_, ax = plt.subplots()

ax.set_title("RIR power decay")
ax.set_ylabel("Power (db-fs)")
plot_rir_for_rt60(rir, ax)
plot_rir_for_rt60(rir_decayed, ax)

ax.set_ylim(-120, -23)
x_ticks = ax.get_xticks()
x_tick_labels = x_ticks / SR
ax.set_xticks(x_ticks, x_tick_labels)
ax.grid()

plt.show()

We don't provide and assertion test here, but this is what it should look like.

Pay attention to your decay rate, it should not deviate too much.

![title](assets/pictures/rir_decayed.png)

# 3. On-the-fly data generation [5 points]

**Why?**

We are going to train a model on synthetic mixtures of signals, noises with acoustics simulation via RIR convolution.

Why do we train a model on synthetic data? It is the most straight-forward way to obtain corresponding (mixture, signal) pairs.

It can seem natural to simulate all the data in advance and train on it.

But we shall take another approach: we will generate training mixtures on-the-fly. Data will be generated in parallel with forward-backward passes on GPU -- thanks to PyTorch's DataLoader class.

Generating data on the fly we can both increate training data diversity and save disk storage.

In [None]:
from vqe.data.sampling import list_wavs_in_folder, SignalSampler, RirSampler

**Efficent audio chunk reading:**

We are going to train on fixed-length chunks of audio.

A naive approach to read a chunk of audio file would be to read the full file and then crop it.

But we can do it better.

`sf.read` function provides `start` and `stop` arguments. When provided, `audio[start: stop]` is read directly.

In [None]:
rel_path = "read_speech/book_00002_chp_0005_reader_11980_15_seg_1.wav"
path = os.path.join(DATA_PATHS["speech"], rel_path)

x, _sr = sf.read(path)
print(f"total duration: {len(x)} frames or {len(x) / SR} seconds")

crop_size_sec = 1
crop_size_frames = int(crop_size_sec * SR)

start = 16_000
stop = start + crop_size_frames

In [None]:
%%timeit

x = sf.read(path)[0][start: stop]

In [None]:
%%timeit

x, _sr = sf.read(path, start=start, stop=stop)

**Let's implement a class that will read raw signal and noise data:**

In [None]:
class SignalSampler:
    def __init__(
        self,
        paths: list[str],
        crop_size_sec: float = 5.0,
        min_rms_db: float | None = -38,
        sr: int = SR,
    ) -> None:
        """
        paths: list of absolute paths to the files we are going to sample from
        crop_size_sec: the size of generated chunks, in seconds
        min_rms_db: chunks with RMS lower than this should be discarded
        sr: samplerate
        """
        self.paths = paths
        self.crop_size_frames = int(crop_size_sec * sr)
        self.min_rms_db = min_rms_db
        self.sr = sr

    def _sample_from_single_file(
        self, path: str, crop_size_frames: int | None = None
    ) -> np.ndarray:
        """
        Reads a random crop of size crop_size_frames from path.
        If the file is shorter, reads the full file.
           
        Use sf.read(..., start=start_index, stop=end_index) to read
        chunks efficiently.
        """
        if crop_size_frames is None:
            crop_size_frames = self.crop_size_frames
        with sf.SoundFile(path) as f:
            if f.samplerate != self.sr:
                assert False, (path, f.samplerate, self.sr)
            file_duration_frames: int = f.frames
        if file_duration_frames < crop_size_frames:
            # your code
            return # your code
        # your code; keep in mind the corner case when file_duration_frames == crop_size_frames
        x, _sr = sf.read(path, start=..., stop=...)
        return x

    def __call__(self) -> np.ndarray:
        """
        Generates a chunk of audio data of length self.crop_size_frames.
        
        1. Samples a random file from self.paths
        
        2. Reads its random crop of the target size (initialized as self.crop_size_frames).
           If the file is shorter, reads the full file.
           <this should be done in self._sample_from_single_file>
           
        3. Checks RMS of the crop.
           The crop is discarded if its rms is lower than self.min_rms_db.
           Otherwise it is accumulated
           
        4. Returns the concatenation of accumulated crops
           if their total length reaches self.crop_size_frames.
           Otherwise sets target size (for 2) to n_frames_ramaining and repeats 1-4
        """
        chunks: list[np.ndarray] = []
        duration_frames_remaining = self.crop_size_frames
        while duration_frames_remaining > 0:
            path = # your code: sample a random path
            chunk = self._sample_from_single_file(path, duration_frames_remaining)
            if self.min_rms_db is not None:
                # your code
                
                if chunk_rms_db < self.min_rms_db:
                    continue
            # your code
        result = np.concatenate(chunks)

        assert result.ndim == 1, result.shape
        assert len(result) == self.crop_size_frames

        return result

In [None]:
for crop_size_sec in [1, 3]:
    for key in ["speech", "noise"]:
        print(f"crop_size_sec: {crop_size_sec}, data: {key}")
        sampler = SignalSampler(
            list_wavs_in_folder_recursively(DATA_PATHS[key]),
            crop_size_sec=crop_size_sec
        )
        n_samples = 1000
        for idx in enumerate(tqdm(range(n_samples))):
            chunk = sampler()
            assert chunk.ndim == 1, chunk.shape
            assert len(chunk) == crop_size_sec * SR, chunk.shape
        print("Ok")

**Now let's define a similar sampler for RIRs.**

This guy is simpler, because it does not read chunks: it should read full RIRs.

In [None]:
class RirSampler:
    def __init__(self, paths: list[str], sr: int = SR) -> None:
        """
        paths: list of absolute paths to the files we are going to sample from
        sr: samplerate
        """
        self.paths = paths
        self.sr = sr

    def __call__(self) -> np.ndarray:
        """
        Samples a random path and reads the full audio file from it
        """
        # your code
        assert sr == self.sr, (path, sr, self.sr)
        return rir

In [None]:
sampler = RirSampler(list_wavs_in_folder_recursively(DATA_PATHS["rir"]))
n_samples = 1000
for idx in enumerate(tqdm(range(n_samples))):
    rir = sampler()
    assert rir.ndim == 1, rir.shape
print("Ok")

**We have learnt to sample RIRs and chunks of signal/noise. Now's the time to learn how to mix them.**

In [None]:
def convolve_same_length(x: np.ndarray, rir: np.ndarray) -> np.ndarray:
    """
    Convoves signal with rir and crops the result to have the original shape
    """
    convolved = sig.convolve(x, rir, mode="full")
    result = convolved[: len(x)]  # we crop out the reverb-only semgent
    return result


class RandomMixtureSampler:
    """
    Inspired by PoCoNet: https://arxiv.org/pdf/2008.04470.pdf
    """

    def __init__(
        self,
        sig_sampler: tp.Callable[[], np.ndarray],  # SignalSampler
        noise_sampler: tp.Callable[[], np.ndarray],  # SignalSampler
        rir_sampler: tp.Callable[[], np.ndarray],  # RirSampler
        prob_rir_sig: float = 0.5,  # prob to convolve signal with a RIR
        prob_rir_noise: float = 0.5,  # prob to convolve noise with a RIR
        normalization_rms_db: float = -20,  # normalization level used for signal and noise before gain
        noise_gain_range_db: tuple[float, float] = (-5, 5),  # gain applied to noise
        mixture_gain_range_db: tuple[float, float] = (-25, 5),  # gain applied to final mixture
        partial_dereverb: bool = True,  # whether to do partial dereverberation
        *,
        sr: int = 16_000,  # samplerate
    ) -> None:
        self.sig_sampler = sig_sampler
        self.noise_sampler = noise_sampler
        self.rir_sampler = rir_sampler
        self.prob_rir_sig = prob_rir_sig
        self.prob_rir_noise = prob_rir_noise
        self.sr = sr

        self.normalization_rms_db = normalization_rms_db
        self.noise_gain_range_db = noise_gain_range_db
        self.mixture_gain_range_db = mixture_gain_range_db

        self.partial_dereverb = partial_dereverb

    def sample_noise_rms_db(self) -> float:
        """
        Samples the rms_db for noise which is relevant before mixing with signal.

        Noise is first normlized to normalization_rms_db
        and the gain by Uniform(self.noise_gain_range_db).

        These 2 operations can be implemented as a single normalize_to_rms operation
        with the final rms.
        """
        raise NotImplementedError("your code")

    def sample_mixture_gain(self) -> float:
        """
        Uniform(*self.mixture_gain_range_db)
        """
        raise NotImplementedError("your code")

    def __call__(self) -> tuple[np.ndarray, np.ndarray]:
        signal = self.sig_sampler()
        noise = self.noise_sampler()
        if np.random.binomial(1, self.prob_rir_sig):
            rir_signal = # your code
            signal_input = # your code: signal which will be part of the input, not the target
            if self.partial_dereverb:
                rir_signal_decayed = # your code
                signal_target = # your code: target signal
            else:
                signal_target = np.copy(
                    signal_input
                )  # np.copy is crucial to avoid double scaling
        else:
            signal_input = signal
            signal_target = np.copy(
                signal
            )  # np.copy is crucial to avoid double scaling
        del signal
        if np.random.binomial(1, self.prob_rir_noise):
            rir_noise = # your code
            noise =  # your code

        # input_signal and mic_signal should be multiplied by the same factor to match each other
        mult_signal = normalize_to_rms(
            signal_target, self.normalization_rms_db
        )
        signal_input *=  # your code
        signal_target *=  # your code

        noise_rms_db = self.sample_noise_rms_db()
        mult_noise = normalize_to_rms(noise, noise_rms_db)
        noise *=  # your code

        mixture = signal_input + noise

        mixture_gain_db = self.sample_mixture_gain()
        mixture_mult =  # your code

        mixture *=  # your code
        signal_target *=  # your code: target should be scaled with mixture for them to match each other

        mixture = mixture.astype(np.float32)
        signal_target = signal_target.astype(np.float32)

        return mixture, signal_target

In [None]:
from vqe.data.mixing import RandomMixtureSampler

In [None]:
for crop_size_sec in [1, 5]:
    print(f"crop_size_sec: {crop_size_sec}")
    sampler = RandomMixtureSampler(
        sig_sampler=SignalSampler(
            list_wavs_in_folder_recursively(DATA_PATHS["speech"]),
            crop_size_sec=crop_size_sec
        ),
        noise_sampler=SignalSampler(
            list_wavs_in_folder_recursively(DATA_PATHS["noise"]),
            crop_size_sec=crop_size_sec
        ),
        rir_sampler=RirSampler(
            list_wavs_in_folder_recursively(DATA_PATHS["rir"])
        ),
    )
    n_samples = 1000
    for idx in enumerate(tqdm(range(n_samples))):
        mixture, signal = sampler()
        assert len(mixture) == len(signal) == crop_size_sec * SR, chunk.shape
    print("Ok")

**Sanity check**

It is easy to leave bugs with gains. Here is a simple way to check it.

We should listen to: mixture, signal and the difference: (mixture - signal).

The difference is the sum of noise and late reverberation. No distinct signal should stay there.

An even simpler sanity check: turn partial dereverberation off in the sampler. Then difference should be the noise and it should not contain any trace of the speech signal.

Let's generate some tracks in 2 modes:
1. No-dereverb
2. Full

In [None]:
def show_samples(sampler, n_samples=4):
    _, axes = plt.subplots(ncols=3, nrows=n_samples, figsize=(16, 10))
    for sample_idx in range(n_samples):
        mixture, target = sampler()
        interference = mixture - target

        spec_mixture = build_spec_for_plot(mixture)
        spec_target = build_spec_for_plot(target)
        spec_interf = build_spec_for_plot(interference)

        ax = axes[sample_idx][0]
        ax.imshow(spec_mixture)
        ax.set_aspect("auto")

        ax = axes[sample_idx][1]
        ax.imshow(build_spec_for_plot(target))
        ax.set_aspect("auto")

        ax = axes[sample_idx][2]
        ax.imshow(build_spec_for_plot(interference))
        ax.set_aspect("auto")
    plt.show()

In [None]:
crop_size_sec = 5

sig_sampler = SignalSampler(
    list_wavs_in_folder_recursively(DATA_PATHS["speech"]),
    crop_size_sec=crop_size_sec
)
noise_sampler = SignalSampler(
    list_wavs_in_folder_recursively(DATA_PATHS["noise"]),
    crop_size_sec=crop_size_sec
)
rir_sampler = RirSampler(list_wavs_in_folder_recursively(DATA_PATHS["rir"]))

print("dereverb off:")
sampler = RandomMixtureSampler(
    sig_sampler=sig_sampler,
    noise_sampler=noise_sampler,
    rir_sampler=rir_sampler,
    partial_dereverb=False,
)
show_samples(sampler)
print("*" * 50)

print("Full")
sampler = RandomMixtureSampler(
    sig_sampler=sig_sampler,
    noise_sampler=noise_sampler,
    rir_sampler=rir_sampler,
)
show_samples(sampler)

**Wrapping our sampler to PyTorch Dataset**

On `__getitem__` it will ignore the input and return a sampler from the sampler.
We also define `dummy_duration` variable which will simulate the size of a dataset.

In [None]:
import torch
import torch.utils.data as Data


class Dataset(Data.Dataset):
    def __init__(self, sampler: RandomMixtureSampler, dummy_duration: int):
        self.sampler = sampler
        self.dummy_duration = dummy_duration

    def __len__(self) -> int:
        return self.dummy_duration

    def __getitem__(self, index):
        """
        Ignores index and a sample from self.sampler, converted to float32
        """
        mixture, target = self.sampler()
        mixture = mixture.astype(np.float32)
        target = target.astype(np.float32)
        return mixture, target

    
sampler = RandomMixtureSampler(
    sig_sampler=sig_sampler,
    noise_sampler=noise_sampler,
    rir_sampler=rir_sampler,
)
dataset = Dataset(sampler, 100_000)
dataset[0]

Now we can use PyTorch DataLoader with our sampler, which should be really fast. If the throughput is higher that 10 batches per second (note that it outputs batches, not single samples), it is more than enough.

Pay attention to the `num_workers` parameter.

In [None]:
def worker_init_fn(worker_id):
    """setting different numpy seeds for different workers"""
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
    
loader = Data.DataLoader(
    dataset, batch_size=10, num_workers=8,
    worker_init_fn=worker_init_fn
)

for idx, batch in enumerate(tqdm(loader)):
    if idx == 100:
        break

**Data is ready**

# 4. Neural Network Architecture [3 points]

**Our network will be a 2D UNet operating in STFT domain.**

It will implement the Complex Spectral Mapping scheme (i.e. complex spectrum input -- complex spectrum output) with 2 decoders.

In [None]:
import torch
import torch.nn as nn
import torchaudio as tha

Let's implement [causal convolution](https://paperswithcode.com/method/causal-convolution).

For a moment we may think of it as of a 1D convolution.

Imagine we have an input: [1, 2, 3, 4, 5, 6, 7]
and our kernel size is 3, let's say the kernel is [1/3, 1/3, 1/3].

How will convolution process the input without padding? The input will be chunked into frames: [[1, 2, 3], [2, 3, 4], [4, 5, 6], [5, 6, 7]] and for each frame dot-product with the kernel will be calculated, resuling into 
[(1/3 + 2/3 + 3/3), (2/3 + 3/3 + 4/3), (4/3, 5/3, 6/3), (5/3, 6/3, 7/3)] = [2, 3, 4, 5, 6].

With a causal convolution we want every i-th chunk to end exactly on position i, i.e. we want the chunks to be:
[[\*, \*, 1], [\*, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6], [5, 6, 7]] where the stars mean it's unclear what to place on their positions.

These chunks correspond the convolution without padding for the following input: [\*, \*, 1, 2 ,3, 4, 5, 6, 7].
And it is exactly a convolution with padding equal to (kernel_size - 1, 0). For the stars, let's use zero padding.

For a 2D convolution, the idea does not change, because it should only be causal in 1 dimention, which corresponds to time.

In [None]:
class CausalConv1D(nn.Sequential):
    """
    A sequential of ConstantPad1d and Conv1d.
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        dilation: int,
    ):
        # your code
        padding_layer = # your code
        conv_layer = # your code
        super().__init__(
            padding_layer,
            conv_layer,
        )

Let's implement Double Modified Gated TCM (the guy with Primal and Dual domains; note that is also employs residual connection which is not shown in the picture) as it is defined in [the paper](https://arxiv.org/pdf/2011.01561.pdf)

<p float="left">
  <img src="assets/pictures/tcm.png" alt="MgTcm" width="200"/>
  <img src="assets/pictures/tcm_double.png" alt="DMgTcm" width="400"/>
</p>

<!-- <p float="left">
  <img src="assets/pictures/tcm.png" alt="MgTcm" width="200"/>
  <img src="assets/pictures/tcm_double.png" alt="DMgTcm" width="200"/>
</p> -->

In [None]:
class GatedBranch(nn.Module):
    def __init__(self, dilation: int, n_channels: int = 64, kernel_size: int = 5):
        super().__init__(self):
        # your code


class DoubleModifiedGatedTcm(nn.Module):
    """
    MgTcm here
    https://arxiv.org/pdf/2011.01561.pdf
    """
    def __init__(self, dilation: int, c_in: int = 256, c_hidden: int = 64, ):
        super().__init__()
        # your code

Let's build a 2D UNet with the following parameters:
- num encoder-decoder layer sets: 5
- num decoders: 2: 1 for real part and 1 for imaginary part of complex spectrogram
- for skip connection let's use addition through a point-wise conv (i.e. kernel_size=1) instead of concatenation. 1 point-wise code for each of the 2 decoders.
- kernel_size: (1, 4): 1 for the time dimension and 4 if freq dimension. With time dimension size of 1 a convolution with allways be causal.
- stride: (1, 2): 1 for time, 2 for frequencies. A causal model may not use stride != 1 in time dimension
- num out channels: 64 the final output layers
- middle: TCM (i.e. reshape 2D, compressing channels and frequencies to a single dimension, linear projection, TCM, linear projection, reshape 1D to 2D

TCM parameters:

- dilations: (1, 2, 4, 8, 16, 32) x 2
- layer: Double Modified Gated TCM


In [None]:
class UNetEngine(nn.Module):
    """
    Input: complex spectrum of shape (batch_size, 2, time, n_frequencies=161)
    Output: Similar
    """
    # your code
    
    def forward(self, spec):
        assert spec.ndim == 4, spec.shape
        assert spec.shape[1] == 2, (spec.shape, "complex")
        # your code

In [None]:
model = UNetEngine()

x = torch.rand(3, 2, 298, 161)
out = model(x)
assert x.shape == out.shape, (x.shape, out.shape)

In [None]:
class ComplexSpectumMappingModel(nn.Module):
    def __init__(self, engine):
        super().__init__()
        self.stft = tha.transforms.Spectrogram(
            n_fft=320,
            win_length=320,
            hop_length=160,
            power=None,
            window_fn=torch.hann_window,
        )

        self.istft = tha.transforms.InverseSpectrogram(
            n_fft=320,
            win_length=320,
            hop_length=160,
            window_fn=torch.hann_window,
        )
        self.engine = engine
        
    def forward(self, waveform):
        assert waveform.ndim == 2, waveform.shape
        spec = self.stft(waveform)  # complex
        spec_ri = torch.view_as_real(spec)  # (b, f, t, 2)
        model_input = torch.permute(spec_ri, [0, 3, 2, 1])  # (b, 2, t, f)
        model_output = self.engine(model_input)
        # (b, 2, t, f) -> (b, f, t, 2)
        spec_enhanced = torch.permute(model_output, [0, 3, 2, 1]).contiguous()
        spec_enhanced = torch.view_as_complex(spec_enhanced)
        wave_enhanced = self.istft(spec_enhanced, length=waveform.shape[-1])
        return wave_enhanced

In [None]:
def build_model():
    engine = UNetEngine()
    model = ComplexSpectumMappingModel(engine)
    return model

In [None]:
model = build_model()

x = torch.rand(3, 48_000)
out = model(x)

assert x.shape == out.shape

# Loss function [1 point]

In [None]:
class SpectralLoss(nn.Module):
    """
    Operates on waveforms.
    
    Computes spectrograms from them and evaluates complex spectral and magnitude losses.
    """
    def __init__(
        self,
        n_fft=1024,
        win_size: int | None = None,
        hop_size: int | None = None,
        mult_complex: float = 0.3,
        criterion: (
            tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None
        ) = None,
        window_fn=torch.hann_window,
    ):
        super().__init__()
        self.stft = tha.transforms.Spectrogram(
            n_fft=n_fft,
            win_length=win_size,
            hop_length=hop_size,
            window_fn=window_fn,
            power=None,
        )
        if criterion is None:
            criterion = nn.MSELoss()
        self.criterion = criterion
        assert 0 <= mult_complex <= 1, mult_complex
        self.mult_complex = mult_complex

    def forward(self, waveform_est: torch.Tensor, waveform_target: torch.Tensor):
        """
        Apply self.stft to est and target to get complex spectrograms.
        Compute self.criterion on the complex spectrograms.
        
        Compute magnitude spectograms and compute self.criterion on them.
        
        The final loss should be the sum of magnitude and complex parts.
        """
        assert est.ndim == 2, (est.shape, "batched")
        assert target.ndim == 2, (target.shape, "batched")
        # your code
        return loss_final, loss_complex, loss_magnitude

In [None]:
criterion = SpectralLoss()

est = torch.rand(3, 48_000)
target = torch.rand(3, 48_000)

criterion(est, target)

# 6. Train loop [3 points]

Now everything is ready. Let's check if out model can learn something

General parameters

In [None]:
NUM_WORKERS = 2  # parallel data generation

LR = 2e-4
BATCH_SIZE = 10  # the bigger the better
MAX_GRAD_NORM = 4  # for clipping


if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    PIN_MEMORY = True
else:
    DEVICE = torch.device("cpu")
    PIN_MEMORY = False

metrics

In [None]:
from torchmetrics.audio import SignalNoiseRatio, ScaleInvariantSignalNoiseRatio


metrics = torch.nn.ModuleDict(
    {
        "SNR": SignalNoiseRatio(),
        "SI-SNR": ScaleInvariantSignalNoiseRatio(),
    }
).to(DEVICE)

data

In [None]:
from sklearn.model_selection import train_test_split


SNR_RANGE = (-5, 5)  # range for SNR in data generation
CROP_SIZE_SEC = 5


test_size = 0.1
split_speech = train_test_split(
    list_wavs_in_folder_recursively(DATA_PATHS["speech"]),
    random_state=2967,
    test_size=test_size,
)
split_noise = train_test_split(
    list_wavs_in_folder_recursively(DATA_PATHS["noise"]),
    random_state=8701,
    test_size=test_size,
)
split_rir = train_test_split(
    list_wavs_in_folder_recursively(DATA_PATHS["rir"]),
    random_state=9807,
    test_size=test_size,
)


loaders = {}
for split_idx, mode in enumerate(["train", "val"]):
    sig_sampler = SignalSampler(split_speech[split_idx], crop_size_sec=crop_size_sec)
    noise_sampler = SignalSampler(split_noise[split_idx], crop_size_sec=crop_size_sec)
    rir_sampler = RirSampler(split_rir[split_idx])
    mixture_sampler = RandomMixtureSampler(
        sig_sampler=sig_sampler,
        noise_sampler=noise_sampler,
        rir_sampler=rir_sampler,
        prob_rir_noise=0.5,
        prob_rir_sig=0.5,
        normalization_rms_db=-20,
        noise_gain_range_db=(-SNR_RANGE[1], -SNR_RANGE[0]),
        mixture_gain_range_db=(-25, 5),
        sr=SR,
        partial_dereverb=True,
    )
    batches_per_epoch = 1024 if mode == "train" else 256
    dataset = Dataset(mixture_sampler, dummy_duration=BATCH_SIZE * batches_per_epoch)
    loader = Data.DataLoader(
        dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        worker_init_fn=worker_init_fn
    )
    loaders[mode] = loader

model, optimizer, criterion

In [None]:
model = build_model().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = SpectralLoss()

Tensorboard logger:

In [None]:
from torch.utils.tensorboard import SummaryWriter

SAVE_SOUND_FREQ = 512
LOG_FREQ = 4
SAVE_SNAPSHOT_FREQ = 512

writer = SummaryWriter()

In [None]:
def run_epoch(loader, training: bool, global_step_idx: int = 0):
    mode_name = "train" if training else "val"
    
    for evaluator in metrics.values():
        evaluator.reset()
    loss_storage = [[] for _ in range(3)]
    for step_idx, (mixture, target) in enumerate(tqdm(loader)):
        mixture = mixture.to(DEVICE)
        target = target.to(DEVICE)
        est = # your code: run the model

        loss_components = criterion(est, target)
        loss = loss_components[0]

        for storage, component in zip(loss_storage, loss_components):
            storage.append(component.item())
            
        with torch.no_grad():
            for name, evaluator in metrics.items():
                value_out = evaluator(est, target).mean().item()
            
        if training:
            optimizer.zero_grad()
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), MAX_GRAD_NORM
            )
            optimizer.step()

        if (step_idx % LOG_FREQ == 0) and training:
            loss_logs = [sum(x) / len(x) for x in loss_storage]
            writer.add_scalar(f"{mode_name}/loss/final", loss_logs[0], global_step_idx)
            writer.add_scalar(f"{mode_name}/loss/complex", loss_logs[1], global_step_idx)
            writer.add_scalar(f"{mode_name}/loss/magnitude", loss_logs[2], global_step_idx)
            loss_storage = [[] for _ in range(3)]

            writer.add_scalar(f"{mode_name}/grad_norm", grad_norm.item(), global_step_idx)

            for name, evaluator in metrics.items():
                value_out = evaluator.compute().item()
                writer.add_scalar(f"{mode_name}/metrics/{name}", value_out, global_step_idx,)
                evaluator.reset()

        if step_idx % SAVE_SOUND_FREQ == 0:
            with torch.no_grad():
                path_samples = f"samples/{mode_name}"
                os.makedirs(path_samples, exist_ok=True)
                sf.write(
                    f"{path_samples}/{global_step_idx:05d}_mixture.wav",
                    mixture[0].cpu().numpy(),
                    SR,
                )
                sf.write(
                    f"{path_samples}/{global_step_idx:05d}_clean.wav",
                    target[0].cpu().numpy(),
                    SR,
                )
                sf.write(
                    f"{path_samples}/{global_step_idx:05d}_noise.wav",
                    mixture[0].cpu().numpy() - target[0].cpu().numpy(),
                    SR,
                )
                sf.write(
                    f"{path_samples}/{global_step_idx:05d}_est.wav",
                    est[0].cpu().numpy(),
                    SR,
                )
        if step_idx % SAVE_SNAPSHOT_FREQ == 0 and training:
            torch.save(model.state_dict(), "state_dict_latest.pt")
        if training:
            global_step_idx += 1
            
    if not training:
        for name, evaluator in metrics.items():
            value_out = evaluator.compute().item()
            print(f"{mode_name}/metrics/{name}", value_out)
            writer.add_scalar(f"{mode_name}/metrics/{name}", value_out, global_step_idx,)
        
    return global_step_idx

Train loop itself

In [None]:
global_step_idx = 0
while True:
    global_step_idx = run_epoch(loaders["train"], True, global_step_idx)
    with torch.no_grad():
        run_epoch(loaders["val"], False, global_step_idx)

### What to expect?

SNR and SI-SNR should go up on both train and validation. They should reach 1-2 dB in a couple dozens minutes. In 8 hours on a 1080ti GPU they should reach around 10 dB. The model starts with random predictions, so the initial growth is noticeable since the first minutes, however it does not mean that the model was implemented correctly.

### Scoring
5dB SI-SNR on train test should be beaten to complete the training loop part. Points for main loop will be given based on the final SI-SDR. 10 dB on validation set will definitely result in full points.

### What if the hardware is not powerful enough?
Colab instances should be power enough. However, if it doesn't work for you, please, try to tune the parameters and make the network smaller. If the models are still too big, try [grouped RNN](https://arxiv.org/pdf/2101.09249.pdf). [This model](https://arxiv.org/pdf/2008.06412.pdf) should be cheaper in computation (with a worse quality). Finally there is a bonus task for streaming inference ahead.

# Bonus: streaming inference

The model you implemented above should be streaming-friendly.

During classwork we implemented streaming STFT and ISTFT with a toy VQE. In this assigment you are required to implement streaming end-to-end inference for your network, i.e. take 10-ms chunks of audio as input and output 10-ms chunks of enhanced audio.

Probably you will re-implement your model in PyTorch and load an adapted version of state_dict from the offline model.