In [1]:
import sys
import os
import numpy as np
from scipy.io import wavfile
from scipy.io.wavfile import write
import torch
import torch.nn as nn
import torch.optim as optim
import time
import math
from contextlib import nullcontext
import torchinfo

# 1. Preprocessing

*This code is identically to the code found in the tensorflow training notebook*

Basics:
 - provide a input and target audio file in the config
 - This will create 3 folders [train, val, test] in /Data which will be used for training

In [2]:
def prepare_training_data(config):
    in_rate, in_data = wavfile.read(config["input_audio_path"])
    out_rate, out_data = wavfile.read(config["target_audio_path"])
    
    if len(in_data) != len(out_data):
        print("input and target files have different lengths")
        sys.exit()
      
    if len(in_data.shape) > 1 or len(out_data.shape) > 1:
        print("expected mono files")
        sys.exit()

    # Convert PCM16 to FP32
    if in_data.dtype == "int16":
        in_data = in_data / 32767
        print("In data converted from PCM16 to FP32")
    if out_data.dtype == "int16":
        out_data = out_data / 32767
        print("Out data converted from PCM16 to FP32")    

    clean_data = in_data.astype(np.float32).flatten()
    target_data = out_data.astype(np.float32).flatten()

    # Split the data on a twenty percent mod
    in_train, out_train, in_val, out_val = slice_on_mod(clean_data, target_data)

    save_wav(config["output_path"] + "/train/" + config["name"] + "-input.wav", in_train)
    save_wav(config["output_path"] + "/train/" + config["name"] + "-target.wav", out_train)

    save_wav(config["output_path"] + "/test/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/test/" + config["name"] + "-target.wav", out_val)

    save_wav(config["output_path"] + "/val/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/val/" + config["name"] + "-target.wav", out_val)


In [3]:
def slice_on_mod(input_data, target_data, mod=5):
    # Split the data on a modulus.

    # Type cast to an integer the modulus
    mod = int(mod)

    # Split the data into 100 pieces
    input_split = np.array_split(input_data, 100)
    target_split = np.array_split(target_data, 100)

    val_input_data = []
    val_target_data = []
    # Traverse the range of the indexes of the input signal reversed and pop every 5th for val
    for i in reversed(range(len(input_split))):
        if i % mod == 0:
            # Store the validation data
            val_input_data.append(input_split[i])
            val_target_data.append(target_split[i])
            # Remove the validation data from training
            input_split.pop(i)
            target_split.pop(i)

    # Flatten val_data down to one dimension and concatenate
    val_input_data = np.concatenate(val_input_data)
    val_target_data = np.concatenate(val_target_data)

    # Concatenate back together
    training_input_data = np.concatenate(input_split)
    training_target_data = np.concatenate(target_split)
    return training_input_data, training_target_data, val_input_data, val_target_data

In [4]:
def save_wav(name, data):
    directory = os.path.dirname(name)

    if not os.path.exists(directory):
        os.makedirs(directory)
        
    wavfile.write(name, 44100, data.flatten().astype(np.float32))

In [5]:
importConfig = {
    "input_audio_path": "TrainingData/flanger-input.wav",
    "target_audio_path": "TrainingData/flanger-target.wav",
    "output_path": "Data",
    "name": "flanger"
}

prepare_training_data(importConfig)

In data converted from PCM16 to FP32
Out data converted from PCM16 to FP32


# Dataloader

In [6]:
# converts numpy audio into frames, and creates a torch tensor from them, frame_len = 0 just converts to a torch tensor
def framify(audio, frame_len):
    # If audio is mono, add a dummy dimension, so the same operations can be applied to mono/multichannel audio
    audio = np.expand_dims(audio, 1) if len(audio.shape) == 1 else audio
    # Calculate the number of segments the training data will be split into in frame_len is not 0
    seg_num = math.floor(audio.shape[0] / frame_len) if frame_len else 1
    # If no frame_len is provided, set frame_len to be equal to length of the input audio
    frame_len = audio.shape[0] if not frame_len else frame_len
    # Find the number of channels
    channels = audio.shape[1]
    # Initialise tensor matrices
    dataset = torch.empty((frame_len, seg_num, channels))
    # Load the audio for the training set
    for i in range(seg_num):
        dataset[:, i, :] = torch.from_numpy(audio[i * frame_len:(i + 1) * frame_len, :])
    return dataset

In [7]:
class DataSet:
    def __init__(self, data_dir='../Dataset/', extensions=('input', 'target')):
        self.extensions = extensions if extensions else ['']
        self.subsets = {}
        assert type(data_dir) == str, "data_dir should be string,not %r" % {type(data_dir)}
        self.data_dir = data_dir

    # add a subset called 'name', desired 'frame_len' is given in seconds, or 0 for just one long frame
    def create_subset(self, name, frame_len=0):
        assert type(name) == str, "data subset name must be a string, not %r" %{type(name)}
        assert not (name in self.subsets), "subset %r already exists" %name
        self.subsets[name] = SubSet(frame_len)

    # load a file of 'filename' into existing subset/s 'set_names', split fractionally as specified by 'splits',
    # if 'cond_val' is provided the conditioning value will be saved along with the frames of the loaded data
    def load_file(self, filename, set_names='train', splits=None, cond_val=None):
        # Assertions and checks
        if type(set_names) == str:
            set_names = [set_names]
        assert len(set_names) == 1 or len(set_names) == len(splits), "number of subset names must equal number of " \
                                                                     "split markers"
        assert [self.subsets.get(each) for each in set_names], "set_names contains subsets that don't exist yet"

        # Load each of the 'extensions'
        for i, ext in enumerate(self.extensions):
            try:
                file_loc = os.path.join(self.data_dir, filename + '-' + ext)
                file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc
                np_data = wavfile.read(file_loc)
            except FileNotFoundError:
                print(["File Not Found At: " + self.data_dir + filename])
                return

            raw_audio = np_data[1].astype(np.float32)

            if len(set_names) == 1:
                self.subsets[set_names[0]].add_data(np_data[0], raw_audio, ext, cond_val)

In [8]:
# The SubSet class holds a subset of data,
# frame_len sets the length of audio per frame (in s), if set to 0 a single frame is used instead
class SubSet:
    def __init__(self, frame_len):
        self.data = {}
        self.frame_len = frame_len
        self.conditioning = None
        self.fs = None

    # Add 'audio' data, in the data dictionary at the key 'ext', if cond_val is provided save the cond_val of each frame
    def add_data(self, fs, audio, ext, cond_val):
        if not self.fs:
            self.fs = fs
        assert self.fs == fs, "data with different sample rate provided to subset"
        # if no 'ext' is provided, all the subsets data will be stored at the 'data' key of the 'data' dict
        ext = 'data' if not ext else ext
        # Frame the data and optionally create a tensor of the conditioning values of each frame
        framed_data = framify(audio, self.frame_len)

        try:
            # Convert data from tuple to list and concatenate new data onto the data tensor
            data = list(self.data[ext])
            self.data[ext] = (torch.cat((data[0], framed_data), 1),)
        # If this is the first data to be loaded into the subset, create the data and cond_data tuples
        except KeyError:
            self.data[ext] = (framed_data,)

# Training

In [9]:
class StatefulLSTM(nn.Module):
    def __init__(self, input_size=1, output_size=1, hidden_size=32, skip=0, bias_fl=True):
        super(StatefulLSTM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.skip = skip
        
        # Create dictionary of possible block types
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size)
        self.lin = nn.Linear(hidden_size, output_size, bias=bias_fl)
        self.hidden = (torch.zeros(self.input_size, 1, self.hidden_size),
                       torch.zeros(self.input_size, 1, self.hidden_size))

    # Origin forward function 
    def forward(self, x):    
        if self.skip:
            # save the residual for the skip connection
            res = x[:, :, 0:self.skip]
            x, self.hidden = self.lstm(x, self.hidden)
            return self.lin(x) + res
        else:
            x, self.hidden = self.lstm(x, self.hidden)
            x = self.lin(x)
            return x

    # detach hidden state, this resets gradient tracking on the hidden state
    def detach_hidden(self):
        if self.hidden.__class__ == tuple:
            self.hidden = tuple([h.clone().detach() for h in self.hidden])
        else:
            self.hidden = self.hidden.clone().detach()

    def reset_hidden(self, batch_size):
        self.hidden = (torch.zeros(self.input_size, batch_size, self.hidden_size), 
                       torch.zeros(self.input_size, batch_size, self.hidden_size))

    # train_epoch runs one epoch of training
    def train_epoch(self, input_data, target_data, loss_fcn, optim, bs, init_len=200, up_fr=1000):

        # shuffle the segments at the start of the epoch
        shuffle = torch.randperm(input_data.shape[1])
    
        self.reset_hidden(bs)
        
        ep_loss = 0
        #Iterate over batches of {bs} batches
        for batch_i in range(math.ceil(shuffle.shape[0] / bs)):
            # Load batch of shuffled segments
            input_batch = input_data[:, shuffle[batch_i * bs:(batch_i + 1) * bs], :]
            target_batch = target_data[:, shuffle[batch_i * bs:(batch_i + 1) * bs], :]

            # Initialise network hidden state by processing some samples then zero the gradient buffers
            # For training processing eine Anfangssequenz, damit ein brauchbarer hidden state vorliegt
            # Training startet erst nach! einem eingelaufen hidden state
            out = self(input_batch[0:init_len, :, :])
                        
            self.zero_grad()

            # Choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr
            start_i = init_len
            batch_loss = 0
            # Iterate over the remaining samples in the mini batch
            for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)):
                # Process input batch with neural network                
                output = self(input_batch[start_i:start_i + up_fr, :, :])

                # Calculate loss and update network parameters
                loss = loss_fcn(output, target_batch[start_i:start_i + up_fr, :, :])
                loss.backward()
                optim.step()

                print(f"loss: {loss}")

                # Set the network hidden state, to detach it from the computation graph
                # Die Gradienteninformation die einhergeht mit dem hidden state ist connected zu dem
                # Computiaonal Graf von den vorherigen outputs. Da wir immer nur den letzten hidden state
                # zur berechnung brauchen und den state auch nicht updaten wollen - wird er hier detached
                # detached = wir löchen die Gradienteninformation 
                # https://discuss.pytorch.org/t/stupid-question-why-do-you-have-to-detach-the-hidden-state-of-lstms-but-not-the-hidden-state-of-a-linear-network/95089/3                
                self.detach_hidden()
                self.zero_grad()

                # Update the start index for the next iteration and add the loss to the batch_loss total
                start_i += up_fr
                batch_loss += loss

            # Add the average batch loss to the epoch loss and reset the hidden states to zeros
            ep_loss += batch_loss / (k + 1)
        
        return ep_loss / (batch_i + 1)

In [10]:
class ESRLoss(nn.Module):
    def __init__(self):
        super(ESRLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.add(target, -output)
        loss = torch.pow(loss, 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss
class DCLoss(nn.Module):
    def __init__(self):
        super(DCLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss      
class LossWrapper(nn.Module):
    def __init__(self, losses):
        super(LossWrapper, self).__init__()
        loss_dict = {'ESR': ESRLoss(), 'DC': DCLoss()}

        loss_functions = [[loss_dict[key], value] for key, value in losses.items()]

        self.loss_functions = tuple([items[0] for items in loss_functions])
        try:
            self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions]))
        except IndexError:
            self.loss_factors = torch.ones(len(self.loss_functions))

    def forward(self, output, target):
        loss = 0
        for i, losses in enumerate(self.loss_functions):
            loss += torch.mul(losses(output, target), self.loss_factors[i])
        return loss

In [11]:
config = {
    "input_size": 1, # Number of channels
    "output_size": 1, # Number of channels
    "skip_con": 0, # is there a skip connection for the input to the output
    "epochs": 10,
    "batch_size": 16,
    "init_length": 200, # Number of sequence samples to process before starting weight updates
    "up_fr": 1000, # For recurrent models, number of samples to run in between updating network weights
    "validation_f": 1, # Validation Frequency (in epochs)
    "val_chunk": 1000, #Number of sequence samples to process in n each chunk of validation
    "learning_rate": 0.0005, 
    "hidden_size": 32,
    "loss_fcns": {"ESR": 0.75, "DC": 0.25},
    "hardware_device": "flanger",
    "save_location": "Results-PyTorch",
    "export_json": 1,
    "export_torchscript": 1,
    "stateful_lstm": 1
}

In [12]:
current_directory = os.getcwd()
result_parent_path = os.path.join(current_directory, config["save_location"])
os.makedirs(result_parent_path, exist_ok=True)
result_path = os.path.join(result_parent_path, config["hardware_device"])
os.makedirs(result_path, exist_ok=True)

save_path = os.path.join(config["save_location"], config["hardware_device"])
    
if torch.cuda.is_available():
    print('CUDA device available')
    torch.set_default_dtype(torch.cuda.FloatTensor)
    torch.cuda.set_device(0)
else:
    print('CUDA device not available')

print("Creating Stateful LSTM")
network = StatefulLSTM(input_size=config["input_size"], 
                       output_size=config["output_size"], 
                       hidden_size=config["hidden_size"], 
                       skip=config["skip_con"])

optimiser = torch.optim.Adam(network.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5, verbose=True)
loss_functions = torch.nn.MSELoss(reduction='sum')
#loss_functions = LossWrapper(config["loss_fcns"])

summary = torchinfo.summary(network, (1, 1, 1), device=torch.device("cpu"))
print(summary)

dataset = DataSet(data_dir='Data')
dataset.create_subset('train', frame_len=22050)
dataset.load_file(os.path.join('train', config["hardware_device"]), 'train')

dataset.create_subset('val')
dataset.load_file(os.path.join('val', config["hardware_device"]), 'val')

for epoch in range(1, config["epochs"] + 1):
    print("Epoch: ", epoch)
    # Run 1 epoch of training
    epoch_loss = network.train_epoch(dataset.subsets['train'].data['input'][0],
                                     dataset.subsets['train'].data['target'][0],
                                     loss_functions, optimiser, config['batch_size'], config['init_length'], config['up_fr'])

    print("Epoch loss:", epoch_loss)
    

CUDA device not available
Creating Stateful LSTM
Layer (type:depth-idx)                   Output Shape              Param #
StatefulLSTM                             [1, 1, 1]                 --
├─LSTM: 1-1                              [1, 1, 32]                4,480
├─Linear: 1-2                            [1, 1, 1]                 33
Total params: 4,513
Trainable params: 4,513
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.02
Estimated Total Size (MB): 0.02
Epoch:  1
loss: 375.7005615234375
loss: 318.1959228515625
loss: 335.1663818359375
loss: 330.03240966796875
loss: 315.40576171875
loss: 296.50946044921875
loss: 230.29856872558594
loss: 217.5951690673828
loss: 206.94898986816406
loss: 213.00393676757812
loss: 250.07321166992188
loss: 254.5784149169922
loss: 261.7413330078125
loss: 203.62075805664062
loss: 223.8120574951172
loss: 250.25856018066406
loss: 227.7099151611328
loss: 221.15005493164062
loss:

In [27]:
#network.reset_hidden(1)
torch.save(network.state_dict(), "models/ht1.pth")

In [13]:
model = StatefulLSTM(input_size=config["input_size"], 
                           output_size=config["output_size"], 
                           hidden_size=config["hidden_size"], 
                           skip=config["skip_con"])

model.load_state_dict(torch.load("models/ht1.pth", map_location=torch.device('cpu')))

<All keys matched successfully>

In [14]:
print(dataset.subsets['val'].data['input'][0][0:96000,...].shape)

torch.Size([96000, 1, 1])


In [15]:
with (torch.no_grad()):
    output = model(dataset.subsets['val'].data['input'][0][0:96000,...])

In [16]:
write(os.path.join("best_val_out.wav"),
                      dataset.subsets['val'].fs, output.cpu().numpy()[:, 0, 0])

In [None]:
file_name="test"
direc=''

# Scripting the model for compatibility with LibTorch
model.reset_hidden(1)
model_scripted = torch.jit.script(model)
# # Saving the scripted model
scripted_model_file = file_name + "_scripted.pt"
if direc:
     scripted_model_file = os.path.join(direc, scripted_model_file)
model_scripted.save(scripted_model_file)

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 1, 1).to(torch.device("cpu"))
onnx_model_file = file_name + ".onnx"
if direc:
     onnx_model_file = os.path.join(direc, onnx_model_file)

torch.onnx.export(model=self,
                  args=example,
                  f=onnx_model_file,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'])

In [None]:
loaded_model = torch.jit.load("Results-PyTorch/flanger/model_best_scripted.pt")

# Example input
sequence_length = 100  # adjust as per your model's training
batch_size = 1         # can be set to 1 for testing individual sequences
input_size = loaded_model.input_size  # should match the model's expected input size

test_input = torch.zeros(1, 1, 1)

print(test_input)
test_output = loaded_model(test_input)
print(test_output)
test_output = loaded_model(test_input)
print(test_output)

In [None]:
import onnxruntime
import numpy as np

# Load the ONNX model
onnx_model_path = "Results-PyTorch/flanger/model_best.onnx"
ort_session = onnxruntime.InferenceSession(onnx_model_path)

# Create an input tensor with shape (1, 1, 1) filled with zeros
test_input = np.zeros((1, 1, 1), dtype=np.float32)

# Run inference on the ONNX model
ort_inputs = {"input": test_input}
ort_outputs = ort_session.run(None, ort_inputs)
ort_outputs2 = ort_session.run(None, ort_inputs)

# Print the output
print(ort_outputs)
# Print the output
print(ort_outputs2)
