# Custrom forward and backward path in TF

In [1]:
import math, random
import numpy as np
import tensorflow as tf
from util import suppress_tf_warning
from tensorflow.python.framework import ops
suppress_tf_warning()
print (tf.__version__)

1.15.0


<img src="custom_gradient.jpeg">

### Define custom operation

In [2]:
np.random.seed(seed=0)
W_ref,b_ref = np.random.rand(3,3),np.random.rand(3,1)
print('W_ref:\n',W_ref,'\nb_ref:\n',b_ref)
def f_unit(z): # f: z[3 x 1] -> x[3 x 1]
    x = np.matmul(W_ref.T,np.reshape(z,newshape=(3,1)))+b_ref # [3x1] column vector
    return x 
def g_unit(x): # g: x[3 x 1] -> y[2 x 1]
    x = np.reshape(x,newshape=(3,1))
    y = np.zeros(shape=(2,1))
    x0,x1,x2 = x[0,0],x[1,0],x[2,0]
    y[0,0] = +1*(x0**1) - 2*(x1**2) + 3*(x2**3)
    y[1,0] = -4*(x0**2) + 5*(x1**3) - 6*(x2**1)
    return y # [2x1] column vector
def g_grad_unit(x): # compute Jacobian x: [3]
    dy0_dx,dy1_dx = np.zeros(shape=(3)),np.zeros(shape=(3))
    dy0_dx[0] = +1
    dy0_dx[1] = -4*(x[1])
    dy0_dx[2] = +9*(x[2]**2)
    dy1_dx[0] = -8*(x[0])
    dy1_dx[1] = +15*(x[1]**2)
    dy1_dx[2] = -6
    return dy0_dx,dy1_dx # dy0_dx: [3], dy1_dx: [3]
def fg_batch(z): # z[n x 3] -> x[n x 3] -> y[n x 2]
    n = z.shape[0]
    y = np.zeros(shape=(n,2))
    for i in range(n):
        z_i = z[i,:]
        x_i = f_unit(z_i)
        y[i,:] = g_unit(x_i).T
    return y

W_ref:
 [[0.5488135  0.71518937 0.60276338]
 [0.54488318 0.4236548  0.64589411]
 [0.43758721 0.891773   0.96366276]] 
b_ref:
 [[0.38344152]
 [0.79172504]
 [0.52889492]]


In [3]:
def custom_g_func(x): # x -> y
    n = x.shape[0]
    y0,y1 = np.zeros(shape=(n,1)),np.zeros(shape=(n,1))
    for i in range(n):
        x_i= x[i,:]
        y_i = g_unit(x_i) # g unit custom function 
        y0[i,0],y1[i,0] = y_i[0,0],y_i[1,0] 
    return [y0.astype(np.float32),y1.astype(np.float32)]

def custom_g_derv(x,grads0,grads1): # x -> dy
    n = x.shape[0]
    dy0,dy1 = np.zeros(shape=(n,3)),np.zeros(shape=(n,3))
    for i in range(n):
        x_i = x[i,:]
        dy0_dx,dy1_dx = g_grad_unit(x_i)
        dy0[i,:] = grads0[i,0]*dy0_dx
        dy1[i,:] = grads1[i,0]*dy1_dx
    return [dy0.astype(np.float32),dy1.astype(np.float32)]

def grad_wrapper(op, grads0, grads1):
    x,out = op.inputs[0],op.outputs[0]
    temp = tf.py_func(func=custom_g_derv,inp=[x,grads0,grads1],
                      Tout=[tf.float32,tf.float32]) 
    dy0 = temp[0] 
    dy1 = temp[1]
    return dy0 + dy1

def py_func_wrapper(func, inp, Tout, stateful=True, name=None, grad=None):
    rnd_name = 'custom_gradient_sj' # gradient name (make sure to be unique)
    tf.RegisterGradient(rnd_name)(grad)
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name, "PyFuncStateless": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
    
print ("Ready.")

Ready.


### Training data

In [4]:
n_train = 1000
z_train = -3+6*np.random.rand(n_train,3) # [10000 x 3]
y_train = fg_batch(z_train) # [10000 x 2]
n_test = 100
z_test = -3+6*np.random.rand(n_test,3) # [100 x 3]
y_test = fg_batch(z_test) # [100 x 2]

### Define model

In [5]:
init = tf.random_normal_initializer(stddev=0.1)
W,b = tf.Variable(init(shape=[3,3])),tf.Variable(init(shape=[3]))
z_ph = tf.placeholder(tf.float32,shape=(None,3)) # input z
y_ph = tf.placeholder(tf.float32,shape=(None,2)) # target y
x_pred = tf.matmul(z_ph,W) + b # f: z->x
y_preds = py_func_wrapper(func=custom_g_func, # function
                          inp=[x_pred], # input
                          Tout=[tf.float32,tf.float32], # output type of 'custom_func'
                          name='custom_g_func',
                          grad=grad_wrapper)
y_pred_concat= tf.concat((y_preds[0],y_preds[1]),axis=1)
cost = tf.reduce_mean(tf.reduce_mean(tf.square(y_ph - y_pred_concat)))
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(cost)
print ("Ready.")

Ready.


### Train

In [6]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
max_iter = 10000
for it in range(max_iter):
    batch_idx = np.random.permutation(n_train)[:32]
    z_batch,y_batch = z_train[batch_idx,:],y_train[batch_idx,:]
    cost_val,_,y_pred_concat_val,W_val,b_val = sess.run([cost,optimizer,y_pred_concat,W,b], 
                                                        feed_dict={z_ph: z_batch, y_ph: y_batch})
    if (it==0) or (((it+1)%100)==0):
        cost_val = sess.run(cost,feed_dict={z_ph:z_test,y_ph:y_test})
        print ("Iter:[%d/%d] cost:[%.3e] "% (it+1,max_iter,cost_val,))
    if (it==0) or (((it+1)%1000)==0):
        print ('W_val:\n',W_val,'\nb_val:\n',b_val)
        print ('W_ref:\n',W_ref,'\nb_ref:\n',b_ref.T)
print ("Done.")

Iter:[1/10000] cost:[1.294e+04] 
W_val:
 [[ 0.13522008  0.18572854  0.00611573]
 [ 0.01448664  0.16701524  0.04245107]
 [-0.0743758   0.08850946  0.03756813]] 
b_val:
 [ 0.08928414 -0.04453174 -0.07976732]
W_ref:
 [[0.5488135  0.71518937 0.60276338]
 [0.54488318 0.4236548  0.64589411]
 [0.43758721 0.891773   0.96366276]] 
b_ref:
 [[0.38344152 0.79172504 0.52889492]]
Iter:[100/10000] cost:[1.233e+04] 
Iter:[200/10000] cost:[1.081e+04] 
Iter:[300/10000] cost:[8.584e+03] 
Iter:[400/10000] cost:[6.929e+03] 
Iter:[500/10000] cost:[6.248e+03] 
Iter:[600/10000] cost:[5.876e+03] 
Iter:[700/10000] cost:[5.617e+03] 
Iter:[800/10000] cost:[5.444e+03] 
Iter:[900/10000] cost:[5.225e+03] 
Iter:[1000/10000] cost:[3.421e+03] 
W_val:
 [[0.26110274 0.7488662  0.39169502]
 [0.38215727 0.4436264  0.7469938 ]
 [0.13420638 0.8218259  0.2869    ]] 
b_val:
 [ 0.04189656  0.60155076 -0.09727084]
W_ref:
 [[0.5488135  0.71518937 0.60276338]
 [0.54488318 0.4236548  0.64589411]
 [0.43758721 0.891773   0.96366276]]