In [1]:
import numpy as np

import random
import h5py
from keras import backend as K
from nn_util import *
from keras.models import *
from keras.layers import *
from keras.layers.core import *
from keras.layers.normalization import *
from keras.optimizers import *
from keras.initializers import *
from keras.models import load_model
from keras.losses import *
import scipy.io.wavfile as sciwav
import multiprocessing

import os
import random
import time
import matplotlib
import matplotlib.pyplot as plt
import glob

import operator
import math
import re

# for reproducibility
np.random.seed(1337) 
random.seed(1337)

# increase recursion limit for adaptive VQ
import sys
sys.setrecursionlimit(40000)

np.set_printoptions(formatter={'float_kind':'{:4f}'.format})

Using TensorFlow backend.


In [2]:
# control amount of GPU memory used
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))

In [3]:
# external custom code I wrote
from load_data import *
from windowing import *
from pesq import *
from consts import *
from nn_blocks import *
from perceptual_loss import *
from evaluation import *

In [4]:
[train_paths, val_paths, test_paths], \
[train_waveforms, val_waveforms, test_waveforms], \
[train_procwave, val_procwave, test_procwave], \
[train_wparams, val_wparams, test_wparams], \
[train_windows, val_windows, test_windows] = load_data(TRAIN_SIZE, VAL_SIZE, TEST_SIZE)

In [5]:
# flatten all of the train windows into vectors
train_processed = np.array([i for z in train_windows for i in z])
train_processed = np.reshape(train_processed, (train_processed.shape[0], WINDOW_SIZE, 1))

# randomly shuffle data, if we want to
if (RANDOM_SHUFFLE):
    train_processed = np.random.permutation(train_processed)
    
print train_processed.shape
print np.mean(train_processed, axis=None)
print np.std(train_processed, axis=None)
print np.min(train_processed, axis = None)
print np.max(train_processed, axis = None)

(101544, 512, 1)
-2.40941e-06
0.104158
-1.0
1.0


In [6]:
input_dim = (WINDOW_SIZE, 1)

In [7]:
# softmax hardness variable
tau = K.variable(0.0001, name = "hardness")

In [8]:
DOWNSAMPLE_FACTOR = 2
CHANNEL_SIZE = WINDOW_SIZE / DOWNSAMPLE_FACTOR
    
# ---------------------------------------------------------------------------
# autoencoder: takes an audio window, compresses it, and tries to reconstruct it
# ---------------------------------------------------------------------------
def autoencoder_structure(dim):   
    # - - - - - - - - - - - - - - - - - - - - -
    # parameters
    # - - - - - - - - - - - - - - - - - - - - -   
    NCHAN = 48
    FILT_SIZE = 9
    
    # - - - - - - - - - - - - - - - - - - - - -
    # encoder
    # - - - - - - - - - - - - - - - - - - - - -
    enc_input = Input(shape = dim)
    enc = enc_input
    
    enc = Reshape(dim, input_shape = dim)(enc)  
    
    enc = channel_change_block(NCHAN, FILT_SIZE)(enc)
    enc = residual_block(NCHAN, FILT_SIZE, 1)(enc)
    enc = residual_block(NCHAN, FILT_SIZE, 2)(enc)
    enc = residual_block(NCHAN, FILT_SIZE, 4)(enc)
    enc = downsample_block(NCHAN, FILT_SIZE)(enc)
    enc = residual_block(NCHAN, FILT_SIZE, 1)(enc)
    enc = residual_block(NCHAN, FILT_SIZE, 2)(enc)
    enc = residual_block(NCHAN, FILT_SIZE, 4)(enc)
    enc = channel_change_block(1, FILT_SIZE)(enc)
    
    # quantization
    enc = Reshape((CHANNEL_SIZE,))(enc)
    enc = SoftmaxQuantization()(enc)
    
    enc = Model(inputs = enc_input, outputs = enc)
    
    # - - - - - - - - - - - - - - - - - - - - -
    # decoder
    # - - - - - - - - - - - - - - - - - - - - -
    dec_input = Input(shape = (CHANNEL_SIZE, NBINS))
    dec = dec_input
    
    # dequantization
    dec = SoftmaxDequantization()(dec)    
    dec = Reshape((CHANNEL_SIZE, 1))(dec)
    
    dec = channel_change_block(NCHAN, FILT_SIZE)(dec)
    dec = residual_block(NCHAN, FILT_SIZE, 1)(dec)
    dec = residual_block(NCHAN, FILT_SIZE, 2)(dec)
    dec = residual_block(NCHAN, FILT_SIZE, 4)(dec)
    dec = upsample_block(NCHAN, FILT_SIZE)(dec)
    dec = residual_block(NCHAN, FILT_SIZE, 1)(dec)
    dec = residual_block(NCHAN, FILT_SIZE, 2)(dec)
    dec = residual_block(NCHAN, FILT_SIZE, 4)(dec)
    dec = channel_change_block(1, FILT_SIZE)(dec)

    dec = Model(inputs = dec_input, outputs = dec)
    
    # return both encoder and decoder
    return enc, dec

In [9]:
# we can compute the entropy of a batch directly
def code_entropy(placeholder, code):
    all_onehots = K.reshape(code, (-1, NBINS))
    onehot_hist = K.sum(all_onehots, axis = 0)
    onehot_hist /= K.sum(onehot_hist)

    entropy = -K.sum(onehot_hist * K.log(onehot_hist + K.epsilon()) / K.log(2.0))
    loss = tau * entropy
    return loss

def code_sparsity(placeholder, code):
    sparsity = K.mean(K.sum(K.sqrt(code + K.epsilon()), axis = -1), axis = -1) - 1.0
    return sparsity

In [10]:
# map for load_model
KERAS_LOAD_MAP = {'PhaseShiftUp1D' : PhaseShiftUp1D,
                  'code_entropy' : code_entropy,
                  'code_sparsity' : code_sparsity,
                  'rmse' : rmse,
                  'SoftmaxQuantization' : SoftmaxQuantization,
                  'SoftmaxDequantization' : SoftmaxDequantization,
                  'MEL_FILTERBANK' : MEL_FILTERBANK,
                  'DFT_REAL' : DFT_REAL,
                  'DFT_IMAG' : DFT_IMAG,
                  'MFCC_DCT' : MFCC_DCT,
                  'keras_dft_mag' : keras_dft_mag,
                  'keras_dct' : keras_dct,
                  'perceptual_transform' : perceptual_transform,
                  'perceptual_distance' : perceptual_distance}

In [11]:
# construct autoencoder
ac_input = Input(shape = input_dim)

encoder, decoder = autoencoder_structure(input_dim)
ac_reconstructed = decoder(encoder(ac_input))
autoencoder = Model(inputs = [ac_input], outputs = [ac_reconstructed])

In [12]:
# model parameters
loss_weights = [30.0, 1.0, 5.0, 1.0]
loss_functions = [rmse, perceptual_distance, code_sparsity, code_entropy]
n_recons = 2
n_code = 2
assert(n_recons + n_code == len(loss_weights))
assert(len(loss_weights) == len(loss_functions))

In [13]:
# model specification
model_input = Input(shape = input_dim)
model_embedding = encoder(model_input)
model_reconstructed = decoder(model_embedding)

model = Model(inputs = [model_input], outputs = [model_reconstructed] * n_recons + \
                                            [model_embedding] * n_code)

  ' Found: ' + str(self.outputs))


In [14]:
model.compile(loss = loss_functions,
              loss_weights = loss_weights,
              optimizer = Adam())

model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 512, 1)            0         
_________________________________________________________________
model_1 (Model)              (None, 256, 32)           313486.0  
_________________________________________________________________
model_2 (Model)              (None, 512, 1)            396589    
Total params: 710,075
Trainable params: 710,075
Non-trainable params: 0
_________________________________________________________________


In [15]:
# get untrained baseline for model
test_model_on_wav("./SA1.wav", "./train_output/SA1_uninit", autoencoder)

MSE:         155322.0
Avg err:     210.913
PESQ:        1.02260279655


[155322.05, 210.91307, 1.0226027965545654]

In [16]:
# saves current model
def save_model(prefix = 'best'):
    os.system('rm ./' + prefix + '_model.h5')
    os.system('rm ./' + prefix + '_auto.h5')
    
    model.save('./' + prefix + '_model.h5')
    autoencoder.save('./' + prefix + '_auto.h5')
    
    f = h5py.File('best_model.h5', 'r+')
    del f['optimizer_weights']
    f.close()

In [17]:
def evaluate_training(autoencoder, lead = ""):
    def set_evaluation(windows, wparams, eval_idxs):
        before_after_pairs = np.array([run_model_on_windows(windows[i],
                                                    wparams[i],
                                                    autoencoder,
                                                    argmax = True)
                                       for i in eval_idxs])
        
        NUM_THREADS = 8
        list_range = np.arange(0, len(eval_idxs))
        slices = [list_range[i:None:NUM_THREADS]
                  for i in xrange(0, NUM_THREADS)]
        
        def thread_func(pairs, q):
            for p in pairs:
                q.put(evaluation_metrics(p[0], p[1]))
                
        q = multiprocessing.Queue()
        threads = [multiprocessing.Process(target = thread_func,
                                           args = (before_after_pairs[slices[i]], q))
                   for i in xrange(0, NUM_THREADS)]
        [t.start() for t in threads]
        [t.join() for t in threads]
        
        return np.array([q.get() for i in list_range])
    
    train_eval_idxs = random.sample(range(0, len(train_windows)), TRAIN_EVALUATE)
    val_eval_idxs = random.sample(range(0, len(val_windows)), VAL_EVALUATE)
    
    print lead + "Format: [MSE, avg err, PESQ]"
    
    # train set evaluation
    train_metrics = set_evaluation(train_windows, train_wparams,
                                   train_eval_idxs)
    print lead + "    Train: (mean)", np.mean(train_metrics, axis = 0)
    print lead + "    Train: (max) ", np.max(train_metrics, axis = 0)
    print lead + "    Train: (min) ", np.min(train_metrics, axis = 0)
    
    # validation set evaluation
    val_metrics = set_evaluation(val_windows, val_wparams,
                                 val_eval_idxs)
    print lead + "    Val:   (mean)", np.mean(val_metrics, axis = 0)
    print lead + "    Val:   (max) ", np.max(val_metrics, axis = 0)
    print lead + "    Val:   (min) ", np.min(val_metrics, axis = 0)
    
    # returns mean PESQ on validation
    return np.mean(val_metrics, axis = 0)[2]

In [18]:
X_train = np.copy(train_processed)
ntrain = X_train.shape[0]

BATCH_SIZE = 128
NUM_EPOCHS = 200

ORIG_BITRATE = 256.00
TARGET_BITRATE = 16.00
PRE_ENTROPY_RATE = ORIG_BITRATE / DOWNSAMPLE_FACTOR

TARGET_ENTROPY = (TARGET_BITRATE / PRE_ENTROPY_RATE * 16.0)
TARGET_ENTROPY *= (STEP_SIZE / float(WINDOW_SIZE))
TARGET_ENTROPY_FUZZ = 0.1

TAU_CHANGE_RATE = 0.0125
MIN_TAU = 0.0125

STARTING_LR = 0.001

print "Target entropy:", TARGET_ENTROPY

Target entropy: 1.875


In [19]:
best_val_pesq = 0.0
K.set_value(tau, MIN_TAU)
T_i = 0.0

In [20]:
np.set_printoptions(formatter={'float_kind':'{:4f}'.format})
lead = "    "

for epoch in range(1, NUM_EPOCHS + 1):
    print "Epoch " + str(epoch) + ":"

    # present batches randomly each epoch
    lis = range(0, ntrain, BATCH_SIZE)
    random.shuffle(lis)
    num_batches = len(lis)
    
    # keep track of start time and current batch #
    i = 0
    startTime = time.time()
    for idx in lis:
        # cosine annealing for model's learning rate
        train_pct = T_i / float(NUM_EPOCHS)
        opt_lr = 0.5 * STARTING_LR * (1 + math.cos(3.14159 * train_pct))
        T_i += (1.0 / num_batches)
        K.set_value(model.optimizer.lr, opt_lr)
        
        batch = X_train[idx:idx+BATCH_SIZE, :,  :]
        nbatch = batch.shape[0]
               
        # train autoencoder
        a_y = [batch] * n_recons + \
              [np.zeros((nbatch, WINDOW_SIZE, NBINS))] * n_code       

        a_losses = model.train_on_batch(batch, a_y)
        
        # print statistics every 10 batches so we know what's going on
        if (i % 10 == 0):
            printStr = "        \r" + lead + str(i * BATCH_SIZE) + ": "
            print printStr,
            
            loss_arr = np.asarray(a_losses)
            print loss_arr,
            
            if (len(loss_weights) > 1 and len(loss_arr) > 1):
                for w in xrange(0, len(loss_weights)):
                    loss_arr[w + 1] *= loss_weights[w]
                print loss_arr,
            
            print K.get_value(tau), opt_lr,
        
        i += 1
    print ""
    
    # print elapsed time for epoch
    elapsed = time.time() - startTime
    print lead + "Total time for epoch: " + str(elapsed) + "s"
    
    # ---------------------------------------------------------
    # estimate code entropy from random samples (if quantization is on)
    # ---------------------------------------------------------
    NUM = 500
    rows = np.random.randint(X_train.shape[0], size = NUM)
    code = encoder.predict(X_train[rows, :], verbose = 0)
    probs = np.reshape(code, (code.shape[0] * code.shape[1], NBINS))
    hist = np.sum(probs, axis = 0)
    hist /= np.sum(hist)

    entropy = 0
    for i in hist:
        if (i < 1e-4): continue
        entropy += i * math.log(i, 2)
    entropy = -entropy

    print lead + "----------------"
    print lead + "Code entropy:", entropy
    
    # ---------------------------------------------------------
    # handle updating entropy weight (tau)
    # ---------------------------------------------------------
    old_tau = K.get_value(tau)

    if (entropy < TARGET_ENTROPY - TARGET_ENTROPY_FUZZ):
        new_tau = old_tau - TAU_CHANGE_RATE
        if (new_tau <= MIN_TAU):
            new_tau = MIN_TAU

        K.set_value(tau, new_tau)
        print lead + "Updated tau from", old_tau, "to", new_tau
    elif (entropy > TARGET_ENTROPY + TARGET_ENTROPY_FUZZ):
        new_tau = old_tau + TAU_CHANGE_RATE

        K.set_value(tau, new_tau)
        print lead + "Updated tau from", old_tau, "to", new_tau
    else:
        print lead + "Tau stays at", old_tau
    
    # ---------------------------------------------------------
    # evaluate autoencoder on training/validation data evey epoch
    # ---------------------------------------------------------
    startTime = time.time()
    print lead + "----------------"
    print lead + "Evaluating autoencoder..."
    
    
    metrics = test_model_on_wav("./SA1.wav", "./train_output/SA1_train_epoch" + str(epoch),
                              autoencoder, lead = lead, verbose = False, argmax = False)
    print lead + "SA1:         ", metrics
    metrics = test_model_on_wav("./SA1.wav", "./train_output/SA1_train_epoch" + str(epoch),
                              autoencoder, lead = lead, verbose = False, argmax = True)
    print lead + "SA1 (arg):   ", metrics
    
    metrics_tst = test_model_on_wav("./SX383.wav", "./train_output/SX383_train_epoch" + str(epoch),
                                  autoencoder, lead = lead, verbose = False, argmax = False)
    print lead + "SX383:       ", metrics_tst
    metrics = test_model_on_wav("./SX383.wav", "./train_output/SX383_train_epoch" + str(epoch),
                              autoencoder, lead = lead, verbose = False, argmax = True)
    print lead + "SX383 (arg): ", metrics
    
    val_pesq = evaluate_training(autoencoder, lead)
    if (val_pesq > best_val_pesq and entropy <= TARGET_ENTROPY):
        print lead + "NEW best model! Validation mean-PESQ", val_pesq

        print lead + "Saving model..."
        save_model()
        best_val_pesq = val_pesq
        patience_epoch = epoch
    else:
        print lead + "Best validation mean-PESQ seen:", best_val_pesq

    
    elapsed = time.time() - startTime
    print lead + "Total time for evaluation: " + str(elapsed) + "s"

Epoch 1:
    101120:  [1.843466 0.016325 1.103907 0.041010 0.044768] [1.843466 0.489743 1.103907 0.205048 0.044768] 0.0125 0.000999938936265                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
    Total time for epoch: 88.4746601582s
    ----------------
    Code entropy: 3.46722149312
    Updated tau from 0.0125 to 0.0250000001863
    -------------

        Train: (mean) [8348.027719 43.542152 3.259009]
        Train: (max)  [51647.621094 105.420296 3.834939]
        Train: (min)  [590.548889 15.285256 2.569921]
        Val:   (mean) [8957.331049 44.131442 3.357575]
        Val:   (max)  [66408.281250 100.465195 4.013525]
        Val:   (min)  [299.850403 9.698038 2.349842]
    Best validation mean-PESQ seen: 0.0
    Total time for evaluation: 13.4389770031s
Epoch 6:
    101120:  [1.578600 0.010117 0.900721 0.028269 0.233040] [1.578600 0.303495 0.900721 0.141343 0.233040] 0.075 0.000997784708034                                                                                                                                                                                                                                                                                                                                                                                                                                                            

    SA1 (arg):    [2997.2808, 34.609268, 3.238420009613037]
    SX383:        [3147.9775, 25.630877, 3.2412943840026855]
    SX383 (arg):  [3171.0674, 25.908075, 3.20552659034729]
    Format: [MSE, avg err, PESQ]
        Train: (mean) [6826.529619 40.206363 3.420714]
        Train: (max)  [49288.468750 114.222519 3.980240]
        Train: (min)  [844.431641 15.288953 1.981934]
        Val:   (mean) [7337.869198 41.188717 3.562171]
        Val:   (max)  [62555.281250 91.970467 4.133838]
        Val:   (min)  [281.905945 9.420243 2.414114]
    Best validation mean-PESQ seen: 0.0
    Total time for evaluation: 12.8149600029s
Epoch 11:
    101120:  [1.548590 0.009252 0.822200 0.022248 0.337586] [1.548590 0.277561 0.822200 0.111242 0.337586] 0.1375 0.00099256147673                                                                                                                                                                                                                                       

    101120:  [1.594662 0.009306 0.769592 0.021077 0.440515] [1.594662 0.279170 0.769592 0.105385 0.440515] 0.1875 0.000986194218544                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
    Total time for epoch: 86.0721609592s
    ----------------
    Code entropy: 2.54676670425
    Updated tau from 0.1875 to 0.200000029802
    ----------------
    Evalua

KeyboardInterrupt: 

In [None]:
if True:
    model = load_model('best_model.h5', KERAS_LOAD_MAP)
    autoencoder = load_model('best_auto.h5', KERAS_LOAD_MAP)
    encoder = autoencoder.layers[1]
    decoder = autoencoder.layers[2]

In [None]:
enc = model.layers[1].layers
dec = model.layers[2].layers

In [None]:
#for i in xrange(0, len(enc)):
#    print i, enc[i]

In [None]:
test_model_on_wav("./SA1.wav", "SA1_final", autoencoder)
test_model_on_wav("./SA1.wav", "SA1_final", autoencoder, argmax = True)

test_model_on_wav("./SX383.wav", "SX383_final", autoencoder)
test_model_on_wav("./SX383.wav", "SX383_final", autoencoder, argmax = True)

test_model_on_wav("./fiveYears.wav", "fy_final", autoencoder)
test_model_on_wav("./fiveYears.wav", "fy_final", autoencoder, argmax = True) 

In [None]:
all_embed = encoder.predict(X_train[:10000], batch_size = BATCH_SIZE, verbose = 1)

In [None]:
probs = np.reshape(all_embed, (all_embed.shape[0] * all_embed.shape[1], NBINS))
hist = np.sum(probs, axis = 0)
hist /= np.sum(hist)

sample_hist_bins = np.linspace(0, NBINS - 1, NBINS)
plt.bar(sample_hist_bins, hist, align = 'center', width = 1)
plt.show()

entropy = 0
for i in hist:
    if (i < 1e-4): continue
    entropy += i * math.log(i, 2)
entropy = -entropy
print "Entropy of distribution:", entropy

print "Bins:"
print K.eval(QUANT_BINS)

In [None]:
plt.plot(np.sort(np.array(K.eval(QUANT_BINS)).flatten()))
plt.show()

In [None]:
[rate, data] = sciwav.read("./SA1.wav")
data = data.astype(np.float32)
processedWave, wparams = preprocess_waveform(data)
windows = extract_windows(processedWave, STEP_SIZE, OVERLAP_SIZE)

transformed = np.reshape(windows, (windows.shape[0], WINDOW_SIZE, 1))
embed = encoder.predict(transformed, batch_size = BATCH_SIZE, verbose = 1)

In [None]:
recons = decoder.predict(embed, batch_size = BATCH_SIZE, verbose = 1)

In [None]:
K.eval(enc[-1].SOFTMAX_TEMP)

In [None]:
max_pct = np.max(embed[25], axis = -1)
print max_pct
print np.argmax(embed[25], axis = -1)
print np.sum(max_pct > 0.98) / float(max_pct.size)

In [None]:
embed_max = np.max(embed, axis = -1)
print np.mean(embed_max)
print np.sum(embed_max > 0.98) / float(embed_max.size)

In [None]:
idx = 25

orig = windows[idx].flatten()
recn = recons[idx].flatten()

print "Original"
plt.plot(orig)
ylim = plt.gca().get_ylim()
plt.show()

print "Reconstruction"
plt.plot(recn)
plt.ylim(ylim)
plt.show()

print "Code (argmax)"
argmax_code_vec = embed[idx]
embed_sum = np.sum(embed[idx], axis = -1)
argmax_code_vec = np.eye(NBINS)[np.argmax(argmax_code_vec, axis = -1)]
argmax_code_vec[embed_sum < 0.95] = np.zeros(NBINS)
argmax_code_vec = unquantize_vec(argmax_code_vec)
plt.plot(argmax_code_vec)
plt.show()

print "Code (non-argmax)"
na_code_vec = embed[idx]
na_code_vec = unquantize_vec(na_code_vec)
plt.plot(na_code_vec)
plt.show()

print "Difference"
plt.plot(abs(argmax_code_vec - na_code_vec))
plt.show()
    
print "Error"
plt.plot(abs(orig - recn))
plt.show()