# Licence

In [None]:
https://github.com/seorim0/Speech_enhancement_with_Pytorch

MIT License

Copyright (c) 2021 seorim0

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

# Requirement

In [None]:
# # For additional installation of libraries not included in the colab main library
# !pip install "library_name"
!pip install pesq
!pip install pystoi

Collecting pystoi
  Downloading pystoi-0.3.3.tar.gz (7.0 kB)
Building wheels for collected packages: pystoi
  Building wheel for pystoi (setup.py) ... [?25l[?25hdone
  Created wheel for pystoi: filename=pystoi-0.3.3-py2.py3-none-any.whl size=7793 sha256=4b3f5991ad63a316047c2ab9f983f4f7216a7f059fc75df4c0335555ec164e4d
  Stored in directory: /root/.cache/pip/wheels/46/4a/ad/3ab460193ed0535430b4b1575f255aa6bae69df17453628e86
Successfully built pystoi
Installing collected packages: pystoi
Successfully installed pystoi-0.3.3


In [None]:
import os
import re
import time
import math
import torch
import shutil
import logging
import numpy as np
from pesq import pesq
import torch.nn as nn
from pystoi import stoi
from scipy import interpolate
import matplotlib.pylab as plt
import scipy.io.wavfile as wav
import matplotlib.pyplot as plt
import torch.nn.functional as F
from scipy.signal import get_window
from torch.utils.data import Dataset, DataLoader

# config

In [None]:
"""
Configuration for train_interface

You can check the essential information,
and if you want to change model structure or training method,
you have to change this file.
"""
#######################################################################
#                                 path                                #
#######################################################################
job_dir = './'  # 'FILE PATH for saving models' 
chkpt_model = None  # 'FILE PATH (if you have pretrained model..)'
chkpt = str("EPOCH")  
if chkpt_model is not None:
    chkpt_path = job_dir + chkpt_model + '/chkpt_' + chkpt + '.pt'

#######################################################################
#                         possible setting                            #
#######################################################################
# the list you can do
model_list = ['CRN']
loss_list = ['MSE', 'SDR', 'SI-SNR', 'SI-SDR']
mask_type = ['Direct(None make)', 'E', 'C', 'R']
window_type = ['hanning']

# experiment number setting
expr_num = 'EXPERIMENT_NUMBER'
DEVICE = 'cuda'  # if you want to run the code with 'cpu', change 'cpu'
#######################################################################
#                           current setting                           #
#######################################################################
current_model = model_list[0]
current_loss = loss_list[0]

masking_mode = mask_type[1]
window = window_type[0]
skip_type = True   # False, if you want to remove 'skip connection'
direct_mapping = True if masking_mode == 'Direct(None make)'else False

# hyper-parameters
max_epochs = 100
learning_rate = 0.001
batch = 10

# kernel size
dccrn_kernel_num = [32, 64, 128, 256, 256, 256]
#######################################################################
#                         model information                           #
#######################################################################
fs = 16000
win_len = 400
win_inc = 100
ola_ratio = win_inc / win_len
fft_len = 512 #2048 #512
sam_sec = fft_len / fs
frm_samp = fs * (fft_len / fs)

rnn_layers = 2
rnn_input_size = 512 
rnn_units = 128
#######################################################################
#                      setting error check                            #
#######################################################################
# if the setting is wrong, print error message

#######################################################################
#                           print setting                             #
#######################################################################
print('--------------------  C  O  N  F  I  G  ----------------------')
print('--------------------------------------------------------------')
print('MODEL INFO : {}'.format(current_model))
print('LOSS INFO : {}'.format(current_loss))
print('SKIP : {}'.format(skip_type))
print('MASKING INFO : {}'.format(masking_mode))
print('\nBATCH : {}'.format(batch))
print('LEARNING RATE : {}'.format(learning_rate))
print('--------------------------------------------------------------')
print('--------------------------------------------------------------\n')


--------------------  C  O  N  F  I  G  ----------------------
--------------------------------------------------------------
MODEL INFO : CRN
LOSS INFO : SDR
SKIP : True
MASKING INFO : E

BATCH : 10
LEARNING RATE : 0.001
--------------------------------------------------------------
--------------------------------------------------------------



# Dataloader

In [None]:
def create_dataloader(mode, type=0, snr=0):
    if mode == 'train':
        return DataLoader(
            dataset=Wave_Dataset(mode, type, snr),
            batch_size=batch,
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=True,
            sampler=None
        )
    elif mode == 'valid':
        return DataLoader(
            dataset=Wave_Dataset(mode, type, snr),
            batch_size=batch, shuffle=False, num_workers=0
        )

class Wave_Dataset(Dataset):
    def __init__(self, mode, type, snr):
        # load data
        if mode == 'train':
            self.mode = 'train'
            print('<Training dataset>')
            print('Load the data...')
            self.input = np.ones((300, 2, 16000))
            # self.input_path = ""DATASET_FILE_PATH""
            # self.input = np.load(self.input_path)
        elif mode == 'valid':
            self.mode = 'valid'
            print('<Validation dataset>')
            print('Load the data...')
            self.input = np.ones((50, 2, 16000))
            # self.input_path = ""DATASET_FILE_PATH""
            # self.input = np.load(self.input_path)
            # # if you want to use a part of the dataset
            # self.input = self.input[:500]

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
            inputs = self.input[idx][0]
            targets = self.input[idx][1]

            # transform to torch from numpy
            inputs = torch.from_numpy(inputs)
            targets = torch.from_numpy(targets)

            return inputs, targets

# Tools for model

In [None]:
############################################################################
#                         for convolutional STFT                           #
############################################################################
# this is from conv_stft https://github.com/huyanxin/DeepComplexCRN
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
    if win_type == 'None' or win_type is None:
        window = np.ones(win_len)
    else:
        window = get_window(win_type, win_len, fftbins=True)  # **0.5

    N = fft_len
    fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
    real_kernel = np.real(fourier_basis)
    imag_kernel = np.imag(fourier_basis)
    kernel = np.concatenate([real_kernel, imag_kernel], 1).T

    if invers:
        kernel = np.linalg.pinv(kernel).T

    kernel = kernel * window
    kernel = kernel[:, None, :]
    return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))


class ConvSTFT(nn.Module):

    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
        super(ConvSTFT, self).__init__()

        if fft_len == None:
            self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len

        kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
        # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.stride = win_inc
        self.win_len = win_len
        self.dim = self.fft_len

    def forward(self, inputs):
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)
        inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
        outputs = F.conv1d(inputs, self.weight, stride=self.stride)

        if self.feature_type == 'complex':
            return outputs
        else:
            dim = self.dim // 2 + 1
            real = outputs[:, :dim, :]
            imag = outputs[:, dim:, :]
            mags = torch.sqrt(real ** 2 + imag ** 2)
            phase = torch.atan2(imag, real)
            return mags, phase


class ConviSTFT(nn.Module):

    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
        super(ConviSTFT, self).__init__()
        if fft_len == None:
            self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
        # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.win_type = win_type
        self.win_len = win_len
        self.stride = win_inc
        self.dim = self.fft_len
        self.register_buffer('window', window)
        self.register_buffer('enframe', torch.eye(win_len)[:, None, :])

    def forward(self, inputs, phase=None):
        """
        inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
        phase: [B, N//2+1, T] (if not none)
        """

        if phase is not None:
            real = inputs * torch.cos(phase)
            imag = inputs * torch.sin(phase)
            inputs = torch.cat([real, imag], 1)

        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)

        # this is from torch-stft: https://github.com/pseeth/torch-stft
        t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)

        outputs = outputs / (coff + 1e-8)

        outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]

        return outputs


############################################################################
#                             for complex rnn                              #
############################################################################
def get_casual_padding1d():
    pass


def get_casual_padding2d():
    pass


class cPReLU(nn.Module):

    def __init__(self, complex_axis=1):
        super(cPReLU, self).__init__()
        self.r_prelu = nn.PReLU()
        self.i_prelu = nn.PReLU()
        self.complex_axis = complex_axis

    def forward(self, inputs):
        real, imag = torch.chunk(inputs, 2, self.complex_axis)
        real = self.r_prelu(real)
        imag = self.i_prelu(imag)
        return torch.cat([real, imag], self.complex_axis)


############################################################################
#                         for data normalization                           #
############################################################################
class RealConv2d(nn.Module):

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            dilation=1,
            groups=1,
            causal=True,
            complex_axis=1,
    ):
        '''
            in_channels: real+imag
            out_channels: real+imag
            kernel_size : input [B,C,D,T] kernel size in [D,T]
            padding : input [B,C,D,T] padding in [D,T]
            causal: if causal, will padding time dimension's left side,
                    otherwise both

        '''
        super(RealConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.causal = causal
        self.groups = groups
        self.dilation = dilation

        self.conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
                                   padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)

        nn.init.normal_(self.conv.weight.data, std=0.05)
        nn.init.constant_(self.conv.bias, 0.)

    def forward(self, inputs):
        if self.padding[1] != 0 and self.causal:
            inputs = F.pad(inputs, [self.padding[1], 0, 0, 0])  ## [width left, width right, height left, height right]
        else:
            inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])

        out = self.conv(inputs)

        return out


class RealConvTranspose2d(nn.Module):

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            output_padding=(0, 0),
            groups=1
    ):
        '''
            in_channels: real+imag
            out_channels: real+imag
        '''
        super(RealConvTranspose2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.output_padding = output_padding
        self.groups = groups

        self.conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
                                            padding=self.padding, output_padding=output_padding, groups=self.groups)

        nn.init.normal_(self.conv.weight.data, std=0.05)
        nn.init.constant_(self.conv.bias, 0.)

        # # weight standardization
        # self.real_conv = ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
        #                                     padding=self.padding, output_padding=output_padding, groups=self.groups)
        # self.imag_conv = ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
        #                                     padding=self.padding, output_padding=output_padding, groups=self.groups)
        # self.complex_axis = complex_axis
        #
        # nn.init.constant_(self.real_conv.bias, 0.)
        # nn.init.constant_(self.imag_conv.bias, 0.)

    def forward(self, inputs):

        out = self.conv(inputs)

        return out


############################################################################
#                         for data normalization                           #
############################################################################
# get mu and sig
def get_mu_sig(data):
    """Compute mean and standard deviation vector of input data

    Returns:
        mu: mean vector (#dim by one)
        sig: standard deviation vector (#dim by one)
    """
    # Initialize array.
    data_num = len(data)
    mu_utt = []
    tmp_utt = []
    for n in range(data_num):
        dim = len(data[n])
        mu_utt_tmp = np.zeros(dim)
        mu_utt.append(mu_utt_tmp)

        tmp_utt_tmp = np.zeros(dim)
        tmp_utt.append(tmp_utt_tmp)

    # Get mean.
    for n in range(data_num):
        mu_utt[n] = np.mean(data[n], 0)
    mu = mu_utt

    # Get standard deviation.
    for n in range(data_num):
        tmp_utt[n] = np.mean(np.square(data[n] - mu[n]), 0)
    sig = np.sqrt(tmp_utt)

    # Assign unit variance.
    for n in range(len(sig)):
        if sig[n] < 1e-5:
            sig[n] = 1.0
    return np.float16(mu), np.float16(sig)


def get_statistics_inp(inp):
    """Get statistical parameter of input data.

    Args:
        inp: input data

    Returns:
        mu_inp: mean vector of input data
        sig_inp: standard deviation vector of input data
    """

    mu_inp, sig_inp = get_mu_sig(inp)

    return mu_inp, sig_inp


# normalize [-1 1]
def normalize_dataset(dataset):
    for i in range(len(dataset)):
            noisy_max = np.max(abs(dataset[i][0]))
            dataset[i][0] = dataset[i][0] / noisy_max

            clean_max = np.max(abs(dataset[i][1]))
            dataset[i][1] = dataset[i][1] / clean_max
    return dataset


############################################################################
#                       for plotting the samples                           #
############################################################################
def hann_window(win_samp):
    tmp = np.arange(1, win_samp + 1, 1.0, dtype=np.float64)
    window = 0.5 - 0.5 * np.cos((2.0 * np.pi * tmp) / (win_samp + 1))
    return np.float32(window)


def fig2np(fig):
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return data


def plot_spectrogram_to_numpy(input_wav, fs, n_fft, n_overlap, win, mode, clim, label):
    # cuda to cpu
    input_wav = input_wav.cpu().detach().numpy()

    fig, ax = plt.subplots(figsize=(12, 3))

    if mode == 'phase':
        pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), window=win, noverlap=n_overlap,
                                         cmap='jet',
                                         mode=mode)
    else:
        pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), window=win, noverlap=n_overlap,
                                         cmap='jet')

    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    plt.tight_layout()
    plt.clim(clim)

    if label is None:
        fig.colorbar(cax)
    else:
        fig.colorbar(cax, label=label)

    fig.canvas.draw()
    data = fig2np(fig)
    plt.close()
    return data


def plot_mask_to_numpy(mask, fs, n_fft, n_overlap, win, clim1, clim2, cmap):
    frame_num = mask.shape[0]
    shift_length = n_overlap
    frame_length = n_fft
    signal_length = frame_num * shift_length + frame_length

    xt = np.arange(0, np.floor(10 * signal_length / fs) / 10, step=0.5) / (signal_length / fs) * frame_num + 1e-8
    yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1)

    fig, ax = plt.subplots(figsize=(12, 3))
    im = ax.imshow(np.transpose(mask), aspect='auto', origin='lower', interpolation='none', cmap=cmap)

    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (kHz)')
    plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5))
    plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt))))
    plt.tight_layout()
    plt.colorbar(im, ax=ax)
    im.set_clim(clim1, clim2)

    fig.canvas.draw()
    data = fig2np(fig)
    plt.close()
    return data


def plot_error_to_numpy(estimated, target, fs, n_fft, n_overlap, win, mode, clim1, clim2, label):
    fig, ax = plt.subplots(figsize=(12, 3))
    if mode == None:
        pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet')
        pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet')
        im = ax.imshow(10 * np.log10(pxx1) - 10 * np.log10(pxx2), aspect='auto', origin='lower', interpolation='none',
                       cmap='jet')
    else:
        pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet',
                                          mode=mode)
        pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), window=win, noverlap=n_overlap, cmap='jet',
                                          mode=mode)
        im = ax.imshow(pxx1 - pxx2, aspect='auto', origin='lower', interpolation='none', cmap='jet')

    frame_num = pxx1.shape[1]
    shift_length = n_overlap
    frame_length = n_fft
    signal_length = frame_num * shift_length + frame_length

    xt = np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5) / (signal_length / fs) * frame_num
    yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1)

    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (kHz)')
    plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5))
    plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt))))
    plt.tight_layout()
    plt.colorbar(im, ax=ax, label=label)
    im.set_clim(clim1, clim2)

    fig.canvas.draw()
    data = fig2np(fig)
    plt.close()
    return data



############################################################################
#                                for run.py                                #
############################################################################
def near_avg_index(array):
    array_mean = np.mean(array)

    distance_arr = []
    for i in range(len(array)):
        val = array[i]
        distance = abs(array_mean - val)
        distance_arr.append(distance)

    index = distance_arr.index(min(distance_arr))
    return index


def max_index(array):
    array_max = np.max(array)

    for i in range(len(array)):
        val = array[i]
        if val == array_max:
            index = i
    return index


def min_index(array):
    array_min = np.min(array)

    for i in range(len(array)):
        val = array[i]
        if val == array_min:
            index = i
    return index


class Bar(object):
    def __init__(self, dataloader):
        if not hasattr(dataloader, 'dataset'):
            raise ValueError('Attribute `dataset` not exists in dataloder.')
        if not hasattr(dataloader, 'batch_size'):
            raise ValueError('Attribute `batch_size` not exists in dataloder.')

        self.dataloader = dataloader
        self.iterator = iter(dataloader)
        self.dataset = dataloader.dataset
        self.batch_size = dataloader.batch_size
        self._idx = 0
        self._batch_idx = 0
        self._time = []
        self._DISPLAY_LENGTH = 50

    def __len__(self):
        return len(self.dataloader)

    def __iter__(self):
        return self

    def __next__(self):
        if len(self._time) < 2:
            self._time.append(time.time())

        self._batch_idx += self.batch_size
        if self._batch_idx > len(self.dataset):
            self._batch_idx = len(self.dataset)

        try:
            batch = next(self.iterator)
            self._display()
        except StopIteration:
            raise StopIteration()

        self._idx += 1
        if self._idx >= len(self.dataloader):
            self._reset()

        return batch

    def _display(self):
        if len(self._time) > 1:
            t = (self._time[-1] - self._time[-2])
            eta = t * (len(self.dataloader) - self._idx)
        else:
            eta = 0

        rate = self._idx / len(self.dataloader)
        len_bar = int(rate * self._DISPLAY_LENGTH)
        bar = ('=' * len_bar + '>').ljust(self._DISPLAY_LENGTH, '.')
        idx = str(self._batch_idx).rjust(len(str(len(self.dataset))), ' ')

        tmpl = '\r{}/{}: [{}] - ETA {:.1f}s'.format(
            idx,
            len(self.dataset),
            bar,
            eta
        )
        print(tmpl, end='')
        if self._batch_idx == len(self.dataset):
            print()

    def _reset(self):
        self._idx = 0
        self._batch_idx = 0
        self._time = []


# Tools for loss

In [None]:
############################################################################
#               for model structure & loss function                        #
############################################################################
def remove_dc(data):
    mean = torch.mean(data, -1, keepdim=True)
    data = data - mean
    return data


def l2_norm(s1, s2):
    # norm = torch.sqrt(torch.sum(s1*s2, 1, keepdim=True))
    # norm = torch.norm(s1*s2, 1, keepdim=True)

    norm = torch.sum(s1 * s2, -1, keepdim=True)
    return norm


def sdr(s1, s2, eps=1e-8):
    sn = l2_norm(s1, s1)
    sn_m_shn = l2_norm(s1 - s2, s1 - s2)
    sdr_loss = 10 * torch.log10(sn**2 / (sn_m_shn**2 + eps))
    return torch.mean(sdr_loss)


def si_snr(s1, s2, eps=1e-8):
    # s1 = remove_dc(s1)
    # s2 = remove_dc(s2)
    s1_s2_norm = l2_norm(s1, s2)
    s2_s2_norm = l2_norm(s2, s2)
    s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
    e_nosie = s1 - s_target
    target_norm = l2_norm(s_target, s_target)
    noise_norm = l2_norm(e_nosie, e_nosie)
    snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
    return torch.mean(snr)


def si_sdr(reference, estimation, eps=1e-8):
    """
        Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
        Args:
            reference: numpy.ndarray, [..., T]
            estimation: numpy.ndarray, [..., T]
        Returns:
            SI-SDR
        [1] SDR– Half- Baked or Well Done?
        http://www.merl.com/publications/docs/TR2019-013.pdf
        >>> np.random.seed(0)
        >>> reference = np.random.randn(100)
        >>> si_sdr(reference, reference)
        inf
        >>> si_sdr(reference, reference * 2)
        inf
        >>> si_sdr(reference, np.flip(reference))
        -25.127672346460717
        >>> si_sdr(reference, reference + np.flip(reference))
        0.481070445785553
        >>> si_sdr(reference, reference + 0.5)
        6.3704606032577304
        >>> si_sdr(reference, reference * 2 + 1)
        6.3704606032577304
        >>> si_sdr([1., 0], [0., 0])  # never predict only zeros
        nan
        >>> si_sdr([reference, reference], [reference * 2 + 1, reference * 1 + 0.5])
        array([6.3704606, 6.3704606])
        :param reference:
        :param estimation:
        :param eps:
        """

    reference_energy = torch.sum(reference ** 2, axis=-1, keepdims=True)

    # This is $\alpha$ after Equation (3) in [1].
    optimal_scaling = torch.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy + eps

    # This is $e_{\text{target}}$ in Equation (4) in [1].
    projection = optimal_scaling * reference

    # This is $e_{\text{res}}$ in Equation (4) in [1].
    noise = estimation - projection

    ratio = torch.sum(projection ** 2, axis=-1) / torch.sum(noise ** 2, axis=-1) + eps

    ratio = torch.mean(ratio)
    return 10 * torch.log10(ratio + eps)

# Tools for score

In [None]:
###############################################################################
#                           PESQ (another ref)                                #
###############################################################################
# interface to PESQ evaluation, taking in two waveforms as input
def cal_pesq(dirty_wavs, clean_wavs):
    scores = []
    for i in range(len(dirty_wavs)):
        pesq_score = pesq(fs, dirty_wavs[i], clean_wavs[i], 'wb')
        scores.append(pesq_score)
    return scores


###############################################################################
#                                     STOI                                    #
###############################################################################
def cal_stoi(estimated_speechs, clean_speechs):
    stoi_scores = []
    for i in range(len(estimated_speechs)):
        stoi_score = stoi(clean_speechs[i], estimated_speechs[i], cfg.fs, extended=False)
        stoi_scores.append(stoi_score)
    return stoi_scores


###############################################################################
#                                     SNR                                     #
###############################################################################
def cal_snr(s1, s2, eps=1e-8):
    signal = s2
    mean_signal = np.mean(signal)
    signal_diff = signal - mean_signal
    var_signal = np.sum(np.mean(signal_diff ** 2))  # # variance of orignal data

    noisy_signal = s1
    noise = noisy_signal - signal
    mean_noise = np.mean(noise)
    noise_diff = noise - mean_noise
    var_noise = np.sum(np.mean(noise_diff ** 2))  # # variance of noise

    if var_noise == 0:
        snr_score = 100  # # clean
    else:
        snr_score = (np.log10(var_signal/var_noise + eps))*10
    return snr_score


def cal_snr_array(estimated_speechs, clean_speechs):
    snr_score = []
    for i in range(len(estimated_speechs)):
        snr = cal_snr(estimated_speechs[i], clean_speechs[i])
        snr_score.append(snr)
    return snr_score

# Model

In [None]:
#######################################################################
#                            real network                             #
#######################################################################
class CRN(nn.Module):
    def __init__(
            self,
            rnn_layers=rnn_layers,
            rnn_units=rnn_units,
            win_len=win_len,
            win_inc=win_inc,
            fft_len=fft_len,
            win_type=window,
            kernel_size=5
    ):
        '''
            rnn_layers: the number of lstm layers in the crn
        '''

        super(CRN, self).__init__()

        # for fft
        self.win_len = win_len
        self.win_inc = win_inc
        self.fft_len = fft_len
        self.win_type = win_type

        input_dim = win_len
        output_dim = win_len

        self.rnn_input_size = rnn_input_size
        self.rnn_units = rnn_units
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_layers = rnn_layers
        self.kernel_size = kernel_size
        kernel_num = dccrn_kernel_num
        self.kernel_num = [2] + kernel_num

        # bidirectional=True
        bidirectional = False
        fac = 2 if bidirectional else 1

        self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'real')
        self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'real')

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        for idx in range(len(self.kernel_num) - 1):
            self.encoder.append(
                nn.Sequential(
                    RealConv2d(
                        self.kernel_num[idx]//2,
                        self.kernel_num[idx + 1]//2,
                        kernel_size=(self.kernel_size, 2),
                        stride=(2, 1),
                        padding=(2, 1)
                    ),
                    nn.BatchNorm2d(self.kernel_num[idx + 1]//2),
                    nn.PReLU() 
                )
            )
        hidden_dim = self.fft_len // (2 ** (len(self.kernel_num)))

        self.enhance = nn.LSTM(
            input_size=self.rnn_input_size,
            hidden_size=self.rnn_units,
            dropout=0.0,
            bidirectional=bidirectional,
            batch_first=False
        )
        self.tranform = nn.Linear(self.rnn_units, self.rnn_input_size)

        if skip_type:
            for idx in range(len(self.kernel_num) - 1, 0, -1):
                if idx != 1:
                    self.decoder.append(
                        nn.Sequential(
                            RealConvTranspose2d(
                                self.kernel_num[idx],
                                self.kernel_num[idx - 1]//2,
                                kernel_size=(self.kernel_size, 2),
                                stride=(2, 1),
                                padding=(2, 0),
                                output_padding=(1, 0)
                            ),
                            nn.BatchNorm2d(self.kernel_num[idx - 1]//2),
                            nn.PReLU() 
                        )
                    )
                else:
                    self.decoder.append(
                        nn.Sequential(
                            RealConvTranspose2d(
                                self.kernel_num[idx],
                                self.kernel_num[idx - 1]//2,
                                kernel_size=(self.kernel_size, 2),
                                stride=(2, 1),
                                padding=(2, 0),
                                output_padding=(1, 0)
                            ),
                        )
                    )
        else:
            for idx in range(len(self.kernel_num) - 1, 0, -1):
                if idx != 1:
                    self.decoder.append(
                        nn.Sequential(
                            nn.ConvTranspose2d(
                                self.kernel_num[idx],
                                self.kernel_num[idx - 1],
                                kernel_size=(self.kernel_size, 2),
                                stride=(2, 1),
                                padding=(2, 0),
                                output_padding=(1, 0)
                            ),
                            nn.BatchNorm2d(self.kernel_num[idx - 1]),
                            # nn.ELU()
                            nn.PReLU()
                        )
                    )
                else:
                    self.decoder.append(
                        nn.Sequential(
                            nn.ConvTranspose2d(
                                self.kernel_num[idx],
                                self.kernel_num[idx - 1],
                                kernel_size=(self.kernel_size, 2),
                                stride=(2, 1),
                                padding=(2, 0),
                                output_padding=(1, 0)
                            ),
                        )
                    )
        self.flatten_parameters()

    def flatten_parameters(self):
        if isinstance(self.enhance, nn.LSTM):
            self.enhance.flatten_parameters()

    def forward(self, inputs, targets=0):

        mags, phase = self.stft(inputs)

        out = mags
        out = out.unsqueeze(1)
        out = out[:, :, 1:]
        encoder_out = []

        for idx, layer in enumerate(self.encoder):
            out = layer(out)
            #    print('encoder', out.size())
            encoder_out.append(out)

        batch_size, channels, dims, lengths = out.size()
        out = out.permute(3, 0, 1, 2)

        rnn_in = torch.reshape(out, [lengths, batch_size, channels * dims])
        out, _ = self.enhance(rnn_in)
        out = self.tranform(out)
        out = torch.reshape(out, [lengths, batch_size, channels, dims])

        out = out.permute(1, 2, 3, 0)

        if skip_type:  # use skip connection
            for idx in range(len(self.decoder)):
                out = torch.cat([out, encoder_out[-1 - idx]], 1)
                out = self.decoder[idx](out)
                out = out[..., 1:]  #
        else:
            for idx in range(len(self.decoder)):
                out = self.decoder[idx](out)
                out = out[..., 1:]

        # mask_mags = F.pad(out, [0, 0, 1, 0])
        out = out.squeeze(1)

        if direct_mapping:  # spectral mapping
            target_mags, _ = self.stft(target)

            out_real = out * torch.cos(phase)
            out_imag = out * torch.sin(phase)

            out_spec = torch.cat([out_real, out_imag], 1)

            out_wav = self.istft(out_spec)
            out_wav = torch.squeeze(out_wav, 1)
            out_wav = torch.clamp_(out_wav, -1, 1)

            return out, target_mags, out_wav
        else:  # T-F masking
            # mask_mags = torch.clamp_(mask_mags,0,100)
            out = F.pad(out, [0, 0, 1, 0])
            mask_mags = torch.tanh(out)
            est_mags = mask_mags * mags
            out_real = est_mags * torch.cos(phase)
            out_imag = est_mags * torch.sin(phase)

            out_spec = torch.cat([out_real, out_imag], 1)

            out_wav = self.istft(out_spec)
            out_wav = torch.squeeze(out_wav, 1)
            out_wav = torch.clamp_(out_wav, -1, 1)

            return out_wav

    def get_params(self, weight_decay=0.0):
        # add L2 penalty
        weights, biases = [], []
        for name, param in self.named_parameters():
            if 'bias' in name:
                biases += [param]
            else:
                weights += [param]
        params = [{
            'params': weights,
            'weight_decay': weight_decay,
        }, {
            'params': biases,
            'weight_decay': 0.0,
        }]
        return params

    def loss(self, estimated, target):
        if current_loss == 'MSE':
            return F.mse_loss(estimated, target, reduction='mean')
        elif current_loss == 'SDR':
            return -sdr(target, estimated)
        elif current_loss == 'SI-SNR':
            return -(si_snr(estimated, target))
        elif current_loss == 'SI-SDR':
            return -(si_sdr(target, estimated))

# Trainer

In [None]:
#######################################################################
#                             For train                               #
#######################################################################
def model_train(model, optimizer, train_loader, DEVICE):
    # initialization
    train_loss = 0
    batch_num = 0

    # arr = []
    # train
    model.train()
    for inputs, targets in Bar(train_loader):
                batch_num += 1

                # to cuda
                inputs = inputs.float().to(DEVICE)
                targets = targets.float().to(DEVICE)

                if direct_mapping:
                  output_mag, target_mag, _ = model(inputs, targets)
                  loss = model.loss(output_mag, target_mag)
                else:
                  outputs = model(inputs)
                  loss = model.loss(outputs, targets)
                # # if you want to check the scale of the loss
                # print('loss: {:.4}'.format(loss))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss
    train_loss /= batch_num

    return train_loss  ##


#######################################################################
#                           For validation                            #
#######################################################################
def model_validate(model, validation_loader, dir_to_save, epoch, DEVICE):
    # initialization
    validation_loss = 0
    batch_num = 0

    avg_pesq = 0
    avg_stoi = 0

    all_batch_input = []
    all_batch_target = []
    all_batch_output = []

    # for record the score each samples
    f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a')

    model.eval()
    with torch.no_grad():
                for inputs, targets in Bar(validation_loader):
                    batch_num += 1

                    # to cuda
                    inputs = inputs.float().to(DEVICE)
                    targets = targets.float().to(DEVICE)

                    if direct_mapping:
                        output_mag, target_mag, outputs = model(inputs, targets, direct_mapping=True)
                        loss = model.loss(output_mag, target_mag)
                    else:
                        outputs = model(inputs)
                        loss = model.loss(outputs, targets)

                    validation_loss += loss

                    # estimate the output speech with pesq and stoi
                    estimated_wavs = outputs.cpu().detach().numpy()
                    clean_wavs = targets.cpu().detach().numpy()

                    pesq = cal_pesq(estimated_wavs, clean_wavs)
                    stoi = cal_stoi(estimated_wavs, clean_wavs)

                    # pesq: 0.1 better / stoi: 0.01 better
                    for i in range(len(pesq)):
                        f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i]))

                    # reshape for sum
                    pesq = np.reshape(pesq, (1, -1))
                    stoi = np.reshape(stoi, (1, -1))

                    avg_pesq += sum(pesq[0]) / len(inputs)
                    avg_stoi += sum(stoi[0]) / len(inputs)

                validation_loss /= batch_num
                avg_pesq /= batch_num
                avg_stoi /= batch_num

                return validation_loss, avg_pesq, avg_stoi

# Train_interface

In [None]:
###############################################################################
#                        Helper function definition                           #
###############################################################################
# Write training related parameters into the log file.
def write_status_to_log_file(fp, total_parameters):
    fp.write('%d-%d-%d %d:%d:%d\n' %
             (time.localtime().tm_year, time.localtime().tm_mon,
              time.localtime().tm_mday, time.localtime().tm_hour,
              time.localtime().tm_min, time.localtime().tm_sec))
    fp.write('total params   : %d (%.2f M, %.2f MBytes)\n' %
             (total_parameters,
              total_parameters / 1000000.0,
              total_parameters * 4.0 / 1000000.0))


# Calculate the size of total network.
def calculate_total_params(our_model):
    total_parameters = 0
    for variable in our_model.parameters():
        shape = variable.size()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim
        total_parameters += variable_parameters

    return total_parameters


###############################################################################
#         Parameter Initialization and Setting for model training             #
###############################################################################
# Set device
DEVICE = torch.device('cpu' ) # if you want to run the code with 'cpu', change 'cpu'

# Set model
model = CRN().to(DEVICE)
# Set optimizer and learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
total_params = calculate_total_params(model)

###############################################################################
#                        Confirm model information                            #
###############################################################################
print('%d-%d-%d %d:%d:%d\n' %
      (time.localtime().tm_year, time.localtime().tm_mon,
       time.localtime().tm_mday, time.localtime().tm_hour,
       time.localtime().tm_min, time.localtime().tm_sec))
print('total params   : %d (%.2f M, %.2f MBytes)\n' %
      (total_params,
       total_params / 1000000.0,
       total_params * 4.0 / 1000000.0))

###############################################################################
#                              Create Dataloader                              #
###############################################################################
train_loader = create_dataloader(mode='train')
validation_loader = create_dataloader(mode='valid')

###############################################################################
#                        Set a log file to store progress.                    #
#               Set a hps file to store hyper-parameters information.         #
###############################################################################
if chkpt_model is not None:  # Load the checkpoint
    print('Resuming from checkpoint: %s' % chkpt_path)

    # Set a log file to store progress.
    dir_to_save = job_dir + chkpt_model

    checkpoint = torch.load(chkpt_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start_idx = checkpoint['epoch'] + 1
    mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy'))
else:  # First learning
    print('Starting new training run...')
    epoch_start_idx = 1
    mse_vali_total = np.zeros(max_epochs)

    # Set a log file to store progress.
    dir_to_save = job_dir + expr_num + '_%d.%d' % (time.localtime().tm_mon,
                                                           time.localtime().tm_mday) + '_%s' % current_model + '_%s' % current_loss

# make the file directory
if not os.path.exists(dir_to_save):
    os.mkdir(dir_to_save)

# logging
log_fname = str(dir_to_save + '/log.txt')
fp = open(log_fname, 'w')
write_status_to_log_file(fp, total_params)


###############################################################################
###############################################################################
#                             Main program start !!                           #
###############################################################################
###############################################################################

###############################################################################
#                                    Train                                    #
###############################################################################
for epoch in range(epoch_start_idx, max_epochs):
        start_time = time.time()
        # Training
        train_loss = model_train(model, optimizer, train_loader, DEVICE)

        # save checkpoint file to resume training
        save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch))
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }, save_path)

        # Validation
        vali_loss, vali_pesq, vali_stoi = \
        model_validate(model, validation_loader, dir_to_save, epoch, DEVICE)

        print('Epoch [{}] | T {:.6f} | V {:.6} takes {:.2f} seconds\n'
                  .format(epoch, train_loss, vali_loss, time.time() - start_time))
        print('          | V PESQ: {:.6f} | STOI: {:.6f} '.format(vali_pesq, vali_stoi))
        # log file save
        fp.write('Epoch [{}] | T {:.6f} | V {:.6} takes {:.2f} seconds\n'
                     .format(epoch, train_loss, vali_loss, time.time() - start_time))
        fp.write('          | V PESQ: {:.6f} | STOI: {:.6f} \n'.format(vali_pesq, vali_stoi))

        mse_vali_total[epoch - 1] = vali_loss
        np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total)


fp.close()
print('Training has been finished.')

# Copy optimum model that has minimum MSE.
print('Save optimum models...')
min_index = np.argmin(mse_vali_total)
print('Minimum validation loss is at ' + str(min_index + 1) + '.')
src_file = str(dir_to_save + '/' + ('chkpt_%d.pt' % (min_index + 1)))
tgt_file = str(dir_to_save + '/chkpt_opt.pt')
shutil.copy(src_file, tgt_file)