In [None]:
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import LSTM, Conv1D, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.backend import clear_session
from tensorflow.keras.activations import tanh, elu, relu
from tensorflow.keras.models import load_model
import tensorflow.keras.backend as K
from tensorflow.keras.utils import Sequence

import sys
from scipy import signal
from scipy.io import wavfile
import numpy as np
import matplotlib.pyplot as plt
import math
import h5py

In [None]:
# Training parameters
DI_path = "ht1-input.wav"
target_path = "ht1-target.wav"
# DI_path = "DI.wav"
# target_path = "target.wav"
train_val_ratio = 0.9
features = 1
hidden_size = 128
epochs = 10
sequence_length = 128
batch_size = 4096
conv_strides = [4, 3]
conv_chan = 32
lr = 1e-3
name = 'ht1'

In [None]:
# wavfile utils
def read_wav_file(path, normalize=True, bitwidth=16):
    rate, sig = wavfile.read(path)
    print("Read {} with {} samples, rate is {}Hz".format(path, len(sig), rate))
    sig = np.float32(sig)
    if normalize == True:
        return rate, 2 * (sig - sig.mean()) / (sig.max() - sig.min())  # Normalized
    return rate, sig

def write_wav_file(path, sig, rate, normalized=True, bitwidth=16):
    if normalized == True:
        _range = float(np.power(2, bitwidth-2))
        for i in range(len(sig)):
            sig[i] *= _range
        if bitwidth == 16:
            sig = sig.astype(np.int16)
        elif bitwidth == 24 or bitwidth == 32:
            sig = sig.astype(np.int32)
        else:
            print("Unknown bitwidth")
    wavfile.write(path, rate, sig)

# Load data
_, DI_dataset = read_wav_file(DI_path)
_, target_dataset = read_wav_file(target_path)
# Verify dataset length
if len(DI_dataset) != len(target_dataset):
    print("DI length and target length is not equal")
    sys.exit()
DI_dataset = DI_dataset
target_dataset = target_dataset
trainset_size = int(len(DI_dataset)*train_val_ratio)
input_trainset = DI_dataset[:trainset_size]
input_valset = DI_dataset[trainset_size:]
target_trainset = target_dataset[:trainset_size]
target_valset = target_dataset[trainset_size:]
# Plot original datasets
fig, ax = plt.subplots(4)
fig.suptitle("Datasets")
ax[0].plot(input_trainset)
ax[1].plot(target_trainset)
ax[2].plot(input_valset)
ax[3].plot(target_valset)
plt.show()

In [None]:
# Get data batchs from dataset
class SequenceBatch(Sequence):
    def __init__(self, x, y, sequence_length=sequence_length, batch_size=batch_size):
        self.x = x
        self.y = y[sequence_length-1:] 
        self.sequence_length = sequence_length
        self.batch_size = batch_size

    def __len__(self):
        return (len(self.x) - self.sequence_length + 1) // self.batch_size
    
    def __getitem__(self, index):
        x_out = np.stack([self.x[i : i+self.sequence_length] for i in range(index*self.batch_size, index*self.batch_size+self.batch_size)])
        y_out = self.y[index*self.batch_size : index*self.batch_size+self.batch_size]
        return x_out, y_out
    
# Pre-process datasets
input_trainset = input_trainset.reshape(-1, 1)
input_valset = input_valset.reshape(-1, 1)
target_trainset = target_trainset.reshape(-1, 1)
target_valset = target_valset.reshape(-1, 1)

train_dataset = SequenceBatch(input_trainset, target_trainset)
val_dataset = SequenceBatch(input_valset, target_valset)

In [None]:
# Create model
clear_session()
model = Sequential()
model.add(Conv1D(conv_chan, 12, strides=conv_strides[0], activation=None, padding='same', input_shape=(sequence_length, 1)))
model.add(Conv1D(conv_chan, 12, strides=conv_strides[1], activation=None, padding='same'))
model.add(LSTM(hidden_size))
model.add(Dense(1, activation=None))
model.compile(optimizer=Adam(learning_rate=lr), loss='mse')
model.summary()

In [None]:
# Train model
history = model.fit(train_dataset, validation_data=val_dataset, epochs=epochs, shuffle=True)    
model.save(name + '.h5')

In [None]:
# Prediction
pred = model.predict(val_dataset)
target = target_valset.flatten()
plt.figure(2)
plt.plot(target[:441000], label="target")
plt.plot(pred[:441000], label="predict")
plt.legend()
plt.show()
write_wav_file("output.wav", pred, 44100)