In [None]:
#!pip install tensorflow
#!pip install pedalboard

# TO USE: 
#    1. Upload your input and output wav files to the current directory in Colab
#    2. Edit the USER INPUTS section to point to your wav files, and choose a
#         model name, and number of epochs for training. 
#    3. Run each section of code. The trained models and output wav files will be 
#         added to the "models" directory.
#
#     Note: Tested on CPU and GPU runtimes.
#     Note: Uses MSE for loss calculation instead of Error to Signal with Pre-emphasis filter

'''This is a similar Tensorflow/Keras implementation of the LSTM model from the paper:
    "Real-Time Guitar Amplifier Emulation with Deep Learning"
    https://www.mdpi.com/2076-3417/10/3/766/htm

    Uses a stack of two 1-D Convolutional layers, followed by LSTM, followed by 
    a Dense (fully connected) layer. Three preset training modes are available, 
    with further customization by editing the code. A Sequential tf.keras model 
    is implemented here.

    Note: RAM may be a limiting factor for the parameter "input_size". The wav data
      is preprocessed and stored in RAM, which improves training speed but quickly runs out
      if using a large number for "input_size".  Reduce this if you are experiencing
      RAM issues. 
    
    --training_mode=0   Speed training (default)
    --training_mode=1   Accuracy training
    --training_mode=2   Extended training (set max_epochs as desired, for example 50+)
'''

In [None]:
#import soundfile as sf
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import LSTM, Conv1D, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.backend import clear_session
#from tensorflow.keras.activations import tanh, elu, relu
#from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K
from tensorflow.keras.utils import Sequence
from sklearn.model_selection import train_test_split
from scipy.io import wavfile
import numpy as np
import os
import matplotlib.pyplot as plt
#import math
import h5py
import librosa
from pedalboard import (
    Pedalboard,
    Convolution,
    Compressor,
    Chorus,
    Distortion,
    Gain,
    Reverb,
    Limiter,
    LadderFilter,
    Phaser,
)

In [None]:
# EDIT THIS SECTION FOR USER INPUTS
#

name = 'Chorus_Test'
if not os.path.exists('models/'+name):
    os.makedirs('models/'+name)
else:
    print("A model with the same name already exists. Please choose a new name.")
    exit

epochs = 20
fs = 44100
train_mode = 0     # 0 = speed training, 
                   # 1 = accuracy training 
                   # 2 = extended training

batch_size = 4 
test_size = 0.2
input_size = 100

if train_mode == 0:         # Speed Training
    learning_rate = 0.01 
    conv1d_strides = 12    
    conv1d_filters = 16
    hidden_units = 36
elif train_mode == 1:       # Accuracy Training (~10x longer than Speed Training)
    learning_rate = 0.01 
    conv1d_strides = 4
    conv1d_filters = 36
    hidden_units= 64
else:                       # Extended Training (~60x longer than Accuracy Training)
    learning_rate = 0.0005 
    conv1d_strides = 3
    conv1d_filters = 36
    hidden_units= 96

In [None]:
class WindowArray(Sequence):
        
    def __init__(self, x, y, window_len, batch_size=32):
        self.x = x
        self.y = y[window_len-1:] 
        self.window_len = window_len
        self.batch_size = batch_size
        
    def __len__(self):
        l = len(self.x)
        #return (len(self.x[0]) - self.window_len +1) // self.batch_size
        return (len(self.x) - self.window_len +1) // self.batch_size
    
    def __getitem__(self, index):
        x_out = np.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])
        y_out = self.y[index*self.batch_size:(index+1)*self.batch_size]
        return x_out, y_out

def pre_emphasis_filter(x, coeff=0.95):
    return tf.concat([x, x - coeff * x], 1)
    
def error_to_signal(y_true, y_pred): 
    """
    Error to signal ratio with pre-emphasis filter:
    """
    y_true, y_pred = pre_emphasis_filter(y_true), pre_emphasis_filter(y_pred)
    return K.sum(tf.pow(y_true - y_pred, 2), axis=0) / (K.sum(tf.pow(y_true, 2), axis=0) + 1e-10)
    
def save_wav(name, data):
    wavfile.write(name, fs, data.flatten().astype(np.float32))

# normalize data to loudest signal
def normalize(data):
    data_max = max(data)
    data_min = min(data)
    data_norm = max(data_max,abs(data_min))
    return data / data_norm

# add fadeout with length samples   
def apply_fadeout(audio, length):
    # convert to audio indices (samples)
    #length = int(duration*sr)
    end = len(audio)
    start = end - length

    # compute fade out curve
    # linear fade
    fade_curve = np.linspace(1.0, 0.0, length)
    audio[start:end] = audio[start:end] *fade_curve
    return audio

In [None]:
# Create Sequential Model ###########################################
clear_session()
model = Sequential()
model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same',input_shape=(input_size,1)))
model.add(Conv1D(conv1d_filters, 12,strides=conv1d_strides, activation=None, padding='same'))
model.add(LSTM(hidden_units))
model.add(Dense(1, activation=None))
model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mse', metrics=[error_to_signal])
model.summary()

In [None]:
X_all = dry[:100].astype(np.float32).flatten()  
X_all = normalize(X_all).reshape(len(X_all),1)   
y_all = wet[:100].astype(np.float32).flatten() 
y_all = normalize(y_all).reshape(len(y_all),1)
train_examples = int(len(X_all)*0.8)
train_arr = WindowArray(X_all[:train_examples], y_all[:train_examples], input_size, batch_size=batch_size)
val_arr = WindowArray(X_all[train_examples:], y_all[train_examples:], input_size, batch_size=batch_size)
print(train_arr.__len__())

In [None]:
# Train Sequential Model ###################################################
history = model.fit(train_arr, validation_data=val_arr, epochs=epochs, shuffle=True)    
model.save('models/'+name+'/'+name+'.h5')

# Run Prediction #################################################
print("Running prediction..")

# Get the last 20% of the wav data to run prediction and plot results
y_the_rest, y_last_part = np.split(y_all, [int(len(y_all)*.8)])
x_the_rest, x_last_part = np.split(X_all, [int(len(X_all)*.8)])
y_test = y_last_part[input_size-1:] 
test_arr = WindowArray(x_last_part, y_last_part, input_size, batch_size = batch_size)

prediction = model.predict(test_arr)

save_wav('models/'+name+'/y_pred.wav', prediction)
save_wav('models/'+name+'/x_test.wav', x_last_part)
save_wav('models/'+name+'/y_test.wav', y_test)

# Add additional data to the saved model (like input_size)
filename = 'models/'+name+'/'+name+'.h5'
f = h5py.File(filename, 'a')
grp = f.create_group("info")
dset = grp.create_dataset("input_size", (1,), dtype='int16')
dset[0] = input_size
f.close()

In [None]:
# visualizing losses and accuracy
train_loss, val_loss = history.history['loss'], history.history['val_loss']
#train_acc, val_acc = history.history['accuracy'], history.history['val_accuracy']

# setup plot
fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(15,5))
 
# plot loss
ax[0].plot(range(epochs), train_loss)
ax[0].plot(range(epochs), val_loss)
ax[0].set_ylabel('loss')
ax[0].set_title('train_loss vs val_loss')

# plot accuracy
#ax[1].plot(range(epochs), train_acc)
#ax[1].plot(range(epochs), val_acc)
#ax[1].set_ylabel('accuracy')
#ax[1].set_title('train_acc vs val_acc')

# plot adjustement
for a in ax:
    a.grid(True)
    a.legend(['train','val'],loc=4)
    a.set_xlabel('num of Epochs')
    
plt.show()