<a href="https://colab.research.google.com/github/tmontaj/WaveNet-tf2/blob/master/model/Wavenet_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from tensorflow.python import keras
from tensorflow.python.keras import backend as K

import numpy as np
import tensorflow as tf

import librosa
from scipy.io import wavfile



In [37]:
io_channels, hidden_channels = 128,128

sample_rate= 24000 # input will be standardised to this rate

fft_step   = 12.5/1000. # 12.5ms
fft_window = 50.0/1000.  # 50ms

n_fft = 512*4

hop_length = int(fft_step*sample_rate)
win_length = int(fft_window*sample_rate)

n_mels = 80
fmin = 125 # Hz
#fmax = ~8000

#np.exp(-7.0), np.log(spectra_abs_min)  # "Audio tests" suggest a min log of -4.605 (-6 confirmed fine)
spectra_abs_min = 0.01

mel_bins, spectra_bins = n_mels, n_fft//2+1 # 80, 1025
#steps_total, steps_leadin = 1024, 64
steps_total, steps_leadin = 549, 64


In [38]:
def wavenet_layer(channels, hidden_channels, kernel_size, dilation_rate, name):
    def f(input_):
        filter_out = keras.layers.Conv1D(hidden_channels, kernel_size,
                                       strides=1, dilation_rate=dilation_rate,
                                       padding='valid', use_bias=True, 
                                       activation='tanh', name='filter_'+name)(input_)
        gate_out   = keras.layers.Conv1D(hidden_channels, kernel_size,
                                       strides=1, dilation_rate=dilation_rate,
                                       padding='valid', use_bias=True, 
                                       activation='sigmoid', name='gate_'+name)(input_)
        mult = keras.layers.Multiply(name='mult_'+name)( [filter_out, gate_out] )
        

        mult_padded = keras.layers.ZeroPadding1D( (dilation_rate*(kernel_size-1), 0) )(mult)

        transformed = keras.layers.Conv1D(channels, 1, 
                                          padding='same', use_bias=True, 
                                          activation='linear', name='trans_'+name)(mult_padded)
        skip_out    = keras.layers.Conv1D(channels, 1, 
                                          padding='same', use_bias=True, 
                                          activation='relu', name='skip_'+name)(mult_padded)
        
        return keras.layers.Add(name='resid_'+name)( [transformed, input_] ), skip_out
      
    return f

In [39]:
def model_mel_to_spec( input_shape=(steps_total, mel_bins) ):
    
    mel_log = keras.layers.Input(shape=input_shape, name='MelInput')
    phase0  = keras.layers.Input(shape=input_shape, name='Phase0')  # Unused
    
    x = keras.layers.BatchNormalization()(mel_log)

    # 'Resize' to make everything 'io_channels' big at the layer interfaces
    x = s0 = keras.layers.Conv1D(io_channels, 1, 
                          padding='same', use_bias=True, 
                          activation='linear', name='mel_log_expanded')(x)
 
    
    x,s1 = wavenet_layer(io_channels, hidden_channels*1, 3, 1, '1')(x)
    x,s2 = wavenet_layer(io_channels, hidden_channels*1, 3, 2, '2')(x)
    x,s3 = wavenet_layer(io_channels, hidden_channels*1, 3, 4, '3')(x)
    x,s4 = wavenet_layer(io_channels, hidden_channels*1, 3, 8, '4')(x)
    x,s5 = wavenet_layer(io_channels, hidden_channels*1, 3, 16, '5')(x)
    x,s6 = wavenet_layer(io_channels, hidden_channels*1, 3, 32, '6')(x)
    x,s7 = wavenet_layer(io_channels, hidden_channels*1, 3, 64, '7')(x)
    x,s8 = wavenet_layer(io_channels, hidden_channels*1, 3, 128, '8')(x)
    x,s9 = wavenet_layer(io_channels, hidden_channels*1, 3, 256, '9')(x)
    #x,s10 = wavenet_layer(io_channels, hidden_channels*1, 3, 512, '10')(x)
    x,s11 = wavenet_layer(io_channels, hidden_channels*1, 3, 1, '11')(x)
    x,s12 = wavenet_layer(io_channels, hidden_channels*1, 3, 2, '12')(x)
    x,s13 = wavenet_layer(io_channels, hidden_channels*1, 3, 4, '13')(x)
    x,s14 = wavenet_layer(io_channels, hidden_channels*1, 3, 8, '14')(x)
    x,s15 = wavenet_layer(io_channels, hidden_channels*1, 3, 16, '15')(x)
    x,s16 = wavenet_layer(io_channels, hidden_channels*1, 3, 32, '16')(x)
    x,s17 = wavenet_layer(io_channels, hidden_channels*1, 3, 64, '17')(x)
    x,s18 = wavenet_layer(io_channels, hidden_channels*1, 3, 128, '18')(x)
    x,s19 = wavenet_layer(io_channels, hidden_channels*1, 3, 256, '19')(x)
    #x,s20 = wavenet_layer(io_channels, hidden_channels*1, 3, 512, '20')(x)
    x,s21 = wavenet_layer(io_channels, hidden_channels*1, 3, 1, '21')(x)
    x,s22 = wavenet_layer(io_channels, hidden_channels*1, 3, 2, '22')(x)
    x,s23 = wavenet_layer(io_channels, hidden_channels*1, 3, 4, '23')(x)
    x,s24 = wavenet_layer(io_channels, hidden_channels*1, 3, 8, '24')(x)
    x,s25 = wavenet_layer(io_channels, hidden_channels*1, 3, 16, '25')(x)
    x,s26 = wavenet_layer(io_channels, hidden_channels*1, 3, 32, '26')(x)
    x,s27 = wavenet_layer(io_channels, hidden_channels*1, 3, 64, '27')(x)
    x,s28 = wavenet_layer(io_channels, hidden_channels*1, 3, 128, '28')(x)
    x,s29 = wavenet_layer(io_channels, hidden_channels*1, 3, 256, '29')(x)
    #_,s30 = wavenet_layer(io_channels, hidden_channels*1, 3, 512, '30')(x)  
    #x is now irrelevant
    
    # skip_overall = keras.layers.Concatenate( axis=-1 )( [s0,s1,s2,s3,s4,s5,s6,s7,s8,s9] )
    skip_overall = keras.layers.Concatenate( axis=-1 )( [s0,s1,s2,s3,s4,s5,s6,s7,s8,s9,s11,s12,s13,s14,s15,s16,s17,s18,s19,
                                                         s21,s22,s23,s24,s25,s26,s27,s28,s29] )
    
    log_amp     = keras.layers.Conv1D(spectra_bins, 1, padding='same', 
                                  activation='linear', name='log_amp')(skip_overall)
    phase_shift = keras.layers.Conv1D(spectra_bins, 1, padding='same', 
                                  activation='linear', name='phase_shift')(skip_overall)
    
    
    
    log_amp_valid     = keras.layers.Cropping1D( (steps_leadin,0), name='crop_a' )( log_amp )
    phase_shift_valid = keras.layers.Cropping1D( (steps_leadin,0), name='crop_p' )( phase_shift )
    
    # Concat the amps and phases into one return value
    spec_concat = keras.layers.Concatenate( axis=-1, name='spec_concat')( 
        [log_amp_valid, phase_shift_valid] )
    
    return keras.models.Model(inputs= mel_log, outputs=spec_concat)

keras_model = model_mel_to_spec()
keras_model.summary()

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
MelInput (InputLayer)           [(None, 549, 80)]    0                                            
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 549, 80)      320         MelInput[0][0]                   
__________________________________________________________________________________________________
mel_log_expanded (Conv1D)       (None, 549, 128)     10368       batch_normalization_8[0][0]      
__________________________________________________________________________________________________
filter_1 (Conv1D)               (None, 547, 128)     49280       mel_log_expanded[0][0]           
____________________________________________________________________________________________

In [40]:
def customLoss(spec_gold, spec_out):
    gold_l_amp = keras.layers.Lambda(lambda x : x[:,:,:spectra_bins])(spec_gold)
    gold_phase = keras.layers.Lambda(lambda x : x[:,:,spectra_bins:])(spec_gold)
    
    spec_l_amp = keras.layers.Lambda(lambda x : x[:,:,:spectra_bins])(spec_out)
    spec_phase = keras.layers.Lambda(lambda x : x[:,:,spectra_bins:])(spec_out)
    
    l_amp_loss = keras.losses.mean_squared_error( gold_l_amp, spec_l_amp )
    
    phase_loss = keras.losses.mean_squared_error( gold_phase, spec_phase )
    
    return l_amp_loss + 1.0 * phase_loss

keras_model.compile(loss=customLoss, 
                    #optimizer=keras.optimizers.RMSprop(),  # lr=2e-5
                    optimizer=tf.keras.optimizers.Adam(),  # lr=2e-5
                    metrics=['mse'])

In [None]:
def load_wav(path, sampling_rate):
    wav = librosa.core.load(path, sr=sampling_rate)[0]

    return wav

def normalize(wav):
    return librosa.util.normalize(wav)


def save_wav(wav, path, sr):
    wav *= 32767 / max(0.0001, np.max(np.abs(wav)))
    wavfile.write(path, sr, wav.astype(np.int16))



def melspectrogram(wav, sampling_rate, num_mels, n_fft, hop_size, win_size):
    d = librosa.stft(y=wav, n_fft=n_fft, hop_length=hop_size,
                     win_length=win_size, pad_mode='constant')
    mel_filter = librosa.filters.mel(sampling_rate, n_fft,
                                     n_mels=num_mels)
    s = np.dot(mel_filter, np.abs(d))

    return np.log10(np.maximum(s, 1e-5))

In [41]:
wav = load_wav("/content/Hol_After.wav", sample_rate)
wav = normalize(wav) * 0.95

mel_sp = melspectrogram(wav, sample_rate, n_mels,
                            n_fft= n_fft, hop_size= hop_length, win_size= win_length)

mel_sp = tf.transpose(mel_sp)

mel_sp = tf.expand_dims(mel_sp, axis=0)



mel_sp.shape


TensorShape([1, 549, 80])

In [44]:
r= keras_model.predict(mel_sp)

In [45]:
r

array([[[-0.3698137 , -0.79270756, -0.43455195, ...,  0.34498277,
         -0.37351266, -0.0866404 ],
        [-0.26663965, -0.6877992 , -0.37443495, ...,  0.23802486,
         -0.3141052 ,  0.13533504],
        [-0.4044853 , -0.5826516 , -0.15386581, ...,  0.11927441,
         -0.454512  ,  0.14621621],
        ...,
        [-1.8331249 , -0.587446  , -0.4204126 , ..., -0.5169747 ,
         -0.37288076, -1.6947552 ],
        [-1.6884118 , -0.93916166, -0.26628405, ..., -0.2836218 ,
         -0.25958475, -1.5539627 ],
        [-1.8286526 , -1.0480152 , -0.3294282 , ..., -0.35248396,
         -0.17936859, -1.4426664 ]]], dtype=float32)

In [46]:
r.shape

(1, 485, 2050)