In [85]:
import sys
import os
import tensorflow as tf
import numpy as np
import librosa
import math
sys.path.append('../')
from wavenet.model_ibab import WaveNetModel
from wavenet.ops import mu_law_encode, mu_law_decode
from IPython.display import Audio
from time import time

In [23]:
def generator_conv(model, input_batch, state_batch, weights):
    '''Perform convolution for a single convolutional processing step.'''
    
    output = tf.matmul(state_batch[0], weights[0])
    
    i = 0 # This value will be used when filter width == 2
    for i in range(1, model.filter_width-1):
        output += tf.matmul(state_batch[i], weights[i])
    
    output += tf.matmul(input_batch, weights[i+1])
    
    return output

def generator_causal_layer(model, input_batch, state_batch):
    with tf.name_scope('causal_layer'):
        weights_filter = model.variables['causal_layer']['filter']
        output = generator_conv(model, input_batch, state_batch, weights_filter)
    return output

def generator_dilation_layer(model, input_batch, state_batch, layer_index,
                              dilation, local_condition_batch, global_condition_batch):
    variables = model.variables['dilated_stack'][layer_index]

    weights_filter = variables['filter']
    weights_gate = variables['gate']
    output_filter = generator_conv(model, input_batch, state_batch, weights_filter)
    output_gate = generator_conv(model, input_batch, state_batch, weights_gate)

    if local_condition_batch is not None:
        output_filter += tf.matmul(local_condition_batch, variables['cond_filter'][0, :, :])
        output_gate += tf.matmul(local_condition_batch, variables['cond_gate'][0, :, :])        

    if global_condition_batch is not None:
        global_condition_batch = tf.reshape(global_condition_batch,
                                            shape=(1, -1))
        output_filter += tf.matmul(global_condition_batch,
                                   variables['gc_filtweights'][0, :, :])
        output_gate += tf.matmul(global_condition_batch,
                                 variables['gc_gateweights'][0, :, :])

    if model.use_biases:
        output_filter = output_filter + variables['filter_bias']
        output_gate = output_gate + variables['gate_bias']

    out = tf.tanh(output_filter) * tf.sigmoid(output_gate)

    weights_dense = variables['dense']
    transformed = tf.matmul(out, weights_dense[0, :, :])
    if model.use_biases:
        transformed = transformed + variables['dense_bias']

    weights_skip = variables['skip']
    skip_contribution = tf.matmul(out, weights_skip[0, :, :])
    if model.use_biases:
        skip_contribution = skip_contribution + variables['skip_bias']

    return skip_contribution, input_batch + transformed


def create_queue(model, dilation, n_channel, name=None):
    shape = (dilation * (model.filter_width - 1), model.batch_size, n_channel)
    value = tf.zeros(shape, dtype=tf.float32)
    return tf.Variable(value, trainable=False, name=name)

def create_q_ops(model):
    wrapper = []
    q = create_queue(model, 1, model.quantization_channels, "Q_L0")

    wrapper.append(q)
    # Add all defined dilation layers.
    with tf.name_scope('dilated_stack'):
        for layer_index, dilation in enumerate(model.dilations):
            with tf.name_scope('layer{}'.format(layer_index)):
                q = create_queue(model, dilation, model.residual_channels, "Q_L{}".format(layer_index+1))
                wrapper.append(q)
    
    return wrapper

def update(q, current_q_idx, dil, filter_width, x):
    # dequeue
    for i in range(1, filter_width - 1):
        q = tf.scatter_update(q, current_q_idx + i - 1, 
                                 q[current_q_idx + i])  
        
    # enqueue
    q = tf.scatter_update(q, current_q_idx + (filter_width - 2), x)  
 
    return q

def create_update_q_ops(model, qs, initial, others, gen_num):
    dilation = 1
    current_q_idx = 0
    
    q = qs[0]
    q = update(q, current_q_idx, dilation, model.filter_width, 
                tf.reshape(initial, [model.batch_size, model.quantization_channels], name="up_reshape0"))
    qs[0] = q
    
    for layer_index, dilation in enumerate(model.dilations):
        q = qs[layer_index + 1]
        current_q_idx = (gen_num % dilation) * (model.filter_width - 1)
        q = update(q, current_q_idx, dilation, model.filter_width, 
                    tf.reshape(others[layer_index], [model.batch_size, model.residual_channels], 
                               name="up_reshape{}".format(layer_index + 1)))
        qs[layer_index + 1] = q
    
    return qs
    
def predict(model, qs, x, c, g, gen_num):
    outputs = []
    layers = []
    
    q = qs[0]
    
    dilation = 1
    current_q_idx = 0
    current_data_idx = current_q_idx + (model.filter_width - 1)
    
    current_layer = x
#     current_layer = tf.Print(current_layer, [current_layer], message="current_layer0:")
    current_state = q[current_q_idx:current_data_idx]
    layers.append(current_layer)
    
    current_layer = generator_causal_layer(model, current_layer, current_state)
    
    # Add all defined dilation layers.
    with tf.name_scope('dilated_stack'):
        for layer_index, dilation in enumerate(model.dilations):
            with tf.name_scope('layer{}'.format(layer_index)):
                q = qs[layer_index + 1]
                current_q_idx = (gen_num % dilation) * (model.filter_width - 1)
                current_data_idx = current_q_idx + (model.filter_width - 1)
                layers.append(current_layer)
                
                current_state = q[current_q_idx:current_data_idx]
#                 current_layer = tf.Print(current_layer, [current_layer], 
#                                          message="current_layer{}:".format(layer_index + 1))
                
                output, current_layer = generator_dilation_layer(model,
                    current_layer, current_state, layer_index, dilation, c, g)
                outputs.append(output)
                
    with tf.name_scope('postprocessing'):
        variables = model.variables['postprocessing']
        # Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to
        # postprocess the output.
        w1 = variables['postprocess1']
        w2 = variables['postprocess2']
        if model.use_biases:
            b1 = variables['postprocess1_bias']
            b2 = variables['postprocess2_bias']

        # We skip connections from the outputs of each layer, adding them
        # all up here.
        total = sum(outputs)
        transformed1 = tf.nn.relu(total)

        conv1 = tf.matmul(transformed1, w1[0, :, :])
        if model.use_biases:
            conv1 = conv1 + b1
        transformed2 = tf.nn.relu(conv1)
        conv2 = tf.matmul(transformed2, w2[0, :, :])
        if model.use_biases:
            conv2 = conv2 + b2    
            
    return conv2, layers

In [114]:
tf.reset_default_graph()
batch_size = 1
filter_width = 3
n_stack = 2
max_dilation = 8
dilations = [2 ** i for j in range(n_stack) for i in range(max_dilation)]
print("dilations:", dilations)
residual_channels, dilation_channels, skip_channels = 256, 256, 256
use_biases = True
quantization_channels = 256
gc_cardinality = None
gc_channels = None
scalar_input = False
initial_filter_width = None

net = WaveNetModel(batch_size=batch_size,
                        dilations=dilations,
                        filter_width=filter_width,
                        scalar_input = scalar_input,
                        initial_filter_width=32,
                        residual_channels=residual_channels,
                        dilation_channels=dilation_channels,
                        quantization_channels=quantization_channels,
                        skip_channels=skip_channels,
                        global_condition_channels=gc_channels,
                        global_condition_cardinality=gc_cardinality,
                        use_biases=use_biases)

waveform = tf.placeholder(tf.float32)
ml_encoded = mu_law_encode(waveform, quantization_channels)
encoded = net._one_hot(ml_encoded)
encoded = tf.reshape(encoded, [batch_size, -1, quantization_channels])

raw = net._create_network(encoded, None)
raw = tf.reshape(raw, [batch_size, -1, quantization_channels])
proba = tf.cast(tf.nn.softmax(tf.cast(raw, tf.float64)), tf.float32)

# for generation
sample_placeholder = tf.placeholder(tf.int32)
encoded_sample = net._one_hot(sample_placeholder)
encoded_sample = tf.reshape(encoded_sample, [-1, quantization_channels])
gen_num = tf.placeholder(tf.int32)
qs = create_q_ops(net)

next_sample, layers_out = predict(net, qs, encoded_sample, None, None, gen_num)
next_sample = tf.reshape(next_sample, [-1, quantization_channels])
next_sample = tf.cast(tf.nn.softmax(tf.cast(next_sample, tf.float64)), tf.float32)

initial = tf.placeholder(tf.float32)
others = tf.placeholder(tf.float32)
update_q_ops = create_update_q_ops(net, qs, initial, others, gen_num)

var_q = list(filter(lambda var : var.name.split('/')[-1].startswith("Q_L"), 
                    tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
print("created.")

dilations: [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128]
created.


In [106]:
data = []
for j in range(batch_size):
     data.append(np.concatenate([np.zeros(net.receptive_field), np.random.randn(10)]))
data = np.vstack(data)

sess_config = tf.ConfigProto(
    device_count = {'GPU': 0}
)
with tf.Session(config=sess_config) as sess:
    sess.run(tf.global_variables_initializer())
    
    result, _encoded = sess.run([proba, ml_encoded], 
                                   feed_dict={waveform:data})
    _encoded = _encoded.reshape(batch_size, -1)
    result = np.argmax(result, axis=-1)

    sess.run(tf.variables_initializer(var_q))
    
    t = time()
    samples= []
    for j in range(net.receptive_field-1):
        feed_dict = {sample_placeholder:_encoded[:,j], gen_num:j}
        prob, _layers = sess.run([next_sample, layers_out], feed_dict=feed_dict)
        sess.run(update_q_ops, feed_dict={initial:_layers[0], others:np.array(_layers[1:]), gen_num:j})
#         print("current_layer:", prob)

    for j in range(net.receptive_field-1, _encoded.shape[-1]):
        feed_dict = {sample_placeholder:_encoded[:,j], gen_num:j}
        prob, _layers = sess.run([next_sample, layers_out], feed_dict=feed_dict)
        sess.run(update_q_ops, feed_dict={initial:_layers[0], others:np.array(_layers[1:]), gen_num:j})
#         print("current_layer:", prob)
        sample = np.argmax(prob, axis=-1)
        samples.append(sample)
    samples = np.array(samples).T
    print("elapsed:", time()-t)

print("result:", result)
print("generated samples:", samples)
print("difference between result and samples:", np.abs(result-samples).sum())

elapsed: 9.298052072525024
result: [[225 225 246  57  34 246 225 143 225 225 246]
 [225 225 225 225 225 225 143 246 202 246 246]]
generated samples: [[225 225 246  57  34 246 225 143 225 225 246]
 [225 225 225 225 225 225 143 246 202 246 246]]
difference between result and samples: 0


In [115]:
src, _ = librosa.load("voice.wav", sr=16000)
n_samples = len(src)
src = src.reshape(-1, 1)
src = np.pad(src, [[net.receptive_field, 0], [0, 0]],'constant')
data = []
for j in range(batch_size):
     data.append(src)
data = np.vstack(data)

sess_config = tf.ConfigProto(
    device_count = {'GPU': 0}
)
with tf.Session(config=sess_config) as sess:
    sess.run(tf.global_variables_initializer())
    result, _encoded = sess.run([proba, ml_encoded], feed_dict={waveform:data})
    _encoded = _encoded.reshape(batch_size, -1)
    result = np.argmax(result, axis=-1)
    
    sess.run(tf.variables_initializer(var_q))
    
    t = time()
    samples= []
    for j in range(net.receptive_field-1):
        feed_dict = {sample_placeholder:_encoded[:,j], gen_num:j}
        prob, _layers = sess.run([next_sample, layers_out], feed_dict=feed_dict)
        sess.run(update_q_ops, feed_dict={initial:_layers[0], others:np.array(_layers[1:]), gen_num:j})
#         print("current_layer:", prob)

    for j in range(net.receptive_field-1, _encoded.shape[-1]):
        feed_dict = {sample_placeholder:_encoded[:,j], gen_num:j}
        prob, _layers = sess.run([next_sample, layers_out], feed_dict=feed_dict)
        sess.run(update_q_ops, feed_dict={initial:_layers[0], others:np.array(_layers[1:]), gen_num:j})
#         print("current_layer:", prob)
        sample = np.argmax(prob, axis=-1)
        samples.append(sample)
    samples = np.array(samples).T
    print("elapsed:", time()-t)
    
print("difference between result and samples:", np.abs(result-samples).sum())

elapsed: 109.17055201530457
difference between result and samples: 0
