In [None]:
# !pip install keras
from scipy.io import wavfile
import numpy as np
import matplotlib.pyplot as plt
import sys
import math
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import LSTM, Dense, GRU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.backend import clear_session
from tensorflow.keras.activations import tanh, elu, relu
from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.utils import Sequence
from sklearn.metrics import mean_squared_error

In [None]:
# Training parameters
DI_path = "DI.wav"
target_path = "target.wav"
need_compensate_delay = True
train_val_ratio = 0.8
dataset_frame_size = 221
hidden_size = 32
epochs = 10
batch_size = 4410
lr = 0.001
name = 'test'

In [None]:
# wavfile utils
def read_wav_file(path, normalize=True):
  rate, sig = wavfile.read(path)
  print("Read {} with {} samples, rate is {}Hz".format(path, len(sig), rate))
  sig = np.float32(sig)
  if normalize == True:
    return rate, (sig - np.mean(sig)) * 2 / (sig.max() - sig.min()) # Normalized
  return rate, sig

def write_wav_file(path, sig, rate, normalized=True, bitwidth=16):
  if normalized == True:
    _range = float(np.power(2, bitwidth-1))
    for i in range(len(sig)):
      sig[i] *= _range
    sig.astype(np.int16)
  plt.figure(1)
  plt.plot(sig)
  plt.show()
  wavfile.write(path, rate, sig)

In [None]:
class WindowArray(Sequence):
  def __init__(self, x, y, window_len, batch_size=32):
    self.x = x
    self.y = y[window_len-1:]
    self.window_len = window_len
    self.batch_size = batch_size

  def __len__(self):
    return (len(self.x) - self.window_len +1) // self.batch_size

  def __getitem__(self, index):
    x_out = np.stack([self.x[idx: idx+self.window_len] for idx in range(index*self.batch_size, (index+1)*self.batch_size)])
    y_out = self.y[index*self.batch_size:(index+1)*self.batch_size]
    return x_out, y_out

In [None]:
def IIR_LP_filter(cutoff_freq, trans_band=500, sample_rate=44100):
    fc = float(cutoff_freq) / float(sample_rate)  # Cutoff frequency as a fraction of the sampling rate (in (0, 0.5)).
    b = float(trans_band) / float(sample_rate)   # Transition band, as a fraction of the sampling rate (in (0, 0.5)).
    N = int(np.ceil((4 / b)))
    if not N % 2: N += 1  # Make sure that N is odd.
    n = np.arange(N)
    # Compute sinc filter.
    h = np.sinc(2 * fc * (n - (N - 1) / 2))
    # Compute Blackman window.
    w = 0.42 - 0.5 * np.cos(2 * np.pi * n / (N - 1)) + \
        0.08 * np.cos(4 * np.pi * n / (N - 1))
    # Multiply sinc filter by window.
    h = h * w
    # Normalize to get unity gain.
    h = h / np.sum(h)
    return h

def delay_compensation(DI_dataset, target_dataset):
    h = IIR_LP_filter(100)
    print("Before delay compensation")
    plt.figure(1)
    plt.plot(DI_dataset[:50000], label="DI")
    plt.plot(target_dataset[:50000], label="target")
    plt.legend()
    plt.show()
    correlated = np.correlate(target_dataset[:10000], DI_dataset[:10000], mode='full')
    index = correlated.argmax()
    delay = index - 10000
    # delay = target_max_index - DI_max_index
    print("DI max index: {}, delay: {} samples".format(index, delay))
    target_dataset = target_dataset[delay:]
    DI_dataset = DI_dataset[:len(DI_dataset)-delay]
    print("After delay compensation")
    plt.figure(2)
    plt.plot(DI_dataset[:50000], label="DI")
    plt.plot(target_dataset[:50000], label="target")
    plt.legend()
    plt.show()

In [None]:
# Load datasets
tf.random.set_seed(7)
_, DI_dataset = read_wav_file(DI_path)
_, target_dataset = read_wav_file(target_path)
# Verify dataset length
if len(DI_dataset) != len(target_dataset):
  print("DI length and target length is not equal")
  sys.exit()
if need_compensate_delay == True:
  delay_compensation(DI_dataset, target_dataset)
trainset_size = int(len(DI_dataset)*train_val_ratio)
input_trainset = DI_dataset[:trainset_size]
input_valset = DI_dataset[trainset_size:]
target_trainset = target_dataset[:trainset_size]
target_valset = target_dataset[trainset_size:]
# Plot original datasets
fig, ax = plt.subplots(4)
fig.suptitle("Datasets")
ax[0].plot(input_trainset)
ax[1].plot(target_trainset)
ax[2].plot(input_valset)
ax[3].plot(target_valset)
plt.show()
# Split datasets to frames
input_trainset = input_trainset.reshape(-1, 1)
input_valset = input_valset.reshape(-1, 1)
target_trainset = target_trainset.reshape(-1, 1)
target_valset = target_valset.reshape(-1, 1)
train_arr = WindowArray(input_trainset, target_trainset, dataset_frame_size, batch_size=batch_size)
val_arr = WindowArray(input_valset, target_valset, dataset_frame_size, batch_size=batch_size)
output_arr = WindowArray(DI_dataset.reshape(-1, 1), target_dataset.reshape(-1, 1), dataset_frame_size, batch_size=batch_size)

In [None]:
# Create model
model = Sequential()
model.add(GRU(hidden_size, input_shape=(dataset_frame_size, 1)))
model.add(Dense(hidden_size, activation=tanh))
model.add(Dense(1, activation=None))
model.compile(optimizer=Adam(learning_rate=lr), loss='mse')
model.summary()

In [None]:
# Train model
history = model.fit(train_arr, validation_data=val_arr, epochs=epochs, shuffle=True)

In [None]:
# Save model
model.save('models/'+name+'/'+name+'.h5')

In [None]:
# Predict
prediction = model.predict(output_arr)

In [None]:
plt.figure(1)
plt.plot(target_dataset, label="target")
plt.plot(prediction.flatten(), label="output")
plt.legend()
plt.show()

In [None]:
write_wav_file("output.wav", prediction.flatten(), 44100, normalized=False)