In [None]:
# %load generate.py

import argparse
from time import time

from datetime import datetime
import json
import os
import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

from wavenet import WaveNet

SAMPLES = 16000
LOGDIR = './logdir'
WINDOW = 80000
WAVENET_PARAMS = './wavenet_params.json'

%matplotlib inline

In [None]:
class Args(object):
    def __init__(self):
        pass

args = Args()
args.checkpoint = './model.ckpt-0'
args.samples = 1000
args.logdir = ''
args.window = 512
args.wavenet_params = WAVENET_PARAMS
args.wav_out_path = None
args.fast_generation = True

In [None]:
logdir = os.path.join(args.logdir, 'train', str(datetime.now()))
with open(args.wavenet_params, 'r') as config_file:
    wavenet_params = json.load(config_file)

sess = tf.Session()

net = WaveNet(
    1,
    wavenet_params['quantization_steps'],
    wavenet_params['dilations'],
    wavenet_params['filter_width'],
    wavenet_params['residual_channels'],
    wavenet_params['dilation_channels'],
    fast_generation=args.fast_generation)

samples = tf.placeholder(tf.int32)


# next_sample = net.predict_proba(samples)
next_sample, push_ops = net.predict_proba(samples)


sess.run(tf.initialize_all_variables())

variables_to_restore = {var.name[:-2]: var for var in tf.all_variables() if 'Variable' in var.name}
saver = tf.train.Saver(variables_to_restore)
print('Restoring model from {}'.format(args.checkpoint))
saver.restore(sess, args.checkpoint)

In [None]:
quantization_steps = wavenet_params['quantization_steps']
# waveform = np.random.randint(quantization_steps, size=(1,)).tolist()
waveform = [20]

for n_i in range(10):
    if args.fast_generation:
        window = waveform[-1]
    else:
        if len(waveform) > args.window:
            window = waveform[-args.window:]
        else:
            window = waveform

    print window

    feed_dict = {samples: window}

    layers_ = sess.run(net.current_layers, feed_dict)
    outputs_ = sess.run(net.outputs, feed_dict)

    for i in range(len(outputs_)):
        if args.fast_generation:
            print 'layer {}:'.format(i), layers_[i][0, 0:2]
            print 'output {}:'.format(i), outputs_[i][0, 0:2]
        else:
        # print np.all(np.equal(layers_[i] + outputs_[i], layers_[i+1]))
            print 'layer {}:'.format(i), layers_[i][0, n_i, 0:2]
            print 'output {}:'.format(i), outputs_[i][0, n_i, 0:2]


    prediction = sess.run(next_sample, feed_dict)
    print 'prediction:', prediction[0:2]
    sample = np.argmax(prediction)
    waveform.append(sample)

    pointers = [sess.run(var) for var in tf.all_variables() if 'pointer' in var.name]
    buffers = [sess.run(var) for var in tf.all_variables() if 'state_buffer' in var.name]
    for item_i, (pointer, buffer_) in enumerate(zip(pointers, buffers)):
        print 'state {}:'.format(item_i), pointer, buffer_[pointer, 0, 0:2]
    
    _ = sess.run(push_ops, feed_dict)
    print 'run'

    pointers = [sess.run(var) for var in tf.all_variables() if 'pointer' in var.name]
    buffers = [sess.run(var) for var in tf.all_variables() if 'state_buffer' in var.name]
    for item_i, (pointer, buffer_) in enumerate(zip(pointers, buffers)):
        print 'state {}:'.format(item_i), pointer, buffer_[pointer, 0, 0:2]

    print 'sample:', sample
    print '---'

print waveform
plt.plot(waveform)

In [None]:
oh_dear_god = tf.placeholder(tf.float32, shape=(1, 32))

q = Queue(batch_size=1,
          state_size=32,
          buffer_size=1)

current_state = q.pop()
push = q.push(oh_dear_god)

print sess.run(push, feed_dict={oh_dear_god: np.ones((1, 32))*np.random.randn()})

In [None]:

times = []
# for step in range(args.samples):
for step in range(1):
    if args.fast_generation:
        window = waveform[-1]
        outputs = [next_sample]
        outputs.extend(push_ops)
        tic = time()
        outputs_list = sess.run(
            outputs,
            feed_dict={samples: window})
        prediction = outputs_list[0]
        toc = time()
    else:
        if len(waveform) > args.window:
            window = waveform[-args.window:]
        else:
            window = waveform
        outputs = [next_sample]
        tic = time()
        prediction = sess.run(
            next_sample,
            feed_dict={samples: window})
        toc = time()
    times.append(toc-tic)
    print 'Average sample took {} seconds.'.format(np.mean(times))
    sample = np.argmax(prediction)
    print sample
    # sample = np.random.choice(np.arange(quantization_steps), p=prediction)
    waveform.append(sample)
    print waveform
    print('Sample {:3<d}/{:3<d}: {}'.format(step + 1, args.samples, sample))
# plt.plot(waveform)

In [None]:
vars = [var for var in tf.all_variables() if 'state_buffer' in var.name or 'pointer' in var.name]

In [None]:
sess.run(vars[0])

In [None]:
saver = tf.train.Saver()

checkpoint_path = os.path.join('.', 'model.ckpt')
print('Storing checkpoint to {}'.format(checkpoint_path))
saver.save(sess, checkpoint_path, global_step=0)

In [None]:
sess.run(vars[2])

In [None]:
vars = {var.name: var.get_shape() for var in tf.all_variables() if 'state_buffer' in var.name or 'pointer' in var.name}

In [None]:
for key, value in vars.iteritems():
    print key, value