In [1]:
import numpy as np

import random
import h5py
from keras import backend as K
from nn_util import *
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.models import load_model
from keras.losses import *
import multiprocessing
from sklearn.cluster import KMeans, MiniBatchKMeans

import os
import random
import time
import matplotlib
import matplotlib.pyplot as plt
import gc
import psutil
import math

# for reproducibility
np.random.seed(1337) 
random.seed(1337)

np.set_printoptions(formatter={'float_kind':'{:4f}'.format})

Using TensorFlow backend.


In [2]:
# control amount of GPU memory used
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
set_session(tf.Session(config=config))

In [3]:
# external custom code I wrote
from load_data import *
from windowing import *
from pesq import *
from consts import *
from nn_blocks import *
from perceptual_loss import *
from evaluation import *

In [4]:
[train_paths, val_paths, test_paths], \
[train_waveforms, val_waveforms, test_waveforms], \
[train_procwave, val_procwave, test_procwave], \
[train_wparams, val_wparams, test_wparams], \
[train_windows, val_windows, test_windows] = load_data(TRAIN_SIZE, VAL_SIZE, TEST_SIZE)

In [5]:
# flatten all of the train windows into vectors
train_processed = np.array([i for z in train_windows for i in z])
train_processed = np.reshape(train_processed, (train_processed.shape[0], WINDOW_SIZE,))

# randomly shuffle data, if we want to
if (RANDOM_SHUFFLE):
    train_processed = np.random.permutation(train_processed)
    
print train_processed.shape
print np.mean(train_processed, axis=None)
print np.std(train_processed, axis=None)
print np.min(train_processed, axis = None)
print np.max(train_processed, axis = None)

(101814, 512)
2.44142e-06
0.0994002
-1.0
1.0


In [6]:
CHANNEL_SIZE = WINDOW_SIZE / 2

# ---------------------------------------------------------------------------
# autoencoder: takes an audio window, compresses it, and tries to reconstruct it
# ---------------------------------------------------------------------------
def autoencoder_structure():   
    # - - - - - - - - - - - - - - - - - - - - -
    # parameters
    # - - - - - - - - - - - - - - - - - - - - -   
    NCHAN = 64
    FILT_SIZE = 9
    
    # feature extractor module, used in both encoder and decoder
    #     (structure is the same, but weights aren't shared)
    def feature_extractor():
        def f(inp):
            out = inp
            
            out = residual_block(NCHAN, FILT_SIZE, 1)(out)
            out = residual_block(NCHAN, FILT_SIZE, 2)(out)
            out = residual_block(NCHAN, FILT_SIZE, 4)(out)
            out = residual_block(NCHAN, FILT_SIZE, 8)(out)
            
            return out
        
        return f

    # - - - - - - - - - - - - - - - - - - - - -
    # encoder
    # - - - - - - - - - - - - - - - - - - - - -
    enc_input = Input(shape = (WINDOW_SIZE,))
    enc = Reshape((WINDOW_SIZE, 1))(enc_input)
    
    # processing steps
    enc = channel_change_block(NCHAN, FILT_SIZE)(enc)  
    enc = feature_extractor()(enc)
    enc = downsample_block(NCHAN, FILT_SIZE)(enc)
    enc = feature_extractor()(enc)
    enc = channel_change_block(1, FILT_SIZE)(enc)
    
    # quantization (real numbers => soft bin assignments)
    enc = SoftmaxQuantization()(enc)
    
    enc = Model(inputs = enc_input, outputs = enc, name = 'encoder')
    
    # - - - - - - - - - - - - - - - - - - - - -
    # decoder
    # - - - - - - - - - - - - - - - - - - - - -
    dec_input = Input(shape = (CHANNEL_SIZE, NBINS))
    dec = dec_input
    
    # "dequantization" (soft bin assignments => real numbers)
    dec = SoftmaxDequantization()(dec)
    
    # processing steps
    dec = channel_change_block(NCHAN, FILT_SIZE)(dec)
    dec = feature_extractor()(dec)
    dec = upsample_block(NCHAN, FILT_SIZE)(dec)
    dec = feature_extractor()(dec)
    dec = channel_change_block(1, FILT_SIZE)(dec)
    
    dec = Reshape((WINDOW_SIZE,))(dec)
    dec = Model(inputs = dec_input, outputs = dec, name = 'decoder')
    
    # return both encoder and decoder
    return enc, dec

In [7]:
# map for load_model
KERAS_LOAD_MAP = {'PhaseShiftUp1D' : PhaseShiftUp1D,
                  'code_entropy' : code_entropy,
                  'code_sparsity' : code_sparsity,
                  'rmse' : rmse,
                  'SoftmaxQuantization' : SoftmaxQuantization,
                  'SoftmaxDequantization' : SoftmaxDequantization,
                  'DFT_REAL' : DFT_REAL,
                  'DFT_IMAG' : DFT_IMAG,
                  'MEL_FILTERBANKS' : MEL_FILTERBANKS,
                  'keras_dft_mag' : keras_dft_mag,
                  'keras_dct' : keras_dct,
                  'perceptual_transform' : perceptual_transform,
                  'perceptual_distance' : perceptual_distance}

In [8]:
# construct autoencoder
ac_input = Input(shape = (WINDOW_SIZE,))

encoder, decoder = autoencoder_structure()
ac_reconstructed = decoder(encoder(ac_input))
autoencoder = Model(inputs = [ac_input], outputs = [ac_reconstructed])

In [9]:
# model parameters
loss_weights = [30.0, 5.0, 10.0, 1.0]
loss_functions = [rmse, perceptual_distance, code_sparsity, code_entropy]
n_recons = 2
n_code = 2
assert(n_recons + n_code == len(loss_weights))
assert(len(loss_weights) == len(loss_functions))

In [10]:
# model specification
model_input = Input(shape = (WINDOW_SIZE,))
model_embedding = encoder(model_input)
model_reconstructed = decoder(model_embedding)

model = Model(inputs = [model_input], outputs = [model_reconstructed] * n_recons + \
                                            [model_embedding] * n_code)

  ' Found: ' + str(self.outputs))


In [11]:
model.compile(loss = loss_functions,
              loss_weights = loss_weights,
              optimizer = Adam())

model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 512)               0         
_________________________________________________________________
encoder (Model)              (None, 256, 32)           742448.0  
_________________________________________________________________
decoder (Model)              (None, 512)               816399    
Total params: 1,558,847
Trainable params: 1,558,847
Non-trainable params: 0
_________________________________________________________________


In [12]:
# model without any training
test_model_on_wav("./SA1.wav", "", autoencoder,
                  save_recons = False)

MSE:         154586.0
Avg err:     210.205
PESQ:        1.05026233196


[154586.23, 210.20505, 1.0502623319625854]

In [13]:
# saves current model
def save_model(prefix = 'best'):
    os.system('rm ./' + prefix + '_model.h5')
    os.system('rm ./' + prefix + '_coder.h5')
    
    model.save('./' + prefix + '_model.h5')
    autoencoder.save('./' + prefix + '_coder.h5')
    
    f = h5py.File('./' + prefix + '_model.h5', 'r+')
    del f['optimizer_weights']
    f.close()

In [14]:
def evaluate_training(autoencoder, lead = ""):
    def set_evaluation(paths, eval_idxs):
        before_after_pairs = [run_model_on_wav(paths[i],
                                               autoencoder,
                                               argmax = True)
                              for i in eval_idxs]
        
        def thread_func(my_id, q):
            my_slice = np.arange(my_id, len(eval_idxs), NUM_THREADS)
            
            for i in my_slice:
                p = before_after_pairs[i]
                q.put(evaluation_metrics(p[0], p[1]))
        
        q = multiprocessing.Queue()
        threads = [multiprocessing.Process(target = thread_func,
                                           args = (i, q))
                   for i in xrange(0, NUM_THREADS)]
        [t.start() for t in threads]
        [t.join() for t in threads]
        results = np.array([q.get() for i in xrange(0, len(eval_idxs))])
        
        return results
    
    train_eval_idxs = random.sample(range(0, len(train_paths)), TRAIN_EVALUATE)
    val_eval_idxs = random.sample(range(0, len(val_paths)), VAL_EVALUATE)
    
    print lead + "Format: [MSE, avg err, PESQ]"
    
    # train set evaluation
    train_metrics = set_evaluation(train_paths,
                                   train_eval_idxs)
    print lead + "    Train: (mean)", np.mean(train_metrics, axis = 0)
    print lead + "    Train: (max) ", np.max(train_metrics, axis = 0)
    print lead + "    Train: (min) ", np.min(train_metrics, axis = 0)
    
    # validation set evaluation
    val_metrics = set_evaluation(val_paths,
                                 val_eval_idxs)
    print lead + "    Val:   (mean)", np.mean(val_metrics, axis = 0)
    print lead + "    Val:   (max) ", np.max(val_metrics, axis = 0)
    print lead + "    Val:   (min) ", np.min(val_metrics, axis = 0)
    
    # returns mean PESQ on validation
    return np.mean(val_metrics, axis = 0)[2]

In [15]:
X_train = np.copy(train_processed)
ntrain = X_train.shape[0]

NUM_EPOCHS = 300
EPOCHS_BEFORE_QUANT_ON = 10

ORIG_BITRATE = 256.00
TARGET_BITRATE = 16.00
PRE_ENTROPY_RATE = ORIG_BITRATE * (float(CHANNEL_SIZE) / WINDOW_SIZE)

TARGET_ENTROPY = (TARGET_BITRATE / PRE_ENTROPY_RATE * 16.0)
TARGET_ENTROPY *= (STEP_SIZE / float(WINDOW_SIZE))
TARGET_ENTROPY_FUZZ = 0.1

TAU_CHANGE_RATE = 0.025
INITIAL_TAU = 0.5

NUM_QUANT_VECS = 5000

STARTING_LR = 0.00025
ENDING_LR = 0.0001

print "Target entropy:", TARGET_ENTROPY

Target entropy: 1.875


In [16]:
best_val_pesq = 0.0
K.set_value(tau, 0.0)
T_i = 0.0
K.set_value(QUANTIZATION_ON, False)

In [22]:
np.set_printoptions(formatter={'float_kind':'{:4f}'.format})
lead = "    "

for epoch in range(200, NUM_EPOCHS + 1):
    print "Epoch " + str(epoch) + ":"

    # present batches randomly each epoch
    lis = range(0, ntrain, BATCH_SIZE)
    random.shuffle(lis)
    num_batches = len(lis)
    
    # keep track of start time and current batch #
    i = 0
    startTime = time.time()
    for idx in lis:
        # cosine annealing for model's learning rate
        train_pct = T_i / float(NUM_EPOCHS)
        opt_lr = ENDING_LR + 0.5 * (STARTING_LR - ENDING_LR) * (1 + math.cos(3.14159 * train_pct))
        T_i += (1.0 / num_batches)
        K.set_value(model.optimizer.lr, opt_lr)
        
        batch = X_train[idx:idx+BATCH_SIZE, :]
        nbatch = batch.shape[0]
               
        # train autoencoder
        a_y = [batch] * n_recons + \
              [np.zeros((nbatch, 1, 1))] * n_code

        a_losses = model.train_on_batch(batch, a_y)
        
        # print statistics every 10 batches so we know what's going on
        if (i % 10 == 0):
            printStr = "        \r" + lead + str(i * BATCH_SIZE) + ": "
            print printStr,
            
            loss_arr = np.asarray(a_losses)
            print loss_arr,
            
            if (len(loss_weights) > 1 and len(loss_arr) > 1):
                for w in xrange(0, len(loss_weights)):
                    loss_arr[w + 1] *= loss_weights[w]
                print loss_arr,
            
            print K.get_value(tau), opt_lr,
        
        i += 1
    print ""
    
    # print elapsed time for epoch
    elapsed = time.time() - startTime
    print lead + "Total time for epoch: " + str(elapsed) + "s"
    
    # ---------------------------------------------------------
    # estimate code entropy from random samples (if quantization is on)
    # ---------------------------------------------------------
    if (K.get_value(QUANTIZATION_ON) > 0):
        NUM = 20000
        rows = np.random.randint(X_train.shape[0], size = NUM)
        to_predict = np.copy(X_train[rows, :])
        code = encoder.predict(to_predict, verbose = 0, batch_size = 128)
        
        all_onehots = np.reshape(code, (-1, NBINS))
        onehot_hist = np.sum(all_onehots, axis = 0)
        onehot_hist /= np.sum(onehot_hist)

        entropy = 0
        for i in onehot_hist:
            if (i < 1e-5): continue
            entropy += i * math.log(i, 2)
        entropy = -entropy

        print lead + "----------------"
        print lead + "Code entropy:", entropy

        # ---------------------------------------------------------
        # handle updating entropy weight (tau)
        # ---------------------------------------------------------
        old_tau = K.get_value(tau)

        if (entropy < TARGET_ENTROPY - TARGET_ENTROPY_FUZZ):
            new_tau = old_tau - TAU_CHANGE_RATE
            if (new_tau < 0.0):
                new_tau = 0.0

            K.set_value(tau, new_tau)
            print lead + "Updated tau from", old_tau, "to", new_tau
        elif (entropy > TARGET_ENTROPY + TARGET_ENTROPY_FUZZ):
            new_tau = old_tau + TAU_CHANGE_RATE

            K.set_value(tau, new_tau)
            print lead + "Updated tau from", old_tau, "to", new_tau
        else:
            print lead + "Tau stays at", old_tau
    
    # ---------------------------------------------------------
    # evaluate autoencoder on training/validation data evey epoch
    # ---------------------------------------------------------
    startTime = time.time()
    print lead + "----------------"
    print lead + "Evaluating autoencoder..."
    
    
    metrics = test_model_on_wav("./SA1.wav", "./train_output/SA1_train_epoch" + str(epoch),
                                autoencoder, lead = lead, verbose = False, argmax = False)
    print lead + "SA1:         ", metrics
    if (K.get_value(QUANTIZATION_ON) > 0):
        metrics = test_model_on_wav("./SA1.wav", "./train_output/SA1_train_epoch" + str(epoch),
                                    autoencoder, lead = lead, verbose = False, argmax = True)
        print lead + "SA1 (arg):   ", metrics
    
    metrics = test_model_on_wav("./SX383.wav", "./train_output/SX383_train_epoch" + str(epoch),
                                autoencoder, lead = lead, verbose = False, argmax = False)
    print lead + "SX383:       ", metrics
    if (K.get_value(QUANTIZATION_ON) > 0):
        metrics = test_model_on_wav("./SX383.wav", "./train_output/SX383_train_epoch" + str(epoch),
                                    autoencoder, lead = lead, verbose = False, argmax = True)
        print lead + "SX383 (arg): ", metrics
    
    if (K.get_value(QUANTIZATION_ON) > 0):
        val_pesq = evaluate_training(autoencoder, lead)
        if (val_pesq > best_val_pesq and entropy <= TARGET_ENTROPY):
            print lead + "NEW best model! Validation mean-PESQ", val_pesq

            print lead + "Saving model..."
            save_model()
            best_val_pesq = val_pesq
            patience_epoch = epoch
        else:
            print lead + "Best validation mean-PESQ seen:", best_val_pesq
    
    elapsed = time.time() - startTime
    print lead + "Total time for evaluation: " + str(elapsed) + "s"
    
    gc.collect()
    process = psutil.Process(os.getpid())
    mem_used = process.memory_info().rss
    print lead + "Total memory usage: " + str(mem_used)
    
    # ---------------------------------------------------------
    # turn quantization on after a certain # of epochs
    # ---------------------------------------------------------
    if (epoch == EPOCHS_BEFORE_QUANT_ON):
        print lead + "----------------"
        print lead + "Turning quantization on!"
        
        random_windows = []
        for i in xrange(0, NUM_QUANT_VECS):
            w_idx = random.randint(0, train_processed.shape[0] - 1)
            random_windows.append(train_processed[w_idx])

        random_windows = np.array(random_windows)
        print lead + "    Selecting random code vectors for analysis..."
        encoded_windows = encoder.predict(random_windows, batch_size = 128, verbose = 0)
        encoded_windows = encoded_windows[:, :, 0]
        encoded_windows = np.reshape(encoded_windows, (-1, 1))

        print lead + "    K means clustering for bins initialization..."
        km = MiniBatchKMeans(n_clusters = NBINS).fit(encoded_windows)
        clustered_bins = km.cluster_centers_.flatten()
        
        cluster_score = np.sqrt(np.median(np.min(km.transform(encoded_windows), axis = 1)))
        print lead + "    Done. Cluster score:", cluster_score
        
        K.set_value(QUANTIZATION_ON, True)
        K.set_value(QUANT_BINS, clustered_bins)
        K.set_value(tau, INITIAL_TAU)

 Epoch 200:
    101120:  [3.596668 0.014527 0.349130 0.015673 1.258474] [3.596668 0.435816 1.745652 0.156725 1.258474] 0.65 0.000100001251954                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
    Total time for epoch: 143.117285967s
    ----------------
    Code entropy: 1.86098487643
    Tau stays at 0.65
    ----------------
    Evaluating

        Train: (mean) [8430.802546 45.592243 4.019145]
        Train: (max)  [94693.328125 178.642487 4.442926]
        Train: (min)  [997.685608 17.455538 3.207587]
        Val:   (mean) [8728.262581 48.246197 4.152999]
        Val:   (max)  [50232.695312 108.515091 4.426256]
        Val:   (min)  [402.424957 11.418186 3.734662]
    Best validation mean-PESQ seen: 4.16707407475
    Total time for evaluation: 13.0283670425s
    Total memory usage: 3809247232
Epoch 205:
    101120:  [3.508477 0.012874 0.346714 0.015989 1.228805] [3.508477 0.386218 1.733568 0.159885 1.228805] 0.65 0.00010008135523                                                                                                                                                                                                                                                                                                                                                                                                              

    101120:  [3.459387 0.012263 0.351895 0.015315 1.178873] [3.459387 0.367888 1.759474 0.153151 1.178873] 0.65 0.000100293316758                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
    Total time for epoch: 142.83237505s
    ----------------
    Code entropy: 1.8505464285
    Tau stays at 0.65
    ----------------
    Evaluating autoencoder...
    SA1

        Train: (mean) [7084.563668 44.706720 4.102064]
        Train: (max)  [36891.125000 117.277687 4.443181]
        Train: (min)  [809.097229 17.476883 3.324991]
        Val:   (mean) [8789.030560 48.533211 4.144813]
        Val:   (max)  [49869.933594 108.544952 4.468220]
        Val:   (min)  [414.291870 11.534764 3.257351]
    Best validation mean-PESQ seen: 4.17526584387
    Total time for evaluation: 13.4500999451s
    Total memory usage: 3807772672
Epoch 214:
    101120:  [3.593310 0.014070 0.362025 0.015520 1.205894] [3.593310 0.422095 1.810125 0.155196 1.205894] 0.65 0.000100742507277                                                                                                                                                                                                                                                                                                                                                                                                             

    101120:  [3.563709 0.013026 0.361854 0.015299 1.210686] [3.563709 0.390769 1.809268 0.152985 1.210686] 0.65 0.000101248482252                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
    Total time for epoch: 142.858081102s
    ----------------
    Code entropy: 1.83853412217
    Tau stays at 0.65
    ----------------
    Evaluating autoencoder...
    SA1:

        Train: (mean) [8938.434194 48.531702 4.011656]
        Train: (max)  [44758.710938 106.473244 4.439744]
        Train: (min)  [1055.224976 17.525166 3.065242]
        Val:   (mean) [8830.951830 48.625332 4.130303]
        Val:   (max)  [50206.984375 108.688995 4.411980]
        Val:   (min)  [426.688904 11.657377 3.265307]
    Best validation mean-PESQ seen: 4.17526584387
    Total time for evaluation: 13.2996869087s
    Total memory usage: 3808145408
Epoch 223:
    101120:  [3.452085 0.010976 0.355138 0.015470 1.192415] [3.452085 0.329281 1.775691 0.154699 1.192415] 0.65 0.000102062772974                                                                                                                                                                                                                                                                                                                                                                                                            

    101120:  [3.397980 0.012594 0.340823 0.014552 1.170539] [3.397980 0.377808 1.704114 0.145519 1.170539] 0.65 0.000102858270334                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
    Total time for epoch: 142.140394926s
    ----------------
    Code entropy: 1.85905619574
    Tau stays at 0.65
    ----------------
    Evaluating autoencoder...
    SA

KeyboardInterrupt: 

In [20]:
save_model('haha')


