In [2]:
pip install numpy scipy torch torchvision torchaudio

Collecting numpy
  Obtaining dependency information for numpy from https://files.pythonhosted.org/packages/69/65/0d47953afa0ad569d12de5f65d964321c208492064c38fe3b0b9744f8d44/numpy-1.24.4-cp38-cp38-win_amd64.whl.metadata
  Using cached numpy-1.24.4-cp38-cp38-win_amd64.whl.metadata (5.6 kB)
Collecting scipy
  Using cached scipy-1.10.1-cp38-cp38-win_amd64.whl (42.2 MB)
Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/2a/27/cd2e60d4accf81aa6279be8c5e9bca99a16bd21b3c0428dc515569e36561/torch-2.1.2-cp38-cp38-win_amd64.whl.metadata
  Using cached torch-2.1.2-cp38-cp38-win_amd64.whl.metadata (26 kB)
Collecting torchvision
  Obtaining dependency information for torchvision from https://files.pythonhosted.org/packages/54/4b/b0861005f5d4370b3529f31d5e6461c4faf9bbbcbe916480cceebc885aa8/torchvision-0.16.2-cp38-cp38-win_amd64.whl.metadata
  Using cached torchvision-0.16.2-cp38-cp38-win_amd64.whl.metadata (6.6 kB)
Collecting torchaudio
  Obtain


[notice] A new release of pip is available: 23.2.1 -> 23.3.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import sys
import os
import numpy as np
from scipy.io import wavfile
from scipy.io.wavfile import write
import torch
import torch.nn as nn
import torch.optim as optim
import time
import math
import json

# Preprocessing

This code is ....

### Code

Define all necessary functions

In [2]:
def prepareTrainingData(config):
    in_rate, in_data = wavfile.read(config["input_audio_path"])
    out_rate, out_data = wavfile.read(config["target_audio_path"])
    
    if len(in_data) != len(out_data):
        print("input and target files have different lengths")
        sys.exit()
      
    if len(in_data.shape) > 1 or len(out_data.shape) > 1:
        print("expected mono files")
        sys.exit()

    # Convert PCM16 to FP32
    if in_data.dtype == "int16":
        in_data = in_data / 32767
        print("In data converted from PCM16 to FP32")
    if out_data.dtype == "int16":
        out_data = out_data / 32767
        print("Out data converted from PCM16 to FP32")    

    clean_data = in_data.astype(np.float32).flatten()
    target_data = out_data.astype(np.float32).flatten()

    # Split the data on a twenty percent mod
    in_train, out_train, in_val, out_val = sliceOnMod(clean_data, target_data)

    save_wav(config["output_path"] + "/train/" + config["name"] + "-input.wav", in_train)
    save_wav(config["output_path"] + "/train/" + config["name"] + "-target.wav", out_train)

    save_wav(config["output_path"] + "/test/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/test/" + config["name"] + "-target.wav", out_val)

    save_wav(config["output_path"] + "/val/" + config["name"] + "-input.wav", in_val)
    save_wav(config["output_path"] + "/val/" + config["name"] + "-target.wav", out_val)


In [3]:
def sliceOnMod(input_data, target_data, mod=5):
    # Split the data on a modulus.

    # Type cast to an integer the modulus
    mod = int(mod)

    # Split the data into 100 pieces
    input_split = np.array_split(input_data, 100)
    target_split = np.array_split(target_data, 100)

    val_input_data = []
    val_target_data = []
    # Traverse the range of the indexes of the input signal reversed and pop every 5th for val
    for i in reversed(range(len(input_split))):
        if i % mod == 0:
            # Store the validation data
            val_input_data.append(input_split[i])
            val_target_data.append(target_split[i])
            # Remove the validation data from training
            input_split.pop(i)
            target_split.pop(i)

    # Flatten val_data down to one dimension and concatenate
    val_input_data = np.concatenate(val_input_data)
    val_target_data = np.concatenate(val_target_data)

    # Concatenate back together
    training_input_data = np.concatenate(input_split)
    training_target_data = np.concatenate(target_split)
    return (training_input_data, training_target_data, val_input_data, val_target_data)

In [4]:
def save_wav(name, data):
    directory = os.path.dirname(name)

    if not os.path.exists(directory):
        os.makedirs(directory)
        
    wavfile.write(name, 44100, data.flatten().astype(np.float32))

In [5]:
importConfig = {
    "input_audio_path": "TrainingData/flanger-input.wav",
    "target_audio_path": "TrainingData/flanger-target.wav",
    "output_path": "Data",
    "name": "flanger"
}

prepareTrainingData(importConfig)

In data converted from PCM16 to FP32
Out data converted from PCM16 to FP32


# Training

In [46]:
class SimpleLSTM(nn.Module):
    def __init__(self, input_size=1, output_size=1, hidden_size=32, skip=1, bias_fl=True, num_layers=1):
        super(SimpleLSTM, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        # Create dictionary of possible block types
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.lin = nn.Linear(hidden_size, output_size, bias=bias_fl)
        self.bias_fl = bias_fl
        self.skip = skip
        self.save_state = True
        self.hidden = (torch.zeros(1, 1, 32), 
                       torch.zeros(1, 1, 32))

    # Origin forward function 
    def forward(self, x):    
        if self.skip:
            # save the residual for the skip connection
            res = x[:, :, 0:self.skip]
            x, self.hidden = self.lstm(x, self.hidden)
            return self.lin(x) + res
        else:
            x, self.hidden = self.lstm(x, self.hidden)
            return self.lin(x)

    # detach hidden state, this resets gradient tracking on the hidden state
    def detach_hidden(self):
        if self.hidden.__class__ == tuple:
            self.hidden = tuple([h.clone().detach() for h in self.hidden])
        else:
            self.hidden = self.hidden.clone().detach()

    # changes the hidden state to None, causing pytorch to create an all-zero hidden state when the rec unit is called
    def reset_hidden(self, batch_size):
        self.hidden = (torch.zeros(1, batch_size, 32), 
                       torch.zeros(1, batch_size, 32))

    # This functions saves the model and all its paraemters to a json file, so it can be loaded by a JUCE plugin
    def save_model(self, file_name, direc=''):
        model_data = {'model_data': {'model': 'SimpleRNN', 'input_size': self.lstm.input_size, 'skip': self.skip,
                                     'output_size': self.lin.out_features, 'unit_type': self.lstm._get_name(),
                                     'num_layers': self.lstm.num_layers, 'hidden_size': self.lstm.hidden_size,
                                     'bias_fl': self.bias_fl}}

        if self.save_state:
            model_state = self.state_dict()
            for each in model_state:
                model_state[each] = model_state[each].tolist()
            model_data['state_dict'] = model_state

        json_save(model_data, file_name, direc)

        # Scripting the model for compatibility with LibTorch
        
        self.reset_hidden(1)
        
        model_scripted = torch.jit.script(self)
        # 
        # # Saving the scripted model
        scripted_model_file = file_name + "_scripted.pt"
        if direc:
             scripted_model_file = os.path.join(direc, scripted_model_file)
        model_scripted.save(scripted_model_file)

    # train_epoch runs one epoch of training
    def train_epoch(self, input_data, target_data, loss_fcn, optim, bs, init_len=200, up_fr=1000):
        # shuffle the segments at the start of the epoch
        shuffle = torch.randperm(input_data.shape[1])

        self.reset_hidden(bs)
        
        # Iterate over the batches
        ep_loss = 0
        for batch_i in range(math.ceil(shuffle.shape[0] / bs)):
            # Load batch of shuffled segments
            input_batch = input_data[:, shuffle[batch_i * bs:(batch_i + 1) * bs], :]
            target_batch = target_data[:, shuffle[batch_i * bs:(batch_i + 1) * bs], :]

            # Initialise network hidden state by processing some samples then zero the gradient buffers
            self(input_batch[0:init_len, :, :])
            self.zero_grad()

            # Choose the starting index for processing the rest of the batch sequence, in chunks of args.up_fr
            start_i = init_len
            batch_loss = 0
            # Iterate over the remaining samples in the mini batch
            for k in range(math.ceil((input_batch.shape[0] - init_len) / up_fr)):
                # Process input batch with neural network
                output = self(input_batch[start_i:start_i + up_fr, :, :])

                # Calculate loss and update network parameters
                loss = loss_fcn(output, target_batch[start_i:start_i + up_fr, :, :])
                loss.backward()
                optim.step()

                # Set the network hidden state, to detach it from the computation graph
                self.detach_hidden()
                self.zero_grad()

                # Update the start index for the next iteration and add the loss to the batch_loss total
                start_i += up_fr
                batch_loss += loss

            # Add the average batch loss to the epoch loss and reset the hidden states to zeros
            ep_loss += batch_loss / (k + 1)
        return ep_loss / (batch_i + 1)

    # only proc processes a the input data and calculates the loss, optionally grad can be tracked or not
    def process_data(self, input_data, target_data, loss_fcn, chunk, grad=False):
        with (torch.no_grad() if not grad else nullcontext()):
            self.reset_hidden(input_data.shape[1])
            output = torch.empty_like(target_data)
            for l in range(int(output.size()[0] / chunk)):
                output[l * chunk:(l + 1) * chunk] = self(input_data[l * chunk:(l + 1) * chunk])
                self.detach_hidden()
            # If the data set doesn't divide evenly into the chunk length, process the remainder
            if not (output.size()[0] / chunk).is_integer():
                output[(l + 1) * chunk:-1] = self(input_data[(l + 1) * chunk:-1])
            loss = loss_fcn(output, target_data)
        return output, loss

In [7]:
class ESRLoss(nn.Module):
    def __init__(self):
        super(ESRLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.add(target, -output)
        loss = torch.pow(loss, 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss
class DCLoss(nn.Module):
    def __init__(self):
        super(DCLoss, self).__init__()
        self.epsilon = 0.00001

    def forward(self, output, target):
        loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2)
        loss = torch.mean(loss)
        energy = torch.mean(torch.pow(target, 2)) + self.epsilon
        loss = torch.div(loss, energy)
        return loss      
class LossWrapper(nn.Module):
    def __init__(self, losses):
        super(LossWrapper, self).__init__()
        loss_dict = {'ESR': ESRLoss(), 'DC': DCLoss()}

        loss_functions = [[loss_dict[key], value] for key, value in losses.items()]

        self.loss_functions = tuple([items[0] for items in loss_functions])
        try:
            self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions]))
        except IndexError:
            self.loss_factors = torch.ones(len(self.loss_functions))

    def forward(self, output, target):
        loss = 0
        for i, losses in enumerate(self.loss_functions):
            loss += torch.mul(losses(output, target), self.loss_factors[i])
        return loss

In [8]:
# Splits audio, each split marker determines the fraction of the total audio in that split, i.e [0.75, 0.25] will put
# 75% in the first split and 25% in the second
def audio_splitter(audio, split_markers):
    assert sum(split_markers) <= 1.0
    if sum(split_markers) < 0.999:
        print("sum of split markers is less than 1, so not all audio will be included in dataset")
    start = 0
    slices = []
    # convert split markers to samples
    split_bounds = [int(x * audio.shape[0]) for x in split_markers]
    for n in split_bounds:
        end = start + n
        slices.append(audio[start:end])
        start = end
    return slices

# converts numpy audio into frames, and creates a torch tensor from them, frame_len = 0 just converts to a torch tensor
def framify(audio, frame_len):
    # If audio is mono, add a dummy dimension, so the same operations can be applied to mono/multichannel audio
    audio = np.expand_dims(audio, 1) if len(audio.shape) == 1 else audio
    # Calculate the number of segments the training data will be split into in frame_len is not 0
    seg_num = math.floor(audio.shape[0] / frame_len) if frame_len else 1
    # If no frame_len is provided, set frame_len to be equal to length of the input audio
    frame_len = audio.shape[0] if not frame_len else frame_len
    # Find the number of channels
    channels = audio.shape[1]
    # Initialise tensor matrices
    dataset = torch.empty((frame_len, seg_num, channels))
    # Load the audio for the training set
    for i in range(seg_num):
        dataset[:, i, :] = torch.from_numpy(audio[i * frame_len:(i + 1) * frame_len, :])
    return dataset

In [9]:
class DataSet:
    def __init__(self, data_dir='../Dataset/', extensions=('input', 'target')):
        self.extensions = extensions if extensions else ['']
        self.subsets = {}
        assert type(data_dir) == str, "data_dir should be string,not %r" % {type(data_dir)}
        self.data_dir = data_dir

    # add a subset called 'name', desired 'frame_len' is given in seconds, or 0 for just one long frame
    def create_subset(self, name, frame_len=0):
        assert type(name) == str, "data subset name must be a string, not %r" %{type(name)}
        assert not (name in self.subsets), "subset %r already exists" %name
        self.subsets[name] = SubSet(frame_len)

    # load a file of 'filename' into existing subset/s 'set_names', split fractionally as specified by 'splits',
    # if 'cond_val' is provided the conditioning value will be saved along with the frames of the loaded data
    def load_file(self, filename, set_names='train', splits=None, cond_val=None):
        # Assertions and checks
        if type(set_names) == str:
            set_names = [set_names]
        assert len(set_names) == 1 or len(set_names) == len(splits), "number of subset names must equal number of " \
                                                                     "split markers"
        assert [self.subsets.get(each) for each in set_names], "set_names contains subsets that don't exist yet"

        # Load each of the 'extensions'
        for i, ext in enumerate(self.extensions):
            try:
                file_loc = os.path.join(self.data_dir, filename + '-' + ext)
                file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc
                np_data = wavfile.read(file_loc)
            except FileNotFoundError:
                file_loc = os.path.join(self.data_dir, filename + ext)
                file_loc = file_loc + '.wav' if not file_loc.endswith('.wav') else file_loc
                np_data = wavfile.read(file_loc)
            except FileNotFoundError:
                print(["File Not Found At: " + self.data_dir + filename])
                return
            #print("np_data", np_data[1])  # KAB DEBUG ######################################
            #print("np_data dtype", np_data[1].dtype)  # KAB DEBUG ######################################
            raw_audio = np_data[1].astype(np.float32)
            #raw_audio = np_data[1]  # KAB DEBUG ######################################
            #print("raw_audio", raw_audio)  # KAB DEBUG ######################################
            # Split the audio if the set_names were provided
            if len(set_names) > 1:
                raw_audio = audio_splitter(raw_audio, splits)
                for n, sets in enumerate(set_names):
                    self.subsets[set_names[n]].add_data(np_data[0], raw_audio[n], ext, cond_val)
            elif len(set_names) == 1:
                self.subsets[set_names[0]].add_data(np_data[0], raw_audio, ext, cond_val)

In [10]:
# The SubSet class holds a subset of data,
# frame_len sets the length of audio per frame (in s), if set to 0 a single frame is used instead
class SubSet:
    def __init__(self, frame_len):
        self.data = {}
        self.cond_data = {}
        self.frame_len = frame_len
        self.conditioning = None
        self.fs = None

    # Add 'audio' data, in the data dictionary at the key 'ext', if cond_val is provided save the cond_val of each frame
    def add_data(self, fs, audio, ext, cond_val):
        if not self.fs:
            self.fs = fs
        assert self.fs == fs, "data with different sample rate provided to subset"
        # if no 'ext' is provided, all the subsets data will be stored at the 'data' key of the 'data' dict
        ext = 'data' if not ext else ext
        # Frame the data and optionally create a tensor of the conditioning values of each frame
        framed_data = framify(audio, self.frame_len)
        cond_data = cond_val * torch.ones(framed_data.shape[1]) if isinstance(cond_val, (float, int)) else None

        try:
            # Convert data from tuple to list and concatenate new data onto the data tensor
            data = list(self.data[ext])
            self.data[ext] = (torch.cat((data[0], framed_data), 1),)
            # If cond_val is provided add it to the cond_val tensor, note all frames or no frames must have cond vals
            if isinstance(cond_val, (float, int)):
                assert torch.is_tensor(self.cond_data[ext][0]), 'cond val provided, but previous data has no cond val'
                c_data = list(self.cond_data[ext])
                self.cond_data[ext] = (torch.cat((c_data[0], cond_data), 0),)
        # If this is the first data to be loaded into the subset, create the data and cond_data tuples
        except KeyError:
            self.data[ext] = (framed_data,)
            self.cond_data[ext] = (cond_data,)

In [11]:
class TrainTrack(dict):
    def __init__(self):
        self.update({'current_epoch': 0, 'training_losses': [], 'validation_losses': [], 'train_av_time': 0.0,
                     'val_av_time': 0.0, 'total_time': 0.0, 'best_val_loss': 1e12, 'test_loss': 0})

    def restore_data(self, training_info):
        self.update(training_info)

    def train_epoch_update(self, loss, ep_st_time, ep_end_time, init_time, current_ep):
        if self['train_av_time']:
            self['train_av_time'] = (self['train_av_time'] + ep_end_time - ep_st_time) / 2
        else:
            self['train_av_time'] = ep_end_time - ep_st_time
        self['training_losses'].append(loss)
        self['current_epoch'] = current_ep
        self['total_time'] += ((init_time + ep_end_time - ep_st_time)/3600)

    def val_epoch_update(self, loss, ep_st_time, ep_end_time):
        if self['val_av_time']:
            self['val_av_time'] = (self['val_av_time'] + ep_end_time - ep_st_time) / 2
        else:
            self['val_av_time'] = ep_end_time - ep_st_time
        self['validation_losses'].append(loss)
        if loss < self['best_val_loss']:
            self['best_val_loss'] = loss

In [12]:
# Function that saves 'data' to a json file. Constructs a file path is dir_name is provided.
def json_save(data, file_name, dir_name=''):
    dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
    assert type(file_name) == str
    file_name = file_name + '.json' if not file_name.endswith('.json') else file_name
    full_path = os.path.join(*dir_name, file_name)
    with open(full_path, 'w') as fp:
        json.dump(data, fp)

In [47]:
def train(config):
    start_time = time.time()

    current_directory = os.getcwd()
    result_parent_path = os.path.join(current_directory, config["save_location"])
    if not os.path.exists(result_parent_path):
        os.makedirs(result_parent_path)
    result_path = os.path.join(result_parent_path, config["hardware_device"])
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    save_path = os.path.join(config["save_location"], config["hardware_device"])
        
    # Check if a cuda device is available
    if torch.cuda.is_available():
        print('CUDA device available')
        torch.set_default_dtype(torch.cuda.FloatTensor)
        torch.cuda.set_device(0)
        cuda = 1
    else:
        print('CUDA device not available')
        cuda = 0
        
        
    # Create the LSTM network with the specified configuration
    network = SimpleLSTM(input_size=config["input_size"], 
                         output_size=config["output_size"], 
                         hidden_size=config["hidden_size"], 
                         skip=config["skip_con"])
        
    # Set up training optimiser + scheduler + loss fcns and training info tracker
    optimiser = torch.optim.Adam(network.parameters(), lr=config["learning_rate"], weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5, verbose=True)
    loss_functions = LossWrapper(config["loss_fcns"])
    train_track = TrainTrack()

    dataset = DataSet(data_dir='Data')

    dataset.create_subset('train', frame_len=22050)
    dataset.load_file(os.path.join('train', config["hardware_device"]), 'train')
    
    dataset.create_subset('val')
    dataset.load_file(os.path.join('val', config["hardware_device"]), 'val')
    
        # If training is restarting, this will ensure the previously elapsed training time is added to the total
    init_time = time.time() - start_time + train_track['total_time']*3600
    # Set network save_state flag to true, so when the save_model method is called the network weights are saved
    network.save_state = True

    for epoch in range(train_track['current_epoch'] + 1, config["epochs"] + 1):
        print("Epoch: ", epoch)
        ep_st_time = time.time()

        # Run 1 epoch of training,
        epoch_loss = network.train_epoch(dataset.subsets['train'].data['input'][0],
                                         dataset.subsets['train'].data['target'][0],
                                         loss_functions, optimiser, config['batch_size'], config['init_length'], config['up_fr'])

        if epoch % config['validation_f'] == 0:
            val_ep_st_time = time.time()
            val_output, val_loss = network.process_data(dataset.subsets['val'].data['input'][0],
                                             dataset.subsets['val'].data['target'][0], loss_functions, config["val_chunk"])
            scheduler.step(val_loss)
            print("Val loss:", val_loss)
            if val_loss < train_track['best_val_loss']:
                network.save_model('model_best', save_path)
                write(os.path.join(save_path, "best_val_out.wav"),
                      dataset.subsets['val'].fs, val_output.cpu().numpy()[:, 0, 0])
                
            train_track.val_epoch_update(val_loss.item(), val_ep_st_time, time.time())

        print('current learning rate: ' + str(optimiser.param_groups[0]['lr']))
        train_track.train_epoch_update(epoch_loss.item(), ep_st_time, time.time(), init_time, epoch)
        # write loss to the tensorboard (just for recording purposes)
        network.save_model('model', save_path)
        json_save(train_track, 'training_stats', save_path)        


In [None]:
trainConfig = {
    "input_size": 1, # Number of channels
    "output_size": 1, # Number of channels
    "skip_con": 0, # is there a skip connection for the input to the output
    "epochs": 500,
    "batch_size": 8,
    "init_length": 200, # Number of sequence samples to process before starting weight updates
    "up_fr": 1000, # For recurrent models, number of samples to run in between updating network weights
    "validation_f": 1, # Validation Frequency (in epochs)
    "val_chunk": 1000, #Number of sequence samples to process in n each chunk of validation
    "learning_rate": 0.005,
    "hidden_size": 32,
    "loss_fcns": {"ESR": 0.75, "DC": 0.25},
    "hardware_device": "ht1",
    "save_location": "Results",
    "export_json": 1,
    "export_torchscript": 1
}

train(trainConfig)


CUDA device not available
Epoch:  1


In [82]:
loaded_model = torch.jit.load("Results/ht1/model_best_scripted.pt")

# Example input
sequence_length = 100  # adjust as per your model's training
batch_size = 1         # can be set to 1 for testing individual sequences
input_size = loaded_model.input_size  # should match the model's expected input size

test_input = torch.zeros(4, 1, 1)

print(test_input)
test_output = loaded_model(test_input)
print(test_output)

tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]]])
tensor([[[0.1637]],

        [[0.1720]],

        [[0.1719]],

        [[0.1682]]], grad_fn=<ViewBackward0>)
