In [None]:
import sys
import os
import math
import torch
import torchinfo
import torch.nn as nn
import onnx
import numpy as np
from scipy.io import wavfile
from scipy.io.wavfile import write

# 1. Preprocessing

In [None]:
def prepare_training_data(config):
    in_rate, in_data = wavfile.read(config["input_audio_path"])
    out_rate, out_data = wavfile.read(config["target_audio_path"])

    if len(in_data) != len(out_data):
        print("input and target files have different lengths")
        sys.exit()

    if len(in_data.shape) > 1 or len(out_data.shape) > 1:
        print("expected mono files")
        sys.exit()

    # Convert PCM16 to FP32
    if in_data.dtype == "int16":
        in_data = in_data / 32767
        print("In data converted from PCM16 to FP32")
    if out_data.dtype == "int16":
        out_data = out_data / 32767
        print("Out data converted from PCM16 to FP32")

    clean_data = in_data.astype(np.float32).flatten()
    target_data = out_data.astype(np.float32).flatten()

    # Split the data on a twenty percent mod
    in_train, out_train, in_val, out_val = slice_on_mod(clean_data, target_data)

    save_wav(config["output_path"] + "/train/" + config["name"] + "-input.wav", in_train)
    save_wav(config["output_path"] + "/train/" + config["name"] + "-target.wav", out_train)

    save_wav(config["output_path"] + "/test/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/test/" + config["name"] + "-target.wav", out_val)

    save_wav(config["output_path"] + "/val/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/val/" + config["name"] + "-target.wav", out_val)


In [None]:
def slice_on_mod(input_data, target_data, mod=5):
    # Split the data on a modulus.

    # Type cast to an integer the modulus
    mod = int(mod)

    # Split the data into 100 pieces
    input_split = np.array_split(input_data, 100)
    target_split = np.array_split(target_data, 100)

    val_input_data = []
    val_target_data = []
    # Traverse the range of the indexes of the input signal reversed and pop every 5th for val
    for i in reversed(range(len(input_split))):
        if i % mod == 0:
            # Store the validation data
            val_input_data.append(input_split[i])
            val_target_data.append(target_split[i])
            # Remove the validation data from training
            input_split.pop(i)
            target_split.pop(i)

    # Flatten val_data down to one dimension and concatenate
    val_input_data = np.concatenate(val_input_data)
    val_target_data = np.concatenate(val_target_data)

    # Concatenate back together
    training_input_data = np.concatenate(input_split)
    training_target_data = np.concatenate(target_split)
    return training_input_data, training_target_data, val_input_data, val_target_data

In [None]:
def save_wav(name, data):
    directory = os.path.dirname(name)

    if not os.path.exists(directory):
        os.makedirs(directory)

    wavfile.write(name, 44100, data.flatten().astype(np.float32))

In [None]:
importConfig = {
    "input_audio_path": "TrainingData/ts9-input.wav",
    "target_audio_path": "TrainingData/ts9-target.wav",
    "output_path": "Data",
    "name": "ts9"
}

prepare_training_data(importConfig)

# Dataloader

In [None]:
class DataSet:
    def __init__(self, data_dir='Data/'):
        self.data_dir = data_dir
        self.subsets = {}

    def create_subset(self, name, frame_len=22050):
        self.subsets[name] = {'input': None, 'target': None, 'frame_len': frame_len}

    def load_file(self, subset_name, base_filename):
        if subset_name not in self.subsets:
            raise ValueError(f"Subset '{subset_name}' does not exist")

        input_file = os.path.join(self.data_dir, f"{base_filename}-input.wav")
        target_file = os.path.join(self.data_dir, f"{base_filename}-target.wav")

        try:
            self.subsets[subset_name]['input'] = self.load_and_process(input_file, self.subsets[subset_name]['frame_len'])
            self.subsets[subset_name]['target'] = self.load_and_process(target_file, self.subsets[subset_name]['frame_len'])
        except FileNotFoundError as e:
            print(f"File Not Found: {e.filename}")
            return

    def framify(self, audio, frame_len):
        audio = np.expand_dims(audio, 1) if len(audio.shape) == 1 else audio
        seg_num = math.ceil(audio.shape[0] / frame_len)
        padded_length = seg_num * frame_len
        padded_audio = np.pad(audio, ((0, padded_length - audio.shape[0]), (0, 0)), mode='constant')

        reshaped_audio = np.reshape(padded_audio, (seg_num, frame_len, audio.shape[1]))
        return torch.from_numpy(reshaped_audio).float()

    def load_and_process(self, file_path, frame_len):
        sample_rate, data = wavfile.read(file_path)
        data = data.astype(np.float32)
        return self.framify(data, frame_len)

# Training

In [None]:
class StatefulLSTM(nn.Module):
    def __init__(self, input_size=1, output_size=1, hidden_size=32, skip=0, bias_fl=True):
        super(StatefulLSTM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.skip = skip

        # Create dictionary of possible block types
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size)
        self.lin = nn.Linear(hidden_size, output_size, bias=bias_fl)
        self.hidden = (torch.zeros(self.input_size, 1, self.hidden_size),
                       torch.zeros(self.input_size, 1, self.hidden_size))
        torch.nn.init.xavier_uniform_(self.lstm.weight_hh_l0)
        torch.nn.init.xavier_uniform_(self.lstm.weight_ih_l0)
        torch.nn.init.zeros_(self.lstm.bias_hh_l0)
        torch.nn.init.zeros_(self.lstm.bias_ih_l0)
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.zeros_(self.lin.bias)

    # Origin forward function
    def forward(self, x):
        if self.skip:
            # save the residual for the skip connection
            res = x[:, :, 0:self.skip]
            x, self.hidden = self.lstm(x, self.hidden)
            return self.lin(x) + res
        else:
            x, self.hidden = self.lstm(x, self.hidden)
            x = self.lin(x)
            return x

    # detach hidden state, this resets gradient tracking on the hidden state
    def detach_hidden(self):
        if self.hidden.__class__ == tuple:
            self.hidden = tuple([h.clone().detach() for h in self.hidden])
        else:
            self.hidden = self.hidden.clone().detach()

    def reset_hidden(self, batch_size):
        self.hidden = (torch.zeros(self.input_size, batch_size, self.hidden_size),
                       torch.zeros(self.input_size, batch_size, self.hidden_size))

    # train_epoch runs one epoch of training
    def train_epoch(self, input_data, target_data, loss_fcn, optim, bs, init_len=200, up_fr=1000):

        # shuffle the segments at the start of the epoch
        shuffle = torch.randperm(input_data.shape[1])

        self.reset_hidden(bs)

        ep_loss = 0
        #Iterate over batches of {bs} batches
        for batch_i in range(math.ceil(shuffle.shape[0] / bs)):
            if shuffle[batch_i * bs:(batch_i + 1) * bs].shape[0] != bs:
                # If the final batch is smaller than the batch size, break the loop
                break
            # Load batch of shuffled segments
            input_batch = input_data[:, shuffle[batch_i * bs:(batch_i + 1) * bs], :]
            target_batch = target_data[:, shuffle[batch_i * bs:(batch_i + 1) * bs], :]

            # Initialise network hidden state by processing some samples then zero the gradient buffers
            # For training processing eine Anfangssequenz, damit ein brauchbarer hidden state vorliegt
            # Training startet erst nach! einem eingelaufen hidden state
            _ = self(input_batch[0:init_len, :, :])

            self.zero_grad()

            # Choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr
            start_i = init_len
            batch_loss = 0
            # Iterate over the remaining samples in the mini batch
            for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)):
                # Process input batch with neural network
                output = self(input_batch[start_i:start_i + up_fr, :, :])

                # Calculate loss and update network parameters
                loss = loss_fcn(output, target_batch[start_i:start_i + up_fr, :, :])
                loss.backward()
                optim.step()

                # Set the network hidden state, to detach it from the computation graph
                # Die Gradienteninformation die einhergeht mit dem hidden state ist connected zu dem
                # Computiaonal Graf von den vorherigen outputs. Da wir immer nur den letzten hidden state
                # zur berechnung brauchen und den state auch nicht updaten wollen - wird er hier detached
                # detached = wir löchen die Gradienteninformation
                # https://discuss.pytorch.org/t/stupid-question-why-do-you-have-to-detach-the-hidden-state-of-lstms-but-not-the-hidden-state-of-a-linear-network/95089/3
                self.detach_hidden()
                self.zero_grad()

                print(f"loss: {loss}")

                # Update the start index for the next iteration and add the loss to the batch_loss total
                start_i += up_fr
                batch_loss += loss

            # Add the average batch loss to the epoch loss and reset the hidden states to zeros
            ep_loss += batch_loss / (k + 1)
            self.reset_hidden(bs)

        return ep_loss / (batch_i + 1)

In [None]:
class ESRLoss(nn.Module):
    def __init__(self):
        super(ESRLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.add(target, -output)
        loss = torch.pow(loss, 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss
class DCLoss(nn.Module):
    def __init__(self):
        super(DCLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss
    
class LossWrapper(nn.Module):
    def __init__(self, losses):
        super(LossWrapper, self).__init__()
        loss_dict = {'ESR': ESRLoss(), 'DC': DCLoss()}

        loss_functions = [[loss_dict[key], value] for key, value in losses.items()]

        self.loss_functions = tuple([items[0] for items in loss_functions])
        try:
            self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions]))
        except IndexError:
            self.loss_factors = torch.ones(len(self.loss_functions))

    def forward(self, output, target):
        loss = 0
        for i, losses in enumerate(self.loss_functions):
            loss += torch.mul(losses(output, target), self.loss_factors[i])
        return loss

In [None]:
config = {
    "input_size": 1, # Number of channels
    "output_size": 1, # Number of channels
    "skip_con": 1, # is there a skip connection for the input to the output
    "epochs": 20,
    "batch_size": 50,
    "init_length": 200, # Number of sequence samples to process before starting weight updates
    "up_fr": 1000, # For recurrent models, number of samples to run in between updating network weights
    "learning_rate": 0.005,
    "hidden_size": 20,
    "loss_fcns": {"ESR": 0.75, "DC": 0.25},
    "hardware_device": "ts9"
}

In [None]:
if torch.cuda.is_available():
    print('CUDA device available')
    torch.set_default_dtype(torch.cuda.FloatTensor)
    torch.cuda.set_device(0)
else:
    print('CUDA device not available')

print("Creating Stateful LSTM")
network = StatefulLSTM(input_size=config["input_size"],
                       output_size=config["output_size"],
                       hidden_size=config["hidden_size"],
                       skip=config["skip_con"])

optimiser = torch.optim.Adam(network.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5, verbose=True)
# loss_functions = torch.nn.MSELoss(reduction='sum')
loss_functions = LossWrapper(config["loss_fcns"])

summary = torchinfo.summary(network, (1, 1, 1), device=torch.device("cpu"))
print(summary)

dataset = DataSet()
dataset.create_subset('train', frame_len=22050)
dataset.load_file('train', os.path.join('train', config["hardware_device"]))

dataset.create_subset('val')
dataset.load_file('val', os.path.join('val', config["hardware_device"]))

In [None]:
for epoch in range(1, config["epochs"] + 1):
    print("Epoch: ", epoch)

    epoch_loss = network.train_epoch(dataset.subsets['train']['input'].swapaxes(0, 1),
                                     dataset.subsets['train']['target'].swapaxes(0, 1),
                                     loss_functions,
                                     optimiser,
                                     config['batch_size'],
                                     config['init_length'],
                                     config['up_fr'])

    # scheduler.step(epoch_loss)

    print("Epoch loss:", epoch_loss)

In [None]:
name = "model_0"
network.reset_hidden(1)
torch.save(network.state_dict(), "models/"+name+"/"+"stateful-lstm.pth")

In [None]:
model = StatefulLSTM(input_size=config["input_size"], 
                           output_size=config["output_size"], 
                           hidden_size=config["hidden_size"], 
                           skip=config["skip_con"])

model.load_state_dict(torch.load("models/"+name+"/"+"stateful-lstm.pth", map_location=torch.device('cpu')))

In [None]:
input = dataset.subsets['train']['input'].swapaxes(0, 1)
print(input.shape)

In [None]:
output = torch.zeros_like(input)
model.reset_hidden(1)
with (torch.no_grad()):
    for i in range(0, input.shape[1]):
        output[:, i:i+1, :] = model(input[:, i:i+1, :])

In [None]:
output_concat = output[:, 0, :]
input_concat = input[:, 0, :]
for i in range(1, output.shape[1]):
    output_concat = torch.cat((output_concat, output[:, i, :]), 0)
    input_concat = torch.cat((input_concat, input[:, i, :]), 0)

In [None]:
write(os.path.join("input-pytorch.wav"), 44100, input_concat.cpu().numpy().reshape(-1, 1))
write(os.path.join("output-pytorch.wav"), 44100, output_concat.cpu().numpy().reshape(-1, 1))

In [None]:
name = "model_0"
model.reset_hidden(1)
model_scripted = torch.jit.script(model)
model_scripted.save("models/"+name+"/"+"stateful-lstm.pt")

In [None]:
example_input = torch.rand(2048, 1, 1)
filepath = "models/"+name+"/"+"stateful-lstm-libtorch.onnx"
model.reset_hidden(1)

# Export the model
torch.onnx.export(model=model,
                args=example_input,
                f=filepath,
                export_params=True,
                opset_version=13,
                do_constant_folding=True,
                input_names = ['input'],
                output_names = ['output'])

In [None]:
loaded_model = torch.jit.load("models/"+name+"/"+"stateful-lstm.pt")

# Example input
sequence_length = 100  # adjust as per your model's training
batch_size = 1         # can be set to 1 for testing individual sequences
input_size = loaded_model.input_size  # should match the model's expected input size\
print(input_size)

test_input = torch.zeros(2048, 1, 1)
test_output = loaded_model(test_input)
print(test_output)
test_output = loaded_model(test_input)
print(test_output)

In [None]:
test_input = torch.zeros(2048, 1, 1)
test_output = loaded_model(test_input)
print(test_output)
test_output = loaded_model(test_input)
print(test_output)

In [None]:
import onnxruntime
import numpy as np

# Load the ONNX model
onnx_model_path = "models/"+name+"/"+"stateful-lstm-libtorch.onnx"
ort_session = onnxruntime.InferenceSession(onnx_model_path)

# Create an input tensor with shape (1, 1, 1) filled with zeros
test_input = np.zeros((2048, 1, 1), dtype=np.float32)

# Run inference on the ONNX model
ort_inputs = {"input": test_input}
ort_outputs = ort_session.run(None, ort_inputs)
ort_outputs2 = ort_session.run(None, ort_inputs)

# Print the output
print(ort_outputs)
# Print the output
print(ort_outputs2)
