## Nonlinear Regression using Tensorflow

In [None]:
import matplotlib
#matplotlib.use('Agg')
from __future__ import division
%matplotlib inline

import tensorflow as tf
import numpy as np
import scipy.signal as sig
from scipy import interpolate

import sys
sys.path.append('../MockData/')
sys.path.append('../pynoisesub/')
from mock_noise import f, starting_data
#from subtraction_plots import plot_results

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.colors
from params import *
from tf_2dhelpers import *

# these options give you more info, but slow it down
print_loss = False
animate = False

In [None]:
# get data
times, darm, wit1, wit2, data = starting_data(filt=True)
# format data for model input
wit1 = wit1.astype(np.float32)
wit2 = wit2.astype(np.float32)
fs = int(1/times[1])
b,a = sig.butter(4, 2*30.0/fs, btype='highpass')
darm_scaled = (sig.filtfilt(b,a,darm)/NOISE_LEVEL).astype(np.float32)
data_scaled = (sig.filtfilt(b,a,data)/NOISE_LEVEL).astype(np.float32)

# bandpass to use for loss calculation
fir_bandpass = sig.firwin(N+1, [50, 400], pass_zero=False, nyq = fs/2.0)

# lattice parameters (what are these ?) 
lattice_res = 100
epsilon     = 1e-5
n           = 10

animate = 'True'

In [None]:
def generate_grids(weights, phases):
    # weights: nHidden1 x n x 2n
    # phases: nHidden1 x n x 2n
    # out: nHidden1 x lattice_res x lattice_res
    # each grid is sum over n x 2n fourier terms, so before contraction the
    # grids tensor has shape [nHidden1, n, 2n, lattice_res, lattice_res]
    coords = tf.linspace(-1 + epsilon, 1-epsilon, lattice_res)
    xs = tf.reshape(coords, [1, 1, 1, lattice_res, 1])
    ys = tf.reshape(coords, [1, 1, 1, 1, lattice_res])
    fx = tf.reshape(tf.linspace(0.0, n-1, n), [1, n, 1, 1, 1])
    fy = tf.reshape(tf.linspace(-n*1.0, n-1, 2*n), [1, 1, 2*n, 1, 1])

    phi = tf.reshape(phases, [-1, n, 2*n, 1, 1])
    amp = tf.reshape(weights, [-1, n, 2*n, 1, 1])
    grids = amp*tf.sin(np.pi*(fx*xs + fy*ys)/2 + phi) # rank 5
    # sum up terms to make grids, reshape for matmul compatibility later
    out = tf.reshape(tf.reduce_sum(grids, [1,2]), [-1, lattice_res*lattice_res])
    return out

In [None]:
def init_weights(shape, init_method='xavier', xavier_params = (None, None)):
    if init_method == 'zeros':
        return tf.Variable(tf.zeros(shape, dtype=tf.float32))
    elif init_method == 'uniform':
        return tf.Variable(
            tf.random_normal(shape, stddev=0.01, dtype=tf.float32))
    else: #xavier
        (fan_in, fan_out) = xavier_params
        low = -4*np.sqrt(6.0/(fan_in + fan_out)) # {sigmoid:4, tanh:1}
        high = 8*np.sqrt(6.0/(fan_in + fan_out))
        return tf.Variable(tf.random_uniform(shape, minval=low, maxval=high))

In [None]:
def multilayer_perceptron(weights, biases):
    # rank 3 tensor [nHidden1, lattice_res, lattice_res]
    layer_1 = generate_grids(weights['h1'], biases['b1']) # rank 3 tensor
    layer_2 = tf.add(tf.matmul(weights['h2'], layer_1), biases['b2'])
    # layer_2 = tf.nn.tanh(layer_2)

    out_layer = tf.add(tf.matmul(weights['out'], layer_2), biases['out'])
    return out_layer

In [None]:
nHidden1 = 3
nHidden2 = 1

weights =  {
    'h1': init_weights(
        [nHidden1, n, 2*n],
        # 'zeros'),
        'uniform'),
        # 'xavier',
        # xavier_params=(2*n*n, nHidden1)),
    'h2': init_weights(
        [nHidden2, nHidden1],
        # 'zeros'),
        'xavier',
        xavier_params=(nHidden1, nHidden2)),
     'out': init_weights(
        [1, nHidden2],
        # 'zeros')
        'xavier',
        xavier_params=(nHidden2,1))
}

biases = {
    # 'b1': init_weights([nHidden1, n, 2*n],'xavier', xavier_params=(n*n, nHidden1)),
    # 'b1': init_weights([nHidden1, n, 2*n],'zeros'),
    'b1': init_weights([nHidden1, n, 2*n],'uniform'),
    'b2': init_weights([nHidden2, 1],'zeros'),
    'out': init_weights([1, 1],'zeros')
}

witness1 = tf.placeholder(tf.float32, [None])
witness2 = tf.placeholder(tf.float32, [None])
target   = tf.placeholder(tf.float32, [None])
bandpass = tf.placeholder(tf.float32, [None])

nonlinearWitness = tf.reshape(
    multilayer_perceptron(weights, biases),
    [lattice_res, lattice_res])
interpolated_wit = extract_wit(nonlinearWitness, witness1, witness2)

# loss = tf.reduce_mean(tf.abs(target-interpolated_wit), 0)
filt         = tf_wiener_fir(target, interpolated_wit)
filtered_wit = tf_filt(interpolated_wit, filt)
resid        = tf.slice(target, [N], [-1]) - filtered_wit
loss         = tf.reduce_mean(tf.abs(tf_filt(resid, bandpass)), 0)

# anytime you change the loss function, the learning rate needs to change
optimizer    = tf.train.AdamOptimizer(learning_rate=5e-2, epsilon=1e-9)
train_step   = optimizer.minimize(loss)
init         = tf.initialize_all_variables()

## Run the Tensorflow Regression thing

In [None]:
# Launch the graph.
sess  = tf.Session()
sess.run(init)
nIter = 20
if(animate):
    frames = np.zeros((nIter+1, lattice_res, lattice_res))
if(print_loss):
    best = sess.run(
        loss,
        feed_dict={
            witness1:wit1,
            witness2:wit2,
            target:darm_scaled,
            bandpass:fir_bandpass})
# perfect subtraction would just leave data
goal = sess.run(
    loss,
    feed_dict={
        witness1:np.zeros(wit1.size),
        witness2:np.zeros(wit2.size),
        target:data_scaled,
        bandpass:fir_bandpass})
print "Loss goal:", goal

for i in xrange(nIter):
    if(animate):
        frames[i,:,:] = nonlinearWitness.eval(session=sess)
    sess.run(
        train_step,
        feed_dict={
            witness1:wit1,
            witness2:wit2,
            target:darm_scaled,
            bandpass:fir_bandpass})
    print "ITERATION:", i
    # It's nice to have loss updates, but this does actually slow it down
    if(print_loss):
        loss_i = sess.run(
            loss,
            feed_dict={
                witness1:wit1,
                witness2:wit2,
                target:darm_scaled,
                bandpass:fir_bandpass})
        if loss_i < best:
            print loss_i, "**New best**"
            best = loss_i
            # best_lattice = nonlinearWitness.eval(session=sess)
        else:
            print loss_i


# get final model outputs
end_lattice = nonlinearWitness.eval(session=sess)
end_loss = sess.run(
    loss,
    feed_dict={
        witness1:wit1,
        witness2:wit2,
        target:darm_scaled,
        bandpass:fir_bandpass})
end_wit = interpolated_wit.eval(
    session=sess,
    feed_dict={witness1:wit1,
               witness2:wit2,
               target:darm_scaled,
               bandpass:fir_bandpass})

print "Final loss:", end_loss

# make 'perfect solution' lattice for comparison
lattice_real = np.zeros((lattice_res, lattice_res), dtype = np.float32)

coords  = np.linspace(-1+epsilon, 1-epsilon, lattice_res)
xs_grid = np.arctanh(coords)*scale_factor
ys_grid = np.arctanh(coords)*scale_factor

for i in range (0, lattice_res):
    lattice_real[:,i] = f(xs_grid, ys_grid[i])

# save data (maybe pickle them into a single file?)
np.save("tfend_darm", darm)
np.save("tfend_wit1", wit1)
np.save("tfend_wit2", wit2)
np.save("tfend_data", data)
np.save("tfend_wit_result", end_wit)
np.save("tfend_lattice_result", end_lattice)
    
#if __name__ == "__main__":
#    main()

In [None]:
# plot lattice values as images
plt.figure(figsize=(13,5))
plt.subplot(1,2,1)
plt.pcolor(np.real(end_lattice), cmap=matplotlib.cm.inferno)
plt.colorbar()

plt.subplot(1,2,2)
plt.pcolor(lattice_real, cmap=matplotlib.cm.inferno)
plt.colorbar()
plt.show()
plt.savefig("result_tf.png")

In [None]:
# make animation
if(animate):
    frames[-1, :, :] = end_lattice
    print 'Animating...'
    Writer = animation.writers['ffmpeg']
    writer = Writer(metadata=dict(artist='Me'), bitrate=1800)
    fig = plt.figure(figsize=(11,11))
    def plot_frame(frame):
        plt.pcolor(frame, cmap=matplotlib.cm.inferno)
    anim = animation.FuncAnimation(
        fig, plot_frame, frames, interval=500)
    anim.save('lattice_evolution.mp4', writer=writer)
    print 'Done.'
# plot ASDs and coherence
#plot_results(
#    darm, wit1, wit2, end_wit, data,
#    file_end="_2dft", plot_all=True, filt=True)