In [None]:
import os
import math
import numpy as np
import tensorflow as tf
import onnx
import tf2onnx
from scipy.io import wavfile
from scipy.io.wavfile import write

# 1. Preprocessing

In [None]:
def prepare_training_data(config):
    in_rate, in_data = wavfile.read(config["input_audio_path"])
    out_rate, out_data = wavfile.read(config["target_audio_path"])
    
    if len(in_data) != len(out_data):
        print("input and target files have different lengths")
        sys.exit()
      
    if len(in_data.shape) > 1 or len(out_data.shape) > 1:
        print("expected mono files")
        sys.exit()

    # Convert PCM16 to FP32
    if in_data.dtype == "int16":
        in_data = in_data / 32767
        print("In data converted from PCM16 to FP32")
    if out_data.dtype == "int16":
        out_data = out_data / 32767
        print("Out data converted from PCM16 to FP32")    

    clean_data = in_data.astype(np.float32).flatten()
    target_data = out_data.astype(np.float32).flatten()

    # Split the data on a twenty percent mod
    in_train, out_train, in_val, out_val = slice_on_mod(clean_data, target_data)

    save_wav(config["output_path"] + "/train/" + config["name"] + "-input.wav", in_train)
    save_wav(config["output_path"] + "/train/" + config["name"] + "-target.wav", out_train)

    save_wav(config["output_path"] + "/test/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/test/" + config["name"] + "-target.wav", out_val)

    save_wav(config["output_path"] + "/val/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/val/" + config["name"] + "-target.wav", out_val)


In [None]:
def slice_on_mod(input_data, target_data, mod=5):
    # Split the data on a modulus.

    # Type cast to an integer the modulus
    mod = int(mod)

    # Split the data into 100 pieces
    input_split = np.array_split(input_data, 100)
    target_split = np.array_split(target_data, 100)

    val_input_data = []
    val_target_data = []
    # Traverse the range of the indexes of the input signal reversed and pop every 5th for val
    for i in reversed(range(len(input_split))):
        if i % mod == 0:
            # Store the validation data
            val_input_data.append(input_split[i])
            val_target_data.append(target_split[i])
            # Remove the validation data from training
            input_split.pop(i)
            target_split.pop(i)

    # Flatten val_data down to one dimension and concatenate
    val_input_data = np.concatenate(val_input_data)
    val_target_data = np.concatenate(val_target_data)

    # Concatenate back together
    training_input_data = np.concatenate(input_split)
    training_target_data = np.concatenate(target_split)
    return training_input_data, training_target_data, val_input_data, val_target_data

In [None]:
def save_wav(name, data):
    directory = os.path.dirname(name)

    if not os.path.exists(directory):
        os.makedirs(directory)
        
    wavfile.write(name, 44100, data.flatten().astype(np.float32))

In [None]:
importConfig = {
    "input_audio_path": "TrainingData/ts9-input.wav",
    "target_audio_path": "TrainingData/ts9-target.wav",
    "output_path": "Data",
    "name": "ts9"
}

prepare_training_data(importConfig)

# Dataloader

In [None]:
class DataSet:
    def __init__(self, data_dir='Data/'):
        self.data_dir = data_dir
        self.subsets = {}

    def create_subset(self, name, frame_len=22050):
        self.subsets[name] = {'input': None, 'target': None, 'frame_len': frame_len}

    def load_file(self, subset_name, base_filename):
        if subset_name not in self.subsets:
            raise ValueError(f"Subset '{subset_name}' does not exist")

        input_file = os.path.join(self.data_dir, f"{base_filename}-input.wav")
        target_file = os.path.join(self.data_dir, f"{base_filename}-target.wav")

        try:
            self.subsets[subset_name]['input'] = self.load_and_process(input_file, self.subsets[subset_name]['frame_len'])
            self.subsets[subset_name]['target'] = self.load_and_process(target_file, self.subsets[subset_name]['frame_len'])
        except FileNotFoundError as e:
            print(f"File Not Found: {e.filename}")
            return

    def load_and_process(self, file_path, frame_len):
        sample_rate, data = wavfile.read(file_path)
        data = data.astype(np.float32)
        return self.framify(data, frame_len)

    def framify(self, audio, frame_len):
        seg_num = math.ceil(audio.shape[0] / frame_len)
        padded_length = seg_num * frame_len
        padded_audio = np.pad(audio, (0, padded_length - audio.shape[0]), mode='constant')

        reshaped_audio = np.reshape(padded_audio, (seg_num, frame_len, 1))
        return tf.convert_to_tensor(reshaped_audio, dtype=tf.float32)


# Training

In [None]:
class StatefulLSTM(tf.keras.Model):
    def __init__(self, input_size=1, output_size=1, hidden_size=32, skip=1, bias_fl=True, batch_size=4096):
        super(StatefulLSTM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.skip = skip
        self.batch_size = batch_size

        self.lstm = tf.keras.layers.LSTM(units=hidden_size, return_sequences=True, stateful=True, return_state=True, batch_size=batch_size)    
        self.dense = tf.keras.layers.Dense(units=output_size, activation=None, batch_size=batch_size, use_bias=bias_fl)
        
        # Build LSTM before training, because stateful lstm requires information batch size to build static graph
        self.lstm.build((batch_size, input_size, 1))
        
    def call(self, x):
        x, _, _ = self.lstm(x)
        x = self.dense(x)
        return x
    
    def reset_hidden(self, batch_size):
        self.lstm.reset_states()
    
    def train_epoch(self, input_data, target_data, loss_fcn, optim, bs, init_len=200, up_fr=1000):

        # shuffle the segments at the start of the epoch
        shuffle = tf.random.shuffle(tf.range(input_data.shape[0]))
    
        self.reset_hidden(bs)

        # Iterate over the batches
        ep_loss = 0
        for batch_i in range(math.ceil(shuffle.shape[0] / bs)):            

            # Use tf.gather to index the tensors
            input_batch = tf.gather(input_data, shuffle[batch_i * bs:(batch_i + 1) * bs], axis=0)
            target_batch = tf.gather(target_data, shuffle[batch_i * bs:(batch_i + 1) * bs], axis=0)
            
            # Initialise network hidden state by processing some samples then zero the gradient buffers
            # For training processing eine Anfangssequenz, damit ein brauchbarer hidden state vorliegt
            # Training startet erst nach! einem eingelaufen hidden state
            self(input_batch[:, 0:init_len, :])
        
            start_i = init_len
            batch_loss = 0
            # Iterate over the remaining samples in the mini batch
            for k in range(math.ceil((input_batch.shape[1] - init_len) / up_fr)):
                
                with tf.GradientTape() as g:
                    # Process input batch with neural network    
                    output = self(input_batch[:, start_i:start_i + up_fr, :])
                    loss = loss_fcn(output, target_batch[:, start_i:start_i + up_fr, :])
                    with g.stop_recording():
                        dloss_dw = g.gradient(loss, self.trainable_variables)
                        optim.apply_gradients(zip(dloss_dw, self.trainable_variables))
                    g.reset()
                        
                print(f"loss: {loss}")

                # Update the start index for the next iteration and add the loss to the batch_loss total
                start_i += up_fr
                batch_loss += loss

            # Add the average batch loss to the epoch loss and reset the hidden states to zeros
            ep_loss += batch_loss / (k + 1)
        
        return ep_loss / (batch_i + 1)

In [None]:
class ESRLoss(tf.keras.losses.Loss):
    def __init__(self):
        super(ESRLoss, self).__init__()
        self.epsilon = 1e-5

    def call(self, y_true, y_pred):
        loss = tf.reduce_mean(tf.square(y_true - y_pred))
        energy = tf.reduce_mean(tf.square(y_true)) + self.epsilon
        return loss / energy

class DCLoss(tf.keras.losses.Loss):
    def __init__(self):
        super(DCLoss, self).__init__()
        self.epsilon = 1e-5

    def call(self, y_true, y_pred):
        loss = tf.reduce_mean(tf.square(tf.reduce_mean(y_true, axis=0) - tf.reduce_mean(y_pred, axis=0)))
        energy = tf.reduce_mean(tf.square(y_true)) + self.epsilon
        return loss / energy

class LossWrapper(tf.keras.losses.Loss):
    def __init__(self, loss_weights):
        super(LossWrapper, self).__init__()
        # Map the loss names to their corresponding classes
        loss_dict = {'ESR': ESRLoss, 'DC': DCLoss}
        # Create instances of the loss functions
        self.loss_functions = [loss_dict[key]() for key in ["ESR", "DC"]]
        # Assign the weights
        self.loss_factors = [loss_weights[key] for key in ["ESR", "DC"]]

    def call(self, y_true, y_pred):
        total_loss = 0
        for i, loss_function in enumerate(self.loss_functions):
            total_loss += loss_function(y_true, y_pred) * self.loss_factors[i]
        return total_loss


In [None]:
config = {
    "input_size": 1, # Number of input channels
    "output_size": 1, # Number of output channels
    "skip_con": 0, # is there a skip connection for the input to the output
    "epochs": 100,
    "batch_size": 16,
    "init_length": 200, # Number of sequence samples to process before starting weight updates
    "up_fr": 1000, # For recurrent models, number of samples to run in between updating network weights
    "learning_rate": 0.0005, 
    "hidden_size": 32,
    "loss_fcns": {"ESR": 0.75, "DC": 0.25},
    "hardware_device": "ts9",
}

In [None]:
physical_devices = tf.config.list_physical_devices()
print(f"These are the physical devices available:\n{physical_devices}")

try:
    # Disable all GPUS
    tf.config.set_visible_devices([], 'GPU')
    visible_devices = tf.config.get_visible_devices()
    print(f"These are the visible devices:\n{visible_devices}")
except:
    pass
    
print("Creating Stateful LSTM")
network = StatefulLSTM(input_size=config["input_size"], 
                       output_size=config["output_size"], 
                       hidden_size=config["hidden_size"], 
                       skip=config["skip_con"],
                       batch_size=config["batch_size"])

optimiser = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"], weight_decay=1e-4, epsilon=1e-8)
#loss_functions = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM)

loss_functions = LossWrapper(config["loss_fcns"])

network.compile(optimizer=optimiser, loss=loss_functions)
network.build((config["batch_size"],1,1))
network.summary()        

# Usage Example
dataset = DataSet()
dataset.create_subset('train', frame_len=22050)
dataset.load_file('train', os.path.join('train', config["hardware_device"]))

dataset.create_subset('val')
dataset.load_file('val', os.path.join('val', config["hardware_device"]))

In [None]:
for epoch in range(1, config["epochs"] + 1):
    print("Epoch: ", epoch)

    epoch_loss = network.train_epoch(dataset.subsets['train']['input'],
                                     dataset.subsets['train']['target'],
                                     loss_functions,
                                     optimiser,
                                     config['batch_size'],
                                     config['init_length'],
                                     config['up_fr'])

    print("Epoch loss:", epoch_loss)

In [None]:
network.save_weights('models/weights')
network.save_weights('models/lstm_model.h5')

In [None]:
inference_batch_size = 1

inference_model = StatefulLSTM(input_size=1, 
                               output_size=1,
                               hidden_size=32,
                               skip=0,
                               batch_size=inference_batch_size)

input_shape = (inference_batch_size,2048,1)

inference_model.build(input_shape)

inference_model.load_weights('models/lstm_model.h5')

In [None]:
print(dataset.subsets['train']['input'].shape)
input = dataset.subsets['train']['input']
print(input.shape)

In [None]:
inference_model.reset_hidden(inference_batch_size)
output = inference_model.predict(input, batch_size=inference_batch_size)

In [None]:
output_concat = output[0, :, :]
input_concat = input[0, :, :]
for i in range(1, output.shape[0]):
    output_concat =  np.concatenate((output_concat, output[i, :, :]), axis=0)
    input_concat = np.concatenate((input_concat, input[i, :, :]), axis=0)

In [None]:
write(os.path.join("input-tensorflow.wav"), 44100, input_concat.reshape(-1, 1))
write(os.path.join("train_out-tensorflow.wav"), 44100, output_concat.reshape(-1, 1))

In [None]:
input_shape = (1,1,1)
test_sequence = tf.zeros(input_shape)


print("Running prediction..")
prediction = inference_model.predict(test_sequence)
print(f"prediction {prediction}")

print("Running prediction..")
prediction = inference_model.predict(test_sequence)
print(f"prediction2 {prediction}")

print("test_sequence shape: ", test_sequence.shape)
print("prediction_2 shape: ", prediction.shape)

In [None]:
name = "model_0"
inference_model.reset_hidden(inference_batch_size)
input_shape = [1, 2048, 1]

func = tf.function(inference_model).get_concrete_function(
    tf.TensorSpec(input_shape, dtype=tf.float32))
converter = tf.lite.TFLiteConverter.from_concrete_functions([func], inference_model)
tflite_model = converter.convert()

# Save the model.
with open("models/"+name+"/stateful-lstm.tflite", 'wb') as f:
  f.write(tflite_model)

In [None]:
input_shape = [1, 2048, 1]

# Define the input shape
input_signature = [tf.TensorSpec(input_shape, tf.float32, name='input')]

# Convert the model
onnx_model, _ = tf2onnx.convert.from_keras(inference_model, input_signature, opset=13)
onnx.save(proto=onnx_model, f="models/"+name+"/"+"stateful-lstm-tflite.onnx")

In [None]:
tflite_model_path = "models/"+name+"/stateful-lstm.tflite"
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()

In [None]:
interpreter.set_tensor(interpreter.get_input_details()[0]["index"], tf.zeros((1, 2048, 1)))
interpreter.invoke()
prediction = interpreter.get_tensor(interpreter.get_output_details()[0]["index"])
print(f"prediction3 {prediction}")

In [None]:
import onnxruntime
import numpy as np

# Load the ONNX model
onnx_model_path = "models/"+name+"/"+"stateful-lstm-tflite.onnx"
ort_session = onnxruntime.InferenceSession(onnx_model_path)

# Create an input tensor with shape (1, 1, 1) filled with zeros
test_input = np.zeros((1, 2048, 1), dtype=np.float32)

# Run inference on the ONNX model
ort_inputs = {"input": test_input}
ort_outputs = ort_session.run(None, ort_inputs)
ort_outputs2 = ort_session.run(None, ort_inputs)

# Print the output
print(ort_outputs)
# Print the output
print(ort_outputs2)