In [None]:
import sys
import torch
import scipy.io.wavfile as wav
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import math
import os
import json

In [None]:
class Dataset(object):
  def __init__(self, data_input, data_target, device):
    self.input = torch.from_numpy(data_input).to(device)
    self.target = torch.from_numpy(data_target).to(device)

In [None]:
class WavDataset(object):
  def __init__(self, frame_size, device, for_processing=False, _wave_sample_rate=44100, _wav_dtype="int16", _wav_channels=1, _DI_path="DI.wav", _target_path="target.wav", ratio=0.8):
    # Supported wav format
    self._wave_sample_rate = _wave_sample_rate
    self._wav_dtype = _wav_dtype
    self._wav_channels = _wav_channels
    self.device = device
    self.frame_size = frame_size

    if for_processing == False:
      # File paths
      self._DI_path = _DI_path
      self._target_path = _target_path
      self.__load_data(ratio=ratio)
    else:
      self._raw_path = _DI_path
      self.__load_data(for_processing=True)

  def __load_data(self, for_processing=False, ratio=0.8):
    if for_processing == False:
      input_raw = self.__read_wav_file(self._DI_path)
      target_raw = self.__read_wav_file(self._target_path)

      if len(input_raw) != len(target_raw):
        print("DI track has {} frames but target track has {} frames".format(len(input_raw, len(target_raw))))
        print("Please make sure these 2 files match and have the same length")
        sys.exit()

      data_length = len(input_raw)
      frame_cnt = math.floor(data_length / self.frame_size)
      input_framelized = np.zeros((frame_cnt, self.frame_size, 1), dtype=np.float32)
      target_framelized = np.zeros((frame_cnt, self.frame_size, 1), dtype=np.float32)
      # Convert 1-D wav data to data frames
      for i in range(frame_cnt):
        input_framelized[i] = input_raw[i*self.frame_size : (i+1)*self.frame_size].reshape(-1, 1)
        target_framelized[i] = target_raw[i*self.frame_size : (i+1)*self.frame_size].reshape(-1, 1)

      split_index = int(ratio * frame_cnt)
      train_input = input_framelized[:split_index, :]
      train_target = target_framelized[:split_index, :]
      test_input = input_framelized[split_index:, :]
      test_target = target_framelized[split_index:, :]

      self.train_set = Dataset(train_input, train_target, device=self.device)
      self.test_set = Dataset(test_input, test_target, device=torch.device("cpu"))
    else:
      input_raw = self.__read_wav_file(self._raw_path)
      data_length = len(input_raw)
      frame_cnt = math.floor(data_length / self.frame_size)
      input_framelized = target_framelized = np.zeros((frame_cnt, self.frame_size, 1), dtype=np.float32)
      # Convert 1-D wav data to data frames
      for i in range(frame_cnt):
        input_framelized[i] = input_raw[i*self.frame_size : (i+1)*self.frame_size].reshape(-1, 1)
        self.process_set = torch.from_numpy(input_framelized).to(self.device)

  def __read_wav_file(self, path):
    print("Reading {}...".format(path))
    rate, wavsignal = wav.read(path)

    # Check wavfile data format
    if ((rate != self._wave_sample_rate) or (wavsignal.dtype != self._wav_dtype) or (wavsignal.ndim != self._wav_channels)):
        print("Unsupported wav file format: {}Hz, {}bit, {} channels".format(rate, wavsignal.dtype, wavsignal.ndim))
        print("Please provide a wav file with {}Hz, {}bit, {} channels".format(self._wave_sample_rate, self._wav_dtype, self._wav_channels))
        sys.exit()

    print("Successfully read {} with {} samples".format(path, wavsignal.shape[0]))

    wavsignal = np.float32(wavsignal)
    wavsignal = wavsignal / 32768.0 # Normalize
    return wavsignal

In [None]:
# Function that checks if a directory exists, and creates it if it doesn't, if dir_name is a list of strings, it will
# create a search path, i.e dir_name = ['directory', 'subdir'] will search for directory 'directory/subdir'
def dir_check(dir_name):
    dir_name = [dir_name] if not type(dir_name) == list else dir_name
    dir_path = os.path.join(*dir_name)
    if os.path.isdir(dir_path):
        pass
    else:
        os.mkdir(dir_path)


# Function that takes a file_name and optionally a path to the directory the file is expected to be, returns true if
# the file is found in the stated directory (or the current directory is dir_name = '') or False is dir/file isn't found
def file_check(file_name, dir_name=''):
    assert type(file_name) == str
    dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
    full_path = os.path.join(*dir_name, file_name)
    return os.path.isfile(full_path)


# Function that saves 'data' to a json file. Constructs a file path is dir_name is provided.
def json_save(data, file_name, dir_name=''):
    dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
    assert type(file_name) == str
    file_name = file_name + '.json' if not file_name.endswith('.json') else file_name
    full_path = os.path.join(*dir_name, file_name)
    with open(full_path, 'w') as fp:
        json.dump(data, fp)


def json_load(file_name, dir_name=''):
    dir_name = [dir_name] if ((type(dir_name) != list) and (dir_name)) else dir_name
    file_name = file_name + '.json' if not file_name.endswith('.json') else file_name
    full_path = os.path.join(*dir_name, file_name)
    with open(full_path) as fp:
        return json.load(fp)

In [None]:
class RNNModel(nn.Module):
  def __init__(self, input_size=1, output_size=1, hidden_size=64, bias=True):
    super().__init__()
    self.input_size = input_size
    self.output_size = output_size
    self.bias = bias
    self.hidden_size = hidden_size
    self.hidden = None

    # Create model
    self.rec = nn.LSTM(input_size, hidden_size, 1)
    self.linear = nn.Linear(hidden_size, output_size, bias=bias)
    
    # Init weights
    for name, param in self.named_parameters():
      if 'bias' in name:
        nn.init.constant(param, 0.0)
      if 'weight' in name:
        nn.init.xavier_normal(param)

  def forward(self, input):
    rec_output, self.hidden = self.rec(input, self.hidden)
    output = self.linear(rec_output)
    output += input
    return output

  def reset_hidden(self):
    self.hidden = None

  def train_epoch(self, loss_function, optimizer, input, target, batch_size):
    # Shuffle train set
    idx = torch.randperm(input.shape[0])
    input = input[idx]
    target = target[idx]
    # Iterate over the batches, each batch contains several frames
    epoch_loss = 0
    batch_cnt = math.ceil(input.shape[0] / batch_size)
    for i in range(batch_cnt):
      input_batch = input[batch_size*i : batch_size*i+batch_size, :, :]
      target_batch = target[batch_size*i : batch_size*i+batch_size, :, :]
      self.zero_grad()
      # Process input batch with neural network
      output = self(input_batch)
      # Calculate loss and update network parameters
      loss = loss_function(output, target_batch)
      loss.backward()
      optimizer.step()
      self.reset_hidden()
      if i % 10 == 0:
        print("Frame {}/{}: {:.2%}, loss = {}".format(i, batch_cnt, i/batch_cnt, loss), end='\r')
      epoch_loss += loss
    print("")
    return epoch_loss / (i + 1)

  def predict(self, input, batch_size=4096):
    torch.no_grad()
    output = torch.empty_like(input).to(torch.device("cpu"))
    batch_cnt = math.ceil(len(input) / batch_size)
    for i in range(batch_cnt):
      print("Process batch {}/{}: {:.2%}".format(i, batch_cnt, i/batch_cnt))
      output[batch_size*i : batch_size*i+batch_size] = self(input[batch_size*i : batch_size*i+batch_size, :])
      self.reset_hidden()
    return output
  
  def save_model(self, file_name, direc=''):
    if direc:
      dir_check(direc)
    model_data = {'model_data': {'model': 'RNN', 'input_size': self.rec.input_size,
                                 'output_size': self.linear.out_features, 'unit_type': self.rec._get_name(),
                                  'num_layers': self.rec.num_layers, 'hidden_size': self.rec.hidden_size,
                                  'bias_fl': self.bias}}
    model_state = self.state_dict()
    for each in model_state:
      model_state[each] = model_state[each].tolist()
    model_data['state_dict'] = model_state
    json_save(model_data, file_name, direc)

In [None]:
def train_model(model: RNNModel, loss_function, epochs,
          optimizer: torch.optim.Optimizer, wav_dataset: WavDataset, batch_size, lr):
  for i in range(epochs):
    # Run a train epoch
    print("Train epoch {}".format(i+1))
    if i > 1:
      lr /= 1.5
    for param_group in optimizer.param_groups:
      param_group['lr'] = lr
    epoch_loss = model.train_epoch(loss_function, optimizer, wav_dataset.train_set.input, wav_dataset.train_set.target, batch_size)
    print("Epoch loss = {}".format(epoch_loss))

In [None]:
# Train paramters
model = RNNModel(hidden_size=96)
print(model)
loss_function = nn.MSELoss()
epochs = 100
lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

if torch.cuda.is_available():
  device = torch.device("cuda:0")
  torch.cuda.empty_cache()
  print("Running on the GPU")
else:
  device = torch.device("cpu")
  print("Running on the CPU")

# device = torch.device("cpu")


In [None]:
# Start training
model.to(device)
wav_dataset = WavDataset(frame_size=4410, device=device) # Load wav files as datasets, 100ms per frame
# train_input = torch.flatten(wav_dataset.train_set.input).to(torch.device("cpu")).numpy()
# train_target = torch.flatten(wav_dataset.train_set.target).to(torch.device("cpu")).numpy()
# test_input = torch.flatten(wav_dataset.test_set.input).to(torch.device("cpu")).numpy()
# test_target = torch.flatten(wav_dataset.test_set.target).to(torch.device("cpu")).numpy()
# plt.figure(1)
# plt.plot(train_input, label="train_input")
# plt.figure(2)
# plt.plot(train_target, label="train_target")
# plt.figure(3)
# plt.plot(test_input, label="test_input")
# plt.figure(4)
# plt.plot(test_target, label="test_target")
# plt.legend()
# plt.show()
train_model(model=model, loss_function=loss_function, epochs=epochs, optimizer=optimizer,wav_dataset=wav_dataset, batch_size=40, lr=lr)

In [None]:
# Test model
model.to(torch.device("cpu"))
wav_dataset.test_set.target.to(torch.device("cpu"))
wav_dataset.test_set.input.to(torch.device("cpu"))
before = wav_dataset.test_set.input
trained = model.predict(wav_dataset.test_set.input)
# Calculate accuracy
trained = torch.flatten(trained).detach().numpy()
target = torch.flatten(wav_dataset.test_set.target).detach().numpy()
before = torch.flatten(before).detach().numpy()

accuracy = 0
for i in range(len(trained)):
    if trained[i] - target[i] < 1e-4:
        accuracy += 1
accuracy /= i
print("Test accuracy = {:.2%}".format(accuracy))

In [None]:
# Plot test result
plt.figure(1)
plt.plot(target[:500], color="green", label="Target")
plt.plot(trained[:500], color="red", label="Trained result")
plt.plot(before[:500], color="yellow", label="Before training")
plt.xlabel("Sample")
plt.ylabel("Normalized value")
plt.legend()
plt.show()
# for param in model.parameters():
#     print(param)
torch.save(model, "demo.pth")
model.save_model("marshall.json")

In [None]:
trained = np.int16(trained)
before = np.int16(before)
plt.plot(trained, label="Trained wav file")
plt.plot(before, label="Raw wav file")
plt.xlabel("Sample")
plt.ylabel("Normalized value")
plt.legend()
plt.show()
wav.write("output.wav", 44100, trained)

In [None]:
# Process raw sound
torch.cuda.empty_cache()
test_model = torch.load("demo.pth")
for param in test_model.parameters():
    param.requires_grad = False
test_model.eval()
test_model.to(device=torch.device("cpu"))

input = WavDataset(frame_size=4096, device=torch.device("cpu"), for_processing=True, _DI_path="raw.wav").process_set
output = test_model.predict(input)
output = torch.flatten(output.to(torch.device("cpu"))).detach().numpy()
input = torch.flatten(input.to(torch.device("cpu"))).detach().numpy()
output = np.int16(output)
input = np.int16(input)
plt.plot(output[100000:101000], label="Output wav file")
plt.plot(input[100000:101000], label="Input wav file")
plt.xlabel("Sample")
plt.ylabel("Normalized value")
plt.legend()
plt.show()
wav.write("processed.wav", rate=44100, data=output)
