In [1]:
import numpy as np

import random
import h5py
from keras import backend as K
from nn_util import *
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.models import load_model
from keras.losses import *
import multiprocessing
from sklearn.cluster import KMeans, MiniBatchKMeans

import os
import random
import time
import matplotlib
import matplotlib.pyplot as plt
import gc
import psutil
import math

# for reproducibility
np.random.seed(1337) 
random.seed(1337)

np.set_printoptions(formatter={'float_kind':'{:4f}'.format})

Using TensorFlow backend.


In [2]:
# control amount of GPU memory used
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))

In [3]:
# external custom code I wrote
from load_data import *
from windowing import *
from pesq import *
from consts import *
from nn_blocks import *
from perceptual_loss import *
from evaluation import *

In [4]:
[train_paths, val_paths, test_paths], \
[train_waveforms, val_waveforms, test_waveforms], \
[train_procwave, val_procwave, test_procwave], \
[train_wparams, val_wparams, test_wparams], \
[train_windows, val_windows, test_windows] = load_data(TRAIN_SIZE, VAL_SIZE, TEST_SIZE)

In [5]:
# flatten all of the train windows into vectors
train_processed = np.array([i for z in train_windows for i in z])
train_processed = np.reshape(train_processed, (train_processed.shape[0], WINDOW_SIZE,))

# randomly shuffle data, if we want to
if (RANDOM_SHUFFLE):
    train_processed = np.random.permutation(train_processed)
    
print train_processed.shape
print np.mean(train_processed, axis=None)
print np.std(train_processed, axis=None)
print np.min(train_processed, axis = None)
print np.max(train_processed, axis = None)

(205063, 512)
1.9761e-06
0.0991018
-1.0
1.0


In [6]:
CHANNEL_SIZE = WINDOW_SIZE / 2

# ---------------------------------------------------------------------------
# autoencoder: takes an audio window, compresses it, and tries to reconstruct it
# ---------------------------------------------------------------------------
def autoencoder_structure():   
    # - - - - - - - - - - - - - - - - - - - - -
    # parameters
    # - - - - - - - - - - - - - - - - - - - - -   
    NCHAN = 64
    FILT_SIZE = 9
    
    # feature extractor module, used in both encoder and decoder
    #     (structure is the same, but weights aren't shared)
    def feature_extractor():
        def f(inp):
            out = inp
            
            out = residual_block(NCHAN, FILT_SIZE, 1)(out)
            out = residual_block(NCHAN, FILT_SIZE, 2)(out)
            out = residual_block(NCHAN, FILT_SIZE, 4)(out)
            out = residual_block(NCHAN, FILT_SIZE, 8)(out)
            
            return out
        
        return f

    # - - - - - - - - - - - - - - - - - - - - -
    # encoder
    # - - - - - - - - - - - - - - - - - - - - -
    enc_input = Input(shape = (WINDOW_SIZE,))
    enc = Reshape((WINDOW_SIZE, 1))(enc_input)
    
    # processing steps
    enc = channel_change_block(NCHAN, FILT_SIZE)(enc)  
    enc = feature_extractor()(enc)
    enc = downsample_block(NCHAN, FILT_SIZE)(enc)
    enc = feature_extractor()(enc)
    enc = channel_change_block(1, FILT_SIZE)(enc)
    
    # quantization (real numbers => soft bin assignments)
    enc = SoftmaxQuantization()(enc)
    
    enc = Model(inputs = enc_input, outputs = enc, name = 'encoder')
    
    # - - - - - - - - - - - - - - - - - - - - -
    # decoder
    # - - - - - - - - - - - - - - - - - - - - -
    dec_input = Input(shape = (CHANNEL_SIZE, NBINS))
    dec = dec_input
    
    # "dequantization" (soft bin assignments => real numbers)
    dec = SoftmaxDequantization()(dec)
    
    # processing steps
    dec = channel_change_block(NCHAN, FILT_SIZE)(dec)
    dec = feature_extractor()(dec)
    dec = upsample_block(NCHAN, FILT_SIZE)(dec)
    dec = feature_extractor()(dec)
    dec = channel_change_block(1, FILT_SIZE)(dec)
    
    dec = Reshape((WINDOW_SIZE,))(dec)
    dec = Model(inputs = dec_input, outputs = dec, name = 'decoder')
    
    # return both encoder and decoder
    return enc, dec

In [7]:
# map for load_model
KERAS_LOAD_MAP = {'PhaseShiftUp1D' : PhaseShiftUp1D,
                  'code_entropy' : code_entropy,
                  'code_sparsity' : code_sparsity,
                  'rmse' : rmse,
                  'SoftmaxQuantization' : SoftmaxQuantization,
                  'SoftmaxDequantization' : SoftmaxDequantization,
                  'DFT_REAL' : DFT_REAL,
                  'DFT_IMAG' : DFT_IMAG,
                  'MEL_FILTERBANKS' : MEL_FILTERBANKS,
                  'keras_dft_mag' : keras_dft_mag,
                  'keras_dct' : keras_dct,
                  'perceptual_transform' : perceptual_transform,
                  'perceptual_distance' : perceptual_distance}

In [8]:
# construct autoencoder
ac_input = Input(shape = (WINDOW_SIZE,))

encoder, decoder = autoencoder_structure()
ac_reconstructed = decoder(encoder(ac_input))
autoencoder = Model(inputs = [ac_input], outputs = [ac_reconstructed])

In [9]:
# model parameters
loss_weights = [60.0, 5.0, 10.0, 1.0]
loss_functions = [rmse, perceptual_distance, code_sparsity, code_entropy]
n_recons = 2
n_code = 2
assert(n_recons + n_code == len(loss_weights))
assert(len(loss_weights) == len(loss_functions))

In [10]:
# model specification
model_input = Input(shape = (WINDOW_SIZE,))
model_embedding = encoder(model_input)
model_reconstructed = decoder(model_embedding)

model = Model(inputs = [model_input], outputs = [model_reconstructed] * n_recons + \
                                            [model_embedding] * n_code)

  ' Found: ' + str(self.outputs))


In [11]:
model.compile(loss = loss_functions,
              loss_weights = loss_weights,
              optimizer = Adam())

model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 512)               0         
_________________________________________________________________
encoder (Model)              (None, 256, 32)           742448.0  
_________________________________________________________________
decoder (Model)              (None, 512)               816399    
Total params: 1,558,847
Trainable params: 1,558,847
Non-trainable params: 0
_________________________________________________________________


In [12]:
# model without any training
test_model_on_wav("./SA1.wav", "", autoencoder,
                  save_recons = False)

MSE:         154624.0
Avg err:     210.227
PESQ:        1.0355900526


[154623.94, 210.22746, 1.0355900526046753]

In [13]:
# saves current model
def save_model(prefix = 'best'):
    os.system('rm ./' + prefix + '_model.h5')
    os.system('rm ./' + prefix + '_coder.h5')
    
    model.save('./' + prefix + '_model.h5')
    autoencoder.save('./' + prefix + '_coder.h5')
    
    f = h5py.File('./' + prefix + '_model.h5', 'r+')
    try:
        del f['optimizer_weights']
    except Exception:
        pass
    f.close()

In [14]:
def evaluate_training(autoencoder, lead = ""):
    def set_evaluation(paths, eval_idxs):
        before_after_pairs = [run_model_on_wav(paths[i],
                                               autoencoder,
                                               argmax = True)
                              for i in eval_idxs]
        
        def thread_func(my_id, q):
            my_slice = np.arange(my_id, len(eval_idxs), NUM_THREADS)
            
            for i in my_slice:
                p = before_after_pairs[i]
                q.put(evaluation_metrics(p[0], p[1]))
        
        q = multiprocessing.Queue()
        threads = [multiprocessing.Process(target = thread_func,
                                           args = (i, q))
                   for i in xrange(0, NUM_THREADS)]
        [t.start() for t in threads]
        [t.join() for t in threads]
        results = np.array([q.get() for i in xrange(0, len(eval_idxs))])
        
        return results
    
    train_eval_idxs = random.sample(range(0, len(train_paths)), TRAIN_EVALUATE)
    val_eval_idxs = random.sample(range(0, len(val_paths)), VAL_EVALUATE)
    
    print lead + "Format: [MSE, avg err, PESQ]"
    
    # train set evaluation
    train_metrics = set_evaluation(train_paths,
                                   train_eval_idxs)
    print lead + "    Train: (mean)", np.mean(train_metrics, axis = 0)
    print lead + "    Train: (max) ", np.max(train_metrics, axis = 0)
    print lead + "    Train: (min) ", np.min(train_metrics, axis = 0)
    
    # validation set evaluation
    val_metrics = set_evaluation(val_paths,
                                 val_eval_idxs)
    print lead + "    Val:   (mean)", np.mean(val_metrics, axis = 0)
    print lead + "    Val:   (max) ", np.max(val_metrics, axis = 0)
    print lead + "    Val:   (min) ", np.min(val_metrics, axis = 0)
    
    # returns mean PESQ on validation
    return np.mean(val_metrics, axis = 0)[2]

In [15]:
X_train = np.copy(train_processed)
ntrain = X_train.shape[0]

NUM_EPOCHS = 300
EPOCHS_BEFORE_QUANT_ON = 10

ORIG_BITRATE = 256.00
TARGET_BITRATE = 19.85
PRE_ENTROPY_RATE = ORIG_BITRATE * (float(CHANNEL_SIZE) / WINDOW_SIZE)
TARGET_BITRATE_FUZZ = 0.45

def bitrate_to_entropy(bitrate):
    entropy = (bitrate / PRE_ENTROPY_RATE * 16.0)
    entropy *= (STEP_SIZE / float(WINDOW_SIZE))
    return entropy

def entropy_to_bitrate(entropy):
    bitrate = entropy / (STEP_SIZE / float(WINDOW_SIZE))
    bitrate = bitrate * PRE_ENTROPY_RATE / 16.0
    return bitrate

LOWER_BITRATE = TARGET_BITRATE - TARGET_BITRATE_FUZZ
UPPER_BITRATE = TARGET_BITRATE + TARGET_BITRATE_FUZZ
LOWER_ENTROPY = bitrate_to_entropy(LOWER_BITRATE)
UPPER_ENTROPY = bitrate_to_entropy(UPPER_BITRATE)

TAU_CHANGE_RATE = 0.025
INITIAL_TAU = 0.5
TAU_DECAY_EPOCHS = 5

NUM_QUANT_VECS = 5000

STARTING_LR = 0.00025
ENDING_LR = 0.0001

print "Target bitrate range:", LOWER_BITRATE, "to", UPPER_BITRATE, "(kbps)"
print "Target entropy range:", LOWER_ENTROPY, "to", UPPER_ENTROPY, "(bits)"

Target bitrate range: 19.4 to 20.3 (kbps)
Target entropy range: 2.2734375 to 2.37890625 (bits)


In [None]:
best_val_pesq = 0.0
tau_decay_ctr = 0
K.set_value(tau, 0.0)
T_i = 0.0
K.set_value(QUANTIZATION_ON, False)

In [None]:
np.set_printoptions(formatter={'float_kind':'{:4f}'.format})
lead = "    "

for epoch in range(1, NUM_EPOCHS + 1):
    print "Epoch " + str(epoch) + ":"

    # present batches randomly each epoch
    lis = range(0, ntrain, BATCH_SIZE)
    random.shuffle(lis)
    num_batches = len(lis)
    
    # keep track of start time and current batch #
    i = 0
    startTime = time.time()
    for idx in lis:
        # cosine annealing for model's learning rate
        train_pct = T_i / float(NUM_EPOCHS)
        opt_lr = ENDING_LR + 0.5 * (STARTING_LR - ENDING_LR) * (1 + math.cos(3.14159 * train_pct))
        T_i += (1.0 / num_batches)
        K.set_value(model.optimizer.lr, opt_lr)
        
        batch = X_train[idx:idx+BATCH_SIZE, :]
        nbatch = batch.shape[0]
               
        # train autoencoder
        a_y = [batch] * n_recons + \
              [np.zeros((nbatch, 1, 1))] * n_code

        a_losses = model.train_on_batch(batch, a_y)
        
        # print statistics every 10 batches so we know what's going on
        if (i % 10 == 0):
            printStr = "        \r" + lead + str(i * BATCH_SIZE) + ": "
            print printStr,
            
            loss_arr = np.asarray(a_losses)
            print loss_arr,
            
            if (len(loss_weights) > 1 and len(loss_arr) > 1):
                for w in xrange(0, len(loss_weights)):
                    loss_arr[w + 1] *= loss_weights[w]
                print loss_arr,
            
            print K.get_value(tau), opt_lr,
        
        i += 1
    print ""
    
    # print elapsed time for epoch
    elapsed = time.time() - startTime
    print lead + "Total time for epoch: " + str(elapsed) + "s"
    
    # ---------------------------------------------------------
    # estimate network bitrate and code entropy from random samples
    #     (only if quantization is on)
    # ---------------------------------------------------------
    if (K.get_value(QUANTIZATION_ON) > 0):
        NUM = 20000
        rows = np.random.randint(X_train.shape[0], size = NUM)
        to_predict = np.copy(X_train[rows, :])
        code = encoder.predict(to_predict, verbose = 0, batch_size = 128)
        
        all_onehots = np.reshape(code, (-1, NBINS))
        onehot_hist = np.sum(all_onehots, axis = 0)
        onehot_hist /= np.sum(onehot_hist)

        entropy = 0
        for i in onehot_hist:
            if (i < 1e-5): continue
            entropy += i * math.log(i, 2)
        entropy = -entropy
        
        bitrate = entropy_to_bitrate(entropy)

        print lead + "----------------"
        print lead + "Code entropy:", entropy
        print lead + "     bitrate:", bitrate

        # ---------------------------------------------------------
        # handle updating entropy weight (tau)
        # ---------------------------------------------------------
        updated_tau = False
        old_tau = K.get_value(tau)
        
        if (bitrate > UPPER_BITRATE):
            new_tau = old_tau + TAU_CHANGE_RATE
            updated_tau = True
        elif (bitrate < LOWER_BITRATE):
            new_tau = old_tau - TAU_CHANGE_RATE
            if (new_tau < 0): new_tau = 0
            updated_tau = True
        
        #tau_decay_ctr += 1
        #if (tau_decay_ctr == TAU_DECAY_EPOCHS):
        #    tau_decay_ctr = 0
        #    new_tau = old_tau - TAU_CHANGE_RATE
        #    updated_tau = True
        
        if (updated_tau):
            K.set_value(tau, new_tau)
            print lead + "Updated tau from", old_tau, "to", new_tau
        else:
            print lead + "Tau stays at", old_tau
    
    # ---------------------------------------------------------
    # evaluate autoencoder on training/validation data evey epoch
    # ---------------------------------------------------------
    startTime = time.time()
    print lead + "----------------"
    print lead + "Evaluating autoencoder..."
    
    
    metrics = test_model_on_wav("./SA1.wav", "./train_output/SA1_train_epoch" + str(epoch),
                                autoencoder, lead = lead, verbose = False, argmax = False)
    print lead + "SA1:         ", metrics
    if (K.get_value(QUANTIZATION_ON) > 0):
        metrics = test_model_on_wav("./SA1.wav", "./train_output/SA1_train_epoch" + str(epoch),
                                    autoencoder, lead = lead, verbose = False, argmax = True)
        print lead + "SA1 (arg):   ", metrics
    
    metrics = test_model_on_wav("./SX383.wav", "./train_output/SX383_train_epoch" + str(epoch),
                                autoencoder, lead = lead, verbose = False, argmax = False)
    print lead + "SX383:       ", metrics
    if (K.get_value(QUANTIZATION_ON) > 0):
        metrics = test_model_on_wav("./SX383.wav", "./train_output/SX383_train_epoch" + str(epoch),
                                    autoencoder, lead = lead, verbose = False, argmax = True)
        print lead + "SX383 (arg): ", metrics
    
    if (K.get_value(QUANTIZATION_ON) > 0):
        val_pesq = evaluate_training(autoencoder, lead)
        if (val_pesq > best_val_pesq and bitrate <= UPPER_BITRATE):
            print lead + "NEW best model! Validation mean-PESQ", val_pesq

            print lead + "Saving model..."
            save_model()
            best_val_pesq = val_pesq
            patience_epoch = epoch
        else:
            print lead + "Best validation mean-PESQ seen:", best_val_pesq
    
    elapsed = time.time() - startTime
    print lead + "Total time for evaluation: " + str(elapsed) + "s"
    
    gc.collect()
    process = psutil.Process(os.getpid())
    mem_used = process.memory_info().rss
    print lead + "Total memory usage: " + str(mem_used)
    
    # ---------------------------------------------------------
    # turn quantization on after a certain # of epochs
    # ---------------------------------------------------------
    if (epoch == EPOCHS_BEFORE_QUANT_ON):
        print lead + "----------------"
        print lead + "Turning quantization on!"
        
        random_windows = []
        for i in xrange(0, NUM_QUANT_VECS):
            w_idx = random.randint(0, train_processed.shape[0] - 1)
            random_windows.append(train_processed[w_idx])

        random_windows = np.array(random_windows)
        print lead + "    Selecting random code vectors for analysis..."
        encoded_windows = encoder.predict(random_windows, batch_size = 128, verbose = 0)
        encoded_windows = encoded_windows[:, :, 0]
        encoded_windows = np.reshape(encoded_windows, (-1, 1))

        print lead + "    K means clustering for bins initialization..."
        km = MiniBatchKMeans(n_clusters = NBINS).fit(encoded_windows)
        clustered_bins = km.cluster_centers_.flatten()
        
        cluster_score = np.sqrt(np.median(np.min(km.transform(encoded_windows), axis = 1)))
        print lead + "    Done. Cluster score:", cluster_score
        
        K.set_value(QUANTIZATION_ON, True)
        K.set_value(QUANT_BINS, clustered_bins)
        K.set_value(tau, INITIAL_TAU)

Epoch 1:
    204800:  [2.833011 0.009559 0.451896 0.000000 0.000000] [2.833011 0.573530 2.259481 0.000000 0.000000] 0.0 0.000249995903087                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               

    204800:  [1.921737 0.006930 0.301186 0.000000 0.000000] [1.921737 0.415807 1.505930 0.000000 0.000000] 0.0 0.000249897292192                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

    204800:  [1.835961 0.007995 0.271256 0.000000 0.000000] [1.835961 0.479680 1.356281 0.000000 0.000000] 0.0 0.000249667286219                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

    SA1 (arg):    [3597.7932, 40.604294, 2.6593685150146484]
    SX383:        [2672.3713, 27.074221, 3.437892436981201]
    SX383 (arg):  [2934.0796, 29.694372, 2.8886590003967285]
    Format: [MSE, avg err, PESQ]
        Train: (mean) [8382.027496 49.441410 3.139492]
        Train: (max)  [40426.378906 114.234764 3.708477]
        Train: (min)  [855.461731 18.144606 2.202637]
        Val:   (mean) [7475.431203 46.568637 3.180356]
        Val:   (max)  [52338.433594 107.018097 3.940063]
        Val:   (min)  [352.602600 11.397947 2.246979]
    NEW best model! Validation mean-PESQ 3.18035595179
    Saving model...
    Total time for evaluation: 18.109208107s
    Total memory usage: 4848881664
Epoch 13:
    204800:  [4.632167 0.012859 0.438544 0.069064 0.977297] [4.632167 0.771511 2.192722 0.690636 0.977297] 0.45 0.000249306288675                                                                                                                                                               

    204800:  [4.114027 0.010744 0.409489 0.058116 0.840780] [4.114027 0.644638 2.047443 0.581165 0.840780] 0.375 0.000248949949663                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      

    204800:  [4.064592 0.009420 0.402346 0.049820 0.989462] [4.064592 0.565199 2.011728 0.498202 0.989462] 0.4 0.000248520631103                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

    204800:  [4.151114 0.010807 0.392621 0.040564 1.133967] [4.151114 0.648397 1.963106 0.405643 1.133967] 0.475 0.000248018756679                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      

    204800:  [4.384796 0.011555 0.396255 0.036948 1.340710] [4.384796 0.693329 1.981273 0.369484 1.340710] 0.55 0.000247444821679                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

    204800:  [4.356576 0.010564 0.423479 0.032837 1.276977] [4.356576 0.633835 2.117395 0.328369 1.276977] 0.575 0.000246799392507                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      

    204800:  [4.328568 0.011659 0.403100 0.032318 1.290324] [4.328568 0.699562 2.015500 0.323182 1.290324] 0.55 0.000246083106123                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

    204800:  [4.282331 0.010983 0.402643 0.032490 1.285211] [4.282331 0.659004 2.013213 0.324903 1.285211] 0.525 0.000245296669414                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      

    204800:  [4.337799 0.013003 0.374838 0.031573 1.367695] [4.337799 0.780181 1.874189 0.315733 1.367695] 0.55 0.000244440858496                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

    204800:  [4.240456 0.011743 0.386132 0.028715 1.318049] [4.240456 0.704599 1.930662 0.287145 1.318049] 0.55 0.000243516517958                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

    204800:  [4.214359 0.010560 0.387397 0.029383 1.349939] [4.214359 0.633603 1.936983 0.293833 1.349939] 0.55 0.000242524559988                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       

    204800:  [4.383275 0.013144 0.403354 0.026799 1.309906] [4.383275 0.788612 2.016769 0.267988 1.309906] 0.575 0.000241465963549                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      