In [1]:
import contextlib, glob, os, pickle, platform, random, sys, wave
import numpy as np
from scipy.io.wavfile import read
import soundfile as sf
from scipy.io import loadmat, savemat


def save_wav(path, wav, f_s):
    wav = np.squeeze(wav) 
    if isinstance(wav[0], np.float32): wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
    sf.write(path, wav, f_s)

def read_wav(path):
    wav, f_s = sf.read(path, dtype='int16')
    return wav, f_s

# CAN BE REMOVED
# def save_mat(path, data, name):
#     if not path.endswith('.mat'): path = path + '.mat'
#     savemat(path, {name: data})

def Batch(fdir, snr_l=[]):
    fname_l = [] # list of file names.
    wav_l = [] # list for waveforms.
    snr_test_l = [] # list of SNR levels for the test set.
    fnames = ['*.wav', '*.flac', '*.mp3']

    for fname in fnames:
        for fpath in glob.glob(os.path.join(fdir, fname)):
            for snr in snr_l:
                if fpath.find('_' + str(snr) + 'dB') != -1:
                    snr_test_l.append(snr) 
            (wav, _) = read_wav(fpath) 
            if np.isnan(wav).any() or np.isinf(wav).any():
                raise ValueError('Error: NaN or Inf value. File path: %s.' % (file_path))
            wav_l.append(wav)
            fname_l.append(os.path.basename(os.path.splitext(fpath)[0]))

    len_l = [] 
    maxlen = max(len(wav) for wav in wav_l)
    wav_np = np.zeros([len(wav_l), maxlen], np.int16) 

    for (i, wav) in zip(range(len(wav_l)), wav_l):
        wav_np[i,:len(wav)] = wav
        len_l.append(len(wav)) 
    return wav_np, np.array(len_l, np.int32), np.array(snr_test_l, np.int32), fname_l


In [2]:
from tensorflow.python.ops.signal import window_ops
import functools
import numpy as np
import scipy.special as spsp
import tensorflow as tf

class STFT:
    def __init__(self, N_d, N_s, NFFT, f_s):
        self.N_d = N_d
        self.N_s = N_s
        self.NFFT = NFFT
        self.f_s = f_s
        self.W = functools.partial(window_ops.hamming_window, periodic=False)
        self.ten = tf.cast(10.0, tf.float32)

    def polar_analysis(self, x):
        STFT = tf.signal.stft(x, self.N_d, self.N_s, self.NFFT, window_fn=self.W, pad_end=True)
        return tf.abs(STFT), tf.math.angle(STFT)

    def polar_synthesis(self, STMS, STPS):
        STFT = tf.cast(STMS, tf.complex64)*tf.exp(1j*tf.cast(STPS, tf.complex64))
        return tf.signal.inverse_stft(STFT, self.N_d, self.N_s, self.NFFT, tf.signal.inverse_stft_window_fn(self.N_s, self.W))

class DeepXiInput(STFT):
    def __init__(self, N_d, N_s, NFFT, f_s, mu=None, sigma=None):
        super().__init__(N_d, N_s, NFFT, f_s)
        self.mu = mu
        self.sigma = sigma

    def observation(self, x):
        x = self.normalise(x)
        x_STMS, x_STPS = self.polar_analysis(x)
        return x_STMS, x_STPS

#     def example(self, s, d, s_len, d_len, snr):
#         s_STMS, d_STMS, x_STMS, n_frames = self.mix(s, d, s_len, d_len, snr)
#         mask = tf.expand_dims(tf.cast(tf.sequence_mask(n_frames), tf.float32), 2)
#         xi_bar = tf.multiply(self.xi_bar(s_STMS, d_STMS), mask)
#         return x_STMS, xi_bar, n_frames

#     def mix(self, s, d, s_len, d_len, snr):
#         s, d = self.normalise(s), self.normalise(d)
#         n_frames = self.n_frames(s_len)
#         (x, s, d) = self.add_noise_batch(s, d, s_len, d_len, snr)
#         s_STMS, _ = self.polar_analysis(s)
#         d_STMS, _ = self.polar_analysis(d)
#         x_STMS, _ = self.polar_analysis(x)
#         return s_STMS, d_STMS, x_STMS, n_frames

    def normalise(self, x):
        return tf.truediv(tf.cast(x, tf.float32), 32768.0)

    def n_frames(self, N):
        return tf.cast(tf.math.ceil(tf.truediv(tf.cast(N, tf.float32), tf.cast(self.N_s, tf.float32))), tf.int32)

#     def add_noise_batch(self, s, d, s_len, d_len, snr):
#         return tf.map_fn(lambda z: self.add_noise_pad(z[0], z[1], z[2], z[3], z[4],
#             tf.reduce_max(s_len)), (s, d, s_len, d_len, snr), dtype=(tf.float32, tf.float32,
#             tf.float32), back_prop=False)

#     def add_noise_pad(self, s, d, s_len, d_len, snr, pad_len):
#         s, d = s[:s_len], d[:d_len]
#         (x, d) = self.add_noise(s, d, s_len, d_len, snr)
#         total_zeros = tf.subtract(pad_len, s_len)
#         x = tf.pad(x, [[0, total_zeros]], "CONSTANT")
#         s = tf.pad(s, [[0, total_zeros]], "CONSTANT")
#         d = tf.pad(d, [[0, total_zeros]], "CONSTANT")
#         return (x, s, d)

#     def add_noise(self, s, d, s_len, d_len, snr):
#         snr = tf.cast(snr, tf.float32)
#         snr = tf.pow(self.ten, tf.truediv(snr, self.ten)) # inverse of dB.
#         i = tf.random.uniform([1], 0, tf.add(1, tf.subtract(d_len, s_len)), tf.int32)
#         d = tf.slice(d, [i[0]], [s_len])
#         P_s = tf.reduce_mean(tf.math.square(s), 0) # average power of clean speech.
#         P_d = tf.reduce_mean(tf.math.square(d), 0) # average power of noise.
#         alpha = tf.math.sqrt(tf.truediv(P_s,
#             tf.maximum(tf.multiply(P_d, snr), 1e-12))) # scaling factor.
#         d =	tf.multiply(d, alpha)
#         x = tf.add(s, d)
#         return (x, d)

#     def snr_db(self, s, d):
#         P_s = tf.reduce_mean(tf.math.square(s), 0) # average power of clean speech.
#         P_d = tf.reduce_mean(tf.math.square(d), 0) # average power of noise.
#         return tf.multiply(self.ten, self.log_10(tf.truediv(P_s, P_d)))

    def log_10(self, x):
        return tf.truediv(tf.math.log(x), tf.math.log(self.ten))

    def xi(self, s_STMS, d_STMS):
        return tf.truediv(tf.square(s_STMS), tf.maximum(tf.square(d_STMS), 1e-12))

    def xi_db(self, s_STMS, d_STMS):
        return tf.multiply(10.0, self.log_10(tf.maximum(self.xi(s_STMS, d_STMS), 1e-12)))

    def xi_bar(self, s_STMS, d_STMS):
        return tf.multiply(0.5, tf.add(1.0, tf.math.erf(tf.truediv(tf.subtract(self.xi_db(s_STMS, d_STMS), self.mu),
            tf.multiply(self.sigma, tf.sqrt(2.0))))))

    def xi_hat(self, xi_bar_hat):
        xi_db_hat = np.add(np.multiply(np.multiply(self.sigma, np.sqrt(2.0)),
            spsp.erfinv(np.subtract(np.multiply(2.0, xi_bar_hat), 1))), self.mu)
        return np.power(10.0, np.divide(xi_db_hat, 10.0))

#     def mel_filter_bank(self, M):
#         f_l = 0 # lowest frequency (Hz).
#         f_h = self.f_s/2 # highest frequency (Hz).
#         K = self.NFFT//2 + 1 # number of frequency bins.
#         H = np.zeros([M, K], dtype=np.float32) # mel filter bank.
#         for m in range(1, M + 1):
#             bl = self.bpoint(m - 1, M, f_l, f_h) # lower boundary point, f(m - 1) for m-th filterbank.
#             c = self.bpoint(m, M, f_l, f_h) # m-th filterbank centre point, f(m).
#             bh = self.bpoint(m + 1, M, f_l, f_h) # higher boundary point f(m + 1) for m-th filterbank.
#             for k in range(K):
#                 if k >= bl and k <= c:
#                     H[m-1,k] = (2*(k - bl))/((bh - bl)*(c - bl)) # m-th filterbank up-slope.
#                 if k >= c and k <= bh:
#                     H[m-1,k] = (2*(bh - k))/((bh - bl)*(bh - c)) # m-th filterbank down-slope.
#         return H

#     def bpoint(self, m, M, f_l, f_h):
#         K = self.NFFT//2 + 1 # number of frequency bins.
#         return ((2*K)/self.f_s)*self.mel2hz(self.hz2mel(f_l) + \
#             m*((self.hz2mel(f_h) - self.hz2mel(f_l))/(M + 1))) # boundary point.

#     def hz2mel(self, f):
#         return 2595*np.log10(1 + (f/700))

#     def mel2hz(self, m):
#         return 700*((10**(m/2595)) - 1)


In [3]:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Activation, Add, Conv1D, Conv2D, Dense, Dropout,Flatten, LayerNormalization, MaxPooling2D, ReLU
import numpy as np

class ResNet:
    def __init__(self,inp,n_outp,n_blocks,d_model,d_f,k,max_d_rate,padding,):
        self.d_model = d_model
        self.d_f = d_f
        self.k = k
        self.n_outp = n_outp
        self.padding = padding
        self.first_layer = self.feedforward(inp)
        self.layer_list = [self.first_layer]
        for i in range(n_blocks): self.layer_list.append(self.block(self.layer_list[-1], int(2**(i%(np.log2(max_d_rate)+1)))))
        self.logits = Conv1D(self.n_outp, 1, dilation_rate=1, use_bias=True)(self.layer_list[-1])
        self.outp = Activation('sigmoid')(self.logits)

    def feedforward(self, inp):
        ff = Conv1D(self.d_model, 1, dilation_rate=1, use_bias=False)(inp)
        norm = LayerNormalization(axis=2, epsilon=1e-6)(ff)
        act = ReLU()(norm)
        return act

    def block(self, inp, d_rate):
        self.conv_1 = self.unit(inp, self.d_f, 1, 1, False)
        self.conv_2 = self.unit(self.conv_1, self.d_f, self.k, d_rate,
            False)
        self.conv_3 = self.unit(self.conv_2, self.d_model, 1, 1, True)
        residual = Add()([inp, self.conv_3])
        return residual

    def unit(self, inp, n_filt, k, d_rate, use_bias):
        norm = LayerNormalization(axis=2, epsilon=1e-6)(inp)
        act = ReLU()(norm)
        conv = Conv1D(n_filt, k, padding=self.padding, dilation_rate=d_rate,
            use_bias=use_bias)(act)
        return conv


In [4]:
import numpy as np
from scipy.special import exp1, i0, i1

def mmse_lsa(xi, gamma):
    nu = np.multiply(np.divide(xi, np.add(1, xi)), gamma)
    return np.multiply(np.divide(xi, np.add(1, xi)), np.exp(np.multiply(0.5, exp1(nu)))) # MMSE-LSA gain function.

def gfunc(xi, gamma=None, gtype='mmse-lsa'):
    G = mmse_lsa(xi, gamma)
    return G


In [5]:
from pesq import pesq
from pystoi import stoi
from tensorflow.keras.callbacks import Callback, CSVLogger, ModelCheckpoint
from tensorflow.keras.layers import Input, Masking
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.python.lib.io import file_io
from tqdm import tqdm
import csv, math, os, random 
import numpy as np
import tensorflow as tf

class DeepXi(DeepXiInput):
    def __init__(self,N_d,N_s,NFFT,f_s,min_snr,max_snr,snr_inter,ver="resnet-1.0c",**kwargs):
        super().__init__(N_d, N_s, NFFT, f_s)
        self.min_snr = min_snr
        self.max_snr = max_snr
        self.ver=ver
        self.snr_levels = list(range(self.min_snr, self.max_snr + 1, snr_inter))
        self.n_feat = math.ceil(self.NFFT/2 + 1)
        self.n_outp = self.n_feat
        self.inp = Input(name='inp', shape=[None, self.n_feat], dtype='float32')
        self.mask = Masking(mask_value=0.0)(self.inp)
        self.network = ResNet(
            inp=self.mask,
            n_outp=self.n_outp,
            n_blocks=kwargs['n_blocks'],
            d_model=kwargs['d_model'],
            d_f=kwargs['d_f'],
            k=kwargs['k'],
            max_d_rate=kwargs['max_d_rate'],
            padding=kwargs['padding'],
            )
        self.model = Model(inputs=self.inp, outputs=self.network.outp)
        self.model.summary()
        if not os.path.exists("log/summary"):
            os.makedirs("log/summary")
        with open("log/summary/"+ self.ver + ".txt", "w") as f:
            self.model.summary(print_fn=lambda x: f.write(x + '\n'))


    def infer(self,test_x,test_x_len,test_x_base_names,model_path='model/saved_model/',out_type='y',gain='mmse-lsa',out_path='out/denoised/',stats_path=None,n_filters=40,):
#         if out_type == 'xi_hat': out_path = out_path + '/xi_hat'
#         elif out_type == 'y': 
        out_path = out_path + 'y/' + gain
#         elif out_type == 'deepmmse': out_path = out_path + '/deepmmse'
#         elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat'
#         elif out_type == 'subband_ibm_hat': out_path = out_path + '/subband_ibm_hat'
#         else: raise ValueError('Invalid output type.')
#         if not os.path.exists(out_path): os.makedirs(out_path)

#         if out_type == 'subband_ibm_hat':
#             mel_filter_bank = self.mel_filter_bank(n_filters)

        self.sample_stats(stats_path)
        self.model.load_weights(model_path +
            'variables/variables' )

        print("Processing observations...")
        x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch(test_x, test_x_len)
        print("Performing inference...")
        xi_bar_hat_batch = self.model.predict(x_STMS_batch, batch_size=1, verbose=1)

        print("Performing synthesis...")
        batch_size = len(test_x_len)
        for i in tqdm(range(batch_size)):
            base_name = test_x_base_names[i]
            x_STMS = x_STMS_batch[i,:n_frames[i],:]
            x_STPS = x_STPS_batch[i,:n_frames[i],:]
            xi_bar_hat = xi_bar_hat_batch[i,:n_frames[i],:]
            xi_hat = self.xi_hat(xi_bar_hat)
#             if out_type == 'xi_hat': save_mat(args.out_path + '/' + base_name + '.mat',
#                 xi_hat, 'xi_hat')
#             elif out_type == 'y':
            y_STMS = np.multiply(x_STMS, gfunc(xi_hat, xi_hat+1, gtype=gain))
            y = self.polar_synthesis(y_STMS, x_STPS).numpy()
            save_wav(out_path + '/'+ base_name + '.wav', y, self.f_s)
#             elif out_type == 'deepmmse':
#                 d_PSD_hat = np.multiply(np.square(x_STMS), gfunc(xi_hat, xi_hat+1,
#                     gtype='deepmmse'))
#                 save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat, 'd_psd_hat')
#             elif out_type == 'ibm_hat':
#                 ibm_hat = np.greater(xi_hat, 1.0).astype(bool)
#                 save_mat(out_path + '/' + base_name + '.mat', ibm_hat, 'ibm_hat')
#             elif out_type == 'subband_ibm_hat':
#                 xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose())
#                 subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool)
#                 save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat,
#                     'subband_ibm_hat')
#             else: raise ValueError('Invalid output type.')

    def sample_stats(self,stats_path='data',sample_size=1000,train_s_list=None,train_d_list=None):
        if os.path.exists(stats_path + '/stats.npz'):
            print('Loading sample statistics...')
            with np.load(stats_path + '/stats.npz') as stats:
                self.mu = stats['mu_hat']
                self.sigma = stats['sigma_hat']
        elif train_s_list == None:
            raise ValueError('No stats.npz file exists. data/stats.p is available here: https://github.com/anicolson/DeepXi/blob/master/data/stats.npz.')
        else:
            print('Finding sample statistics...')
            s_sample_list = random.sample(self.train_s_list, sample_size)
            d_sample_list = random.sample(self.train_d_list, sample_size)
            s_sample, d_sample, s_sample_len, d_sample_len, snr_sample = self.wav_batch(s_sample_list, d_sample_list)
            snr_sample = np.array(random.choices(self.snr_levels, k=sample_size))
            samples = []
            for i in tqdm(range(s_sample.shape[0])):
                s_STMS, d_STMS, _, _ = self.mix(s_sample[i:i+1], d_sample[i:i+1], s_sample_len[i:i+1],
                    d_sample_len[i:i+1], snr_sample[i:i+1])
                xi_db = self.xi_db(s_STMS, d_STMS) # instantaneous a priori SNR (dB).
                samples.append(np.squeeze(xi_db.numpy()))
            samples = np.vstack(samples)
            if len(samples.shape) != 2: raise ValueError('Incorrect shape for sample.')
            stats = {'mu_hat': np.mean(samples, axis=0), 'sigma_hat': np.std(samples, axis=0)}
            self.mu, self.sigma = stats['mu_hat'], stats['sigma_hat']
            if not os.path.exists(stats_path): os.makedirs(stats_path)
            np.savez(stats_path + '/stats.npz', mu_hat=stats['mu_hat'], sigma_hat=stats['sigma_hat'])
            save_mat(stats_path + '/stats.mat', stats, 'stats')
            print('Sample statistics saved.')


    def observation_batch(self, x_batch, x_batch_len):
        batch_size = len(x_batch)
        max_n_frames = self.n_frames(max(x_batch_len))
        x_STMS_batch = np.zeros([batch_size, max_n_frames, self.n_feat], np.float32)
        x_STPS_batch = np.zeros([batch_size, max_n_frames, self.n_feat], np.float32)
        n_frames_batch = [self.n_frames(i) for i in x_batch_len]
        for i in tqdm(range(batch_size)):
            x_STMS, x_STPS = self.observation(x_batch[i,:x_batch_len[i]])
            x_STMS_batch[i,:n_frames_batch[i],:] = x_STMS.numpy()
            x_STPS_batch[i,:n_frames_batch[i],:] = x_STPS.numpy()
        return x_STMS_batch, x_STPS_batch, n_frames_batch

#     def wav_batch(self, s_list, d_list):
#         batch_size = len(s_list)
#         max_len = max([dic['wav_len'] for dic in s_list])
#         s_batch = np.zeros([batch_size, max_len], np.int16)
#         d_batch = np.zeros([batch_size, max_len], np.int16)
#         s_batch_len = np.zeros(batch_size, np.int32)
#         for i in range(batch_size):
#             (wav, _) = read_wav(s_list[i]['file_path'])
#             s_batch[i,:s_list[i]['wav_len']] = wav
#             s_batch_len[i] = s_list[i]['wav_len']
#             flag = True
#             while flag:
#                 if d_list[i]['wav_len'] < s_batch_len[i]: d_list[i] = random.choice(self.train_d_list)
#                 else: flag = False
#             (wav, _) = read_wav(d_list[i]['file_path'])
#             rand_idx = np.random.randint(0, 1+d_list[i]['wav_len']-s_batch_len[i])
#             d_batch[i,:s_batch_len[i]] = wav[rand_idx:rand_idx+s_batch_len[i]]
#         d_batch_len = s_batch_len
#         # snr_batch = np.random.randint(self.min_snr, self.max_snr+1, batch_size)
#         snr_batch = np.array(random.choices(self.snr_levels, k=batch_size))
#         return s_batch, d_batch, s_batch_len, d_batch_len, snr_batch


In [6]:
# from tensorflow.keras.layers import Input
# from tensorflow.keras.models import Model
# from tensorflow.keras.optimizers import Adam
# import math
# import numpy as np
# import tensorflow as tf

# class Prelim():
#     def __init__(self,n_feat,network):
#         self.n_feat = n_feat
#         self.n_outp = self.n_feat
#         if self.n_feat < 5: raise ValueError('More input features are required for this exampple.')
#         self.inp = Input(name='inp', shape=[None, self.n_feat], dtype='float32')
#         self.mask = tf.keras.layers.Masking(mask_value=0.0)(self.inp)
#         if network == 'ResNet': self.network = ResNet(self.mask, self.n_outp, B=40, d_model=256, d_f=64, k=3, max_d_rate=16)
#         elif network == 'ResLSTM': self.network = ResLSTM(self.mask, self.n_outp, n_blocks=3, d_model=256)
#         else: raise ValueError('Invalid network type.')
#         self.model = Model(inputs=self.inp, outputs=self.network.outp)
#         self.model.summary()

#     def dataset(self, buffer_size=16):
#         dataset = tf.data.Dataset.from_generator(
#             self.mbatch_gen,
#             (tf.float32, tf.float32, tf.float32),
#             (tf.TensorShape([None, None, self.n_feat]),
#                 tf.TensorShape([None, None, self.n_outp]),
#                 tf.TensorShape([None, None]))
#             )
#         dataset = dataset.prefetch(buffer_size)
#         return dataset

#     def mbatch_gen(self):
#         for _ in range(self.max_epochs):
#             for _ in range(math.ceil(self.batch_size/self.mbatch_size)):
#                 max_seq_len = 75
#                 min_seq_len = 45
#                 x_train = np.random.rand(self.mbatch_size, max_seq_len, self.n_feat)
#                 y_frame = np.zeros(self.n_feat)
#                 y_frame[0] = 0.05
#                 y_frame[1] = 0.99
#                 y_frame[2] = 0.5
#                 y_frame[3] = 0.01
#                 y_frame[4] = 0.75
#                 y_train = np.tile(y_frame, (self.mbatch_size, max_seq_len, 1))
#                 seq_len = np.random.randint(min_seq_len, max_seq_len+1, self.mbatch_size)
#                 seq_mask = tf.cast(tf.sequence_mask(seq_len, maxlen=max_seq_len), tf.float32)
#                 x_train = tf.multiply(x_train, tf.expand_dims(seq_mask, 2))
#                 y_train = tf.multiply(y_train, tf.expand_dims(seq_mask, 2))
#                 yield x_train, y_train, seq_mask


In [7]:
#VARIABLES FOR THE MODEL
d_model  = 256
n_blocks  = 40
d_f = 64
k =  3
max_d_rate = 16 
causal  =   1
ver="resnet-1.0c"
f_s  = 16000 
T_d  = 32 
T_s  =  16 
min_snr = -10 # CAN BE REMOVED
max_snr =  20 # CAN BE REMOVED
snr_inter = 1 # CAN BE REMOVED

# VARIABLES FOR INFERENCE
data_path='data'
test_x_path='set/test_noisy_speech'
out_path='out/denoised/'
model_path='model/saved_model/'
gain='mmse-lsa'



if causal: padding = "causal"
else: padding = "same"

N_d = int(f_s*T_d*0.001) # window duration (samples).
N_s = int(f_s*T_s*0.001) # window shift (samples).
NFFT = int(pow(2, np.ceil(np.log2(N_d)))) # number of DFT components.

test_x, test_x_len, _, test_x_base_names = Batch(test_x_path)

deepxi = DeepXi(N_d=N_d,N_s=N_s,NFFT=NFFT,f_s=f_s,min_snr=min_snr,max_snr=max_snr,
  snr_inter=snr_inter,d_model=d_model,n_blocks=n_blocks,d_f=d_f,k=k,max_d_rate =max_d_rate,padding=padding,
  causal=causal,ver= ver)

deepxi.infer(test_x=test_x,test_x_len=test_x_len,test_x_base_names=test_x_base_names,
        model_path=model_path,gain=gain,out_path= out_path,stats_path= data_path)


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inp (InputLayer)                [(None, None, 257)]  0                                            
__________________________________________________________________________________________________
masking (Masking)               (None, None, 257)    0           inp[0][0]                        
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 256)    65792       masking[0][0]                    
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, None, 256)    512         conv1d[0][0]                     
______________________________________________________________________________________________

100%|██████████| 2/2 [00:00<00:00, 22.68it/s]

Processing observations...
Performing inference...







  0%|          | 0/2 [00:00<?, ?it/s]

Performing synthesis...


100%|██████████| 2/2 [00:00<00:00,  5.40it/s]


**Removed in code:**

- Removed the argument **'network_type'** from DeepXi class __init__ method. ( As we do not need ResNetLSTM network for out application)
- Removed the class, **'class SaveWeights'** as it is redundant. (It was not used anywhere in our required code)
- Removed the variables of the main cell, **'train_s_path','train_d_path','val_s_path','val_d_path'** as they are redundant.
- Removed the variable of the main cell, **'set_path'** and correspinding if condition as they are redunant. ( used for above variables, which are also redundant)
- Removed the variables of the main cell **'save_model','max_epochs','resume_epoch'** as they are redundant.
- Removed the argument of deepxi.infer, **'test_epoch'**  and coresponding if condition as they are redundant.
- Removed the methods of deepxi, **'dataset', 'mbatch_gen', 'val_batch','add_score', 'TransformerSchedular'** as they are redundant. ( used during training and validation of the model)
- Removed the variables of main cell **'eval_example','log_iter','mbatch_size','sample_size'** as they are redundant.
- The if statement : **'if infer or test:'** of main cell is removed as are anyway preparing the code for inference. So, the corresponding flag variable **'infer'** of the main cell is also removed.


**Changed in code:**

- Changed the path of output, **'out_path' as '/out/denoised/'**
- Changed the value of argument **'model_path'** from DeepXi class __init__ method as 'model/saved_model/'. (As the older path had more recursive directories)

**Can be Removed(Commented as of now):** [Can be removed after a discussion]

- The variable **'out_type'** in deepxi.infer method. Can be removed as we are using only one out_type (**'y'**) for our inference currently.  
- Commented (can be removed as well after discussion) the out_types : **deepmmse, xi_hat, ibm_hat, sub_band_ibm_hat** and corresponding if-else conditions from **deepxi.infer** method.   
NOTE: ( We are using only one **out_type = 'y'** that outputs .wav file, other types are redundant as all the corresponding output files are matlab files, which i assume used for performance visulaization. I'm not sure if we need those for analysis. But, w.r.t inference, it is not required)  
-The **'Class Prelim'** seems to be redundant.  

- The deepxi method **'wav_batch'** can be removed if we are going to stick with only loading the **sample stats** file. Because it is only needed when we want to find **sample_stats** from scratch.
- The variable **snr_levels,min_snr,max_snr,snr_inter** of deepxi class can be removed if we not going to perform **sample_stats**. 
- The gain - **'mmse-lsa'** is what we are using. So, we can remove the related variables.
- The function **'save_mat'** can be removed if we not going to perform **sample_stats**. 
- In the **Class DeepXiInput**, The methods **'example','mix','add_noise_batch','add_noise_pad','add_noise','snr_db','mel_filter_bank','bpoint','hz2mel','mel2hz'** are redundant for the inference part.
- In the main cell, **if condition** for **padding** can be removed, as there we it is sure that we are using **padding='causal'**.

**No idea and need to clarify**

- **'sample_stats'** method of deepxi.infer : I'm not sure of what it is. I could see it loads a **.npz file** from **data** directory. I beleive it is one of the file which we need as pre-requiste to infer the model. It is something that affects the variables depended on **'xi_hat'**.