In [28]:
import numpy as np
import scipy as sp

import torch
from torch import nn
import torch.optim as optim

import torchinfo
import os

## Choose computation device (CPU)

In [22]:
if torch.cuda.is_available():
    print('CUDA device available')
    torch.set_default_dtype(torch.cuda.FloatTensor)
    torch.cuda.set_device(0)
    cuda = 1
else:
    print('CUDA device not available')
    cuda = 0

CUDA device not available


## User inputs

In [42]:
# EDIT THIS SECTION FOR USER INPUTS
#
name = 'model_1'
in_file = 'TrainingData/flanger-input.wav'
out_file = 'TrainingData/flanger-target.wav'
epochs = 1

input_size = 1 
batch_size = 16 
test_size = 0.2
learning_rate = 0.0005 

if not os.path.exists('models/'+name):
    os.makedirs('models/'+name)
else:
    print("A model with the same name already exists. Please choose a new name.")
    exit

A model with the same name already exists. Please choose a new name.


## Define some helper functions

In [43]:
def save_wav(name, data):
    sp.io.wavfile.write(name, 44100, data.flatten().astype(np.float32))

def normalize(data):
    data_max = max(data)
    data_min = min(data)
    data_norm = max(data_max,abs(data_min))
    return data / data_norm

## Pre-processing the data

In [44]:
# Load and Preprocess Data ###########################################
in_rate, in_data = wavfile.read(in_file)
out_rate, out_data = wavfile.read(out_file)

X_all = in_data.astype(np.float32).flatten()  
X_all = normalize(X_all)
y_all = out_data.astype(np.float32).flatten() 
y_all = normalize(y_all)

# Get the last 20% of the wav data for testing and the rest for training
X_training, X_testing = np.split(X_all, [int(len(X_all) * (1 - test_size))])
y_training, y_testing = np.split(y_all, [int(len(y_all) * (1 - test_size))])
print(f"X_training shape (pre-processing): {X_training.shape}")
print(f"y_training shape (pre-processing): {y_training.shape}")
print(f"X_testing shape (pre-processing): {X_testing.shape}")
print(f"y_testing shape (pre-processing): {y_testing.shape}")

# Create a new array where each element is an array of input_size samples in time order
indices = np.arange(input_size) + np.arange(len(X_training) - input_size + 1)[:, np.newaxis]
indices = torch.from_numpy(indices).long()  # Convert indices to long
X_training = torch.from_numpy(X_training)
X_ordered_training = torch.zeros((len(indices), input_size), dtype=torch.float32)
for i, index in enumerate(indices):
    X_ordered_training[i] = torch.gather(X_training, 0, index)
X_ordered_training = X_ordered_training.unsqueeze(1)
print(f"X_ordered_training shape: {X_ordered_training.shape}")

indices = np.arange(input_size) + np.arange(len(X_testing) - input_size + 1)[:, np.newaxis]
indices = torch.from_numpy(indices).long()  # Convert indices to long
X_testing = torch.from_numpy(X_testing)
X_ordered_testing = torch.zeros((len(indices), input_size), dtype=torch.float32)
for i, index in enumerate(indices):
    X_ordered_testing[i] = torch.gather(X_testing, 0, index)
X_ordered_testing = X_ordered_testing.unsqueeze(1)
print(f"X_ordered_testing shape: {X_ordered_testing.shape}")

y_ordered_training = y_training[input_size - 1:]
y_ordered_training = torch.from_numpy(y_ordered_training).unsqueeze(1)
print(f"y_ordered_training shape: {y_ordered_training.shape}")

y_ordered_testing = y_testing[input_size - 1:]
y_ordered_testing = torch.from_numpy(y_ordered_testing).unsqueeze(1)
print(f"y_ordered_testing shape: {y_ordered_testing.shape}")

X_training shape (pre-processing): (352800,)
y_training shape (pre-processing): (352800,)
X_testing shape (pre-processing): (88200,)
y_testing shape (pre-processing): (88200,)
X_ordered_training shape: torch.Size([352800, 1, 1])
X_ordered_testing shape: torch.Size([88200, 1, 1])
y_ordered_training shape: torch.Size([352800, 1])
y_ordered_testing shape: torch.Size([88200, 1])


## Create dataloaders

In [45]:
training_dataset = torch.utils.data.TensorDataset(X_ordered_training, y_ordered_training)
training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True)

for batch, (X, y) in enumerate(training_dataloader):
    print(f"Batch: {batch}")
    print(f"Shape of X: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

testing_dataset = torch.utils.data.TensorDataset(X_ordered_testing, y_ordered_testing)
testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=batch_size, shuffle=True)

for batch, (X, y) in enumerate(testing_dataloader):
    print(f"Batch: {batch}")
    print(f"Shape of X: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Batch: 0
Shape of X: torch.Size([16, 1, 1])
Shape of y: torch.Size([16, 1]) torch.float32
Batch: 0
Shape of X: torch.Size([16, 1, 1])
Shape of y: torch.Size([16, 1]) torch.float32


## Define the model

In [19]:
class StatefulLSTM(nn.Module):
    def __init__(self, input_size=1, output_size=1, hidden_size=32, skip=1, bias_fl=True, num_layers=1):
        super(StatefulLSTM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        # Create dictionary of possible block types
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.lin = nn.Linear(hidden_size, output_size, bias=bias_fl)
        self.bias_fl = bias_fl
        self.skip = skip
        self.save_state = True
        self.hidden = (torch.zeros(self.input_size, 1, self.hidden_size),
                       torch.zeros(self.input_size, 1, self.hidden_size))

    # Origin forward function 
    def forward(self, x):    
        if self.skip:
            # save the residual for the skip connection
            res = x[:, :, 0:self.skip]
            x, self.hidden = self.lstm(x, self.hidden)
            return self.lin(x) + res
        else:
            x, self.hidden = self.lstm(x, self.hidden)
            return self.lin(x)

    # detach hidden state, this resets gradient tracking on the hidden state
    def detach_hidden(self):
        if self.hidden.__class__ == tuple:
            self.hidden = tuple([h.clone().detach() for h in self.hidden])
        else:
            self.hidden = self.hidden.clone().detach()

    def reset_hidden(self, batch_size):
        self.hidden = (torch.zeros(self.input_size, batch_size, self.hidden_size), 
                       torch.zeros(self.input_size, batch_size, self.hidden_size))

    def save_model(self, file_name, direc=''):
        model_data = {'model_data': {'model': 'SimpleRNN', 'input_size': self.lstm.input_size, 'skip': self.skip,
                                     'output_size': self.lin.out_features, 'unit_type': self.lstm._get_name(),
                                     'num_layers': self.lstm.num_layers, 'hidden_size': self.lstm.hidden_size,
                                     'bias_fl': self.bias_fl}}

        if self.save_state:
            model_state = self.state_dict()
            for each in model_state:
                model_state[each] = model_state[each].tolist()
            model_data['state_dict'] = model_state

        json_save(model_data, file_name, direc)

        # Scripting the model for compatibility with LibTorch
        
        self.reset_hidden(1)
        
        model_scripted = torch.jit.script(self)
        # 
        # # Saving the scripted model
        scripted_model_file = file_name + "_scripted.pt"
        if direc:
             scripted_model_file = os.path.join(direc, scripted_model_file)
        model_scripted.save(scripted_model_file)
        
        # An example input you would normally provide to your model's forward() method.
        example = torch.rand(1, 1, 1).to(torch.device("cpu"))
        onnx_model_file = file_name + ".onnx"
        if direc:
             onnx_model_file = os.path.join(direc, onnx_model_file)
        
        torch.onnx.export(model=self,
                          args=example,
                          f=onnx_model_file,
                          export_params=True,
                          opset_version=13,
                          do_constant_folding=True,
                          input_names = ['input'],
                          output_names = ['output'])

    # train_epoch runs one epoch of training
    def train_epoch(self, dataloader, loss_fcn, optim, bs, init_len=200, up_fr=1000):

        # shuffle the segments at the start of the epoch
        #shuffle = torch.randperm(input_data.shape[1])
        
        #print(shuffle)

        self.reset_hidden(bs)
        
        # Iterate over the batches
        ep_loss = 0
        for batch_i, (X, y) in enumerate(dataloader):
            # Load batch of shuffled segments
            input_batch = X.to(torch.device("cpu"))
            target_batch = y.to(torch.device("cpu"))


            # Initialise network hidden state by processing some samples then zero the gradient buffers
            # For training processing eine Anfangssequenz, damit ein brauchbarer hidden state vorliegt
            # Training startet erst nach! einem eingelaufen hidden state
            self(input_batch[0:init_len, :, :])
            self.zero_grad()

            # Choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr
            start_i = init_len
            batch_loss = 0
            # Iterate over the remaining samples in the mini batch
            for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)):
                # Process input batch with neural network
                
                print(input_batch[start_i:start_i + up_fr, :, :].shape)
                
                output = self(input_batch[start_i:start_i + up_fr, :, :])

                # Calculate loss and update network parameters
                loss = loss_fcn(output, target_batch[start_i:start_i + up_fr, :, :])
                loss.backward()
                optim.step()

                # Set the network hidden state, to detach it from the computation graph
                # Die Gradienteninformation die einhergeht mit dem hidden state ist connected zu dem
                # Computiaonal Graf von den vorherigen outputs. Da wir immer nur den letzten hidden state
                # zur berechnung brauchen und den state auch nicht updaten wollen - wird er hier detached
                # detached = wir löchen die Gradienteninformation 
                # https://discuss.pytorch.org/t/stupid-question-why-do-you-have-to-detach-the-hidden-state-of-lstms-but-not-the-hidden-state-of-a-linear-network/95089/3

                
                self.detach_hidden()
                self.zero_grad()

                # Update the start index for the next iteration and add the loss to the batch_loss total
                start_i += up_fr
                batch_loss += loss

            # Add the average batch loss to the epoch loss and reset the hidden states to zeros
            ep_loss += batch_loss / (k + 1)
        return ep_loss / (batch_i + 1)

    # only proc processes a the input data and calculates the loss, optionally grad can be tracked or not
    def process_data(self, dataloader, loss_fcn, chunk, grad=False, validate=False):
        with (torch.no_grad() if not grad else nullcontext()):
            self.reset_hidden(input_data.shape[1])
            output = torch.empty_like(target_data)
            for l in range(int(output.size()[0] / chunk)):
                if validate:
                    output[l * chunk:(l + 1) * chunk] = input_data[l * chunk:(l + 1) * chunk]
                else:
                    output[l * chunk:(l + 1) * chunk] = self(input_data[l * chunk:(l + 1) * chunk])
                    self.detach_hidden()
            # If the data set doesn't divide evenly into the chunk length, process the remainder
            if not (output.size()[0] / chunk).is_integer():
                if validate:
                    output[(l + 1) * chunk:-1] = input_data[(l + 1) * chunk:-1]                    
                else:
                    output[(l + 1) * chunk:-1] = self(input_data[(l + 1) * chunk:-1])
            loss = loss_fcn(output, target_data)
        return output, loss

In [20]:
class ESRLoss(nn.Module):
    def __init__(self):
        super(ESRLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.add(target, -output)
        loss = torch.pow(loss, 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss
class DCLoss(nn.Module):
    def __init__(self):
        super(DCLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss      
class LossWrapper(nn.Module):
    def __init__(self, losses):
        super(LossWrapper, self).__init__()
        loss_dict = {'ESR': ESRLoss(), 'DC': DCLoss()}

        loss_functions = [[loss_dict[key], value] for key, value in losses.items()]

        self.loss_functions = tuple([items[0] for items in loss_functions])
        try:
            self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions]))
        except IndexError:
            self.loss_factors = torch.ones(len(self.loss_functions))

    def forward(self, output, target):
        loss = 0
        for i, losses in enumerate(self.loss_functions):
            loss += torch.mul(losses(output, target), self.loss_factors[i])
        return loss

## Train the model

In [39]:
network = StatefulLSTM(input_size=1, 
                       output_size=1,
                       hidden_size=32,
                       skip=0)

optimiser = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5, verbose=True)
loss_functions = LossWrapper({"ESR": 0.75, "DC": 0.25})

network.save_state = True

summary = torchinfo.summary(network, (1, 1, 1), device=torch.device("cpu"))
print(summary)

Layer (type:depth-idx)                   Output Shape              Param #
StatefulLSTM                             [1, 1, 1]                 --
├─LSTM: 1-1                              [1, 1, 32]                4,480
├─Linear: 1-2                            [1, 1, 1]                 33
Total params: 4,513
Trainable params: 4,513
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.02
Estimated Total Size (MB): 0.02


In [40]:
# Define training procedure ############################################
def train(dataloader, model, loss_functions, optimiser, batch_size, init_lenght, up_fr):
    model.train()
    epoch_loss = 0
    epoch_loss = network.train_epoch(dataloader,
                                     loss_functions,
                                     optimiser,
                                     batch_size,
                                     init_lenght,
                                     up_fr)
    print("Epoch loss:", epoch_loss)

In [38]:
best_val_los = 10000

for t in range(1):
    print(f"Epoch {t+1}\n-------------------------------")
    train(training_dataloader, network, loss_functions, optimiser, batch_size, init_lenght=200, up_fr=1000)
    
    # Validation
    val_output, val_loss = network.process_data(training_dataloader, loss_functions, chunk=1000)
    scheduler.step(val_loss)
    print("Val loss:", val_loss)
    
    if val_loss < best_val_los:
        network.reset_hidden(1)
        network.save_model('model_best', save_path)
    
    #test(testing_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


AttributeError: 'DataLoader' object has no attribute 'shape'

## Run predictions
### 0. Load the model

In [8]:
model.load_state_dict(torch.load("models/"+name+"/"+name+".pth", map_location=torch.device('cpu')))

<All keys matched successfully>

### 1. On the test audio data

In [29]:
# Set the model to evaluate mode #################################
model.eval()
# Run prediction ##################################################
prediction = torch.zeros(0).to(device)

print("Running prediction..")
with torch.no_grad():
    for X, _ in testing_dataloader:
        X = X.to(device)
        predicted_batch = model(X)
        prediction = torch.cat((prediction, predicted_batch.flatten()), 0)

save_wav('models/'+name+'/y_pred.wav', prediction.cpu().numpy())
save_wav('models/'+name+'/x_test.wav', X_testing.numpy())
save_wav('models/'+name+'/y_test.wav', y_testing)

print("X_testing shape: ", X_testing.shape)
print("X_ordered_testing shape: ", X_ordered_testing.shape)
print("y_testing shape: ", y_testing.shape)
print("prediction shape: ", prediction.shape)

print("Note that the prediction shape is smaller than the y_testing shape. This is because the first predicted sample needs input_size samples for prediction.\n")


Running prediction..
X_testing shape:  torch.Size([1646977])
X_ordered_testing shape:  torch.Size([1646828, 1, 150])
y_testing shape:  (1646977,)
prediction shape:  torch.Size([1646828])
Note that the prediction shape is smaller than the y_testing shape. This is because the first predicted sample needs input_size samples for prediction.


### 2. On a number sequence (to control inference)

In [9]:
batch_size_test2 = 2

# Test the model simple number sequence to compare with inference #
X_testing_2 = np.array([], dtype=np.float64)

for i in range(0, batch_size_test2 * input_size):
    X_testing_2 = np.append(X_testing_2, i*0.001)

X_testing_2 = np.expand_dims(X_testing_2, axis=0)
X_testing_2 = np.expand_dims(X_testing_2, axis=0)
X_testing_2 = np.reshape(X_testing_2, (batch_size_test2, 1, input_size))

X_testing_2 = torch.from_numpy(X_testing_2).double()

print(f"X_testing_2 shape: {X_testing_2.shape}")

print("Running prediction..")
model = model.float()

prediction_2 = model(X_testing_2.to(device).float())

print(f"prediction {prediction_2}")

print("X_testing_2 shape: ", X_testing_2.shape)
print("prediction_2 shape: ", prediction_2.shape)

X_testing_2 shape: torch.Size([2, 1, 150])
Running prediction..
prediction tensor([[-0.0911],
        [-0.2398]], grad_fn=<AddmmBackward0>)
X_testing_2 shape:  torch.Size([2, 1, 150])
prediction_2 shape:  torch.Size([2, 1])


## Export as pt model
### 1. for minimal examples (with batch size = 2)

In [12]:
batch_size_minimal = 2

In [13]:
# An example input you would normally provide to your model's forward() method.
example = torch.rand(batch_size_minimal, 1, input_size).to(device)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("models/"+name+"/"+name+"-minimal.pt")

In [14]:
# An example input you would normally provide to your model's forward() method.
example = torch.rand(batch_size_minimal, 1, input_size).to(device)
filepath = "models/"+name+"/"+name+"-libtorch"+"-minimal.onnx"

# Export the model
torch.onnx.export(model=model,
                  args=example,
                  f=filepath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'])

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


### 2. for real-time streaming (with batch size = 64)

In [10]:
batch_size_streaming = 128

In [11]:
# An example input you would normally provide to your model's forward() method.
example = torch.rand(batch_size_streaming, 1, input_size).to(device)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("models/"+name+"/"+name+"-streaming.pt")

In [12]:
# An example input you would normally provide to your model's forward() method.
example = torch.rand(batch_size_streaming, 1, input_size).to(device)
filepath = "models/"+name+"/"+name+"-libtorch"+"-streaming.onnx"

# Export the model
torch.onnx.export(model=model,
                  args=example,
                  f=filepath,
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'])

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
