# Loading and Re-training a TensorFlow Model

Load the part of the graph that contains the add routine.  Add another regression layer to the graph.  We will train the additional regression step later.

In [2]:
import tensorflow as tf

reg_saver = tf.train.import_meta_graph('/Users/patrickmedina/Desktop/regression/model-891.meta')
graph = tf.get_default_graph()

tf_x = graph.get_tensor_by_name("x:0")
tf_y = graph.get_tensor_by_name("y:0")
y_hat = graph.get_tensor_by_name("regression:0")

# extend the graph
a_1 = tf.Variable(0.5, dtype=tf.float32, name="a_1")
b_1 = tf.Variable(0.0, dtype=tf.float32, name="b_1")
y_hat_2 = tf.add(tf.multiply(a_1, y_hat), b_1, name="regression_2")

# define a new loss function
loss = tf.subtract(y_hat_2, tf_y) ** 2
loss = tf.reduce_mean(loss)

ao = tf.train.AdamOptimizer(learning_rate=1e-2, name='Adam_2')
ts = ao.minimize(loss)


We can reload the variables from the initial training and run inference on the first graph.

In [3]:
sess = tf.Session()
reg_saver.restore(sess, "/Users/patrickmedina/Desktop/regression/model-891")
result = sess.run(y_hat, feed_dict={tf_x: [0], tf_y: [0]})[0]
print("[INFO] Prediction for value 0 is {0:4f}.  The true value for 0 is 5.".format(result))

INFO:tensorflow:Restoring parameters from /Users/patrickmedina/Desktop/regression/model-891


[INFO] Prediction for value 0 is 4.817059.  The true value for 0 is 5.


We can initialize the variables for the remaining parts of the graph and perform inference on the extended model.

In [4]:
sess.run(tf.global_variables_initializer())
result = sess.run(y_hat_2, feed_dict={tf_x: [0], tf_y: [0]})
print(result)

[0.]


In [5]:
import numpy as np

reg_saver.save(sess, "/Users/patrickmedina/Desktop/regression_2/model", 0)
    
# generate the data
x = np.random.normal(size=(100, )).astype(np.float32)
y = 10 * x + 5 + np.random.normal(size=(100, )).astype(np.float32)

for ep in range(100):
    print("[INFO] Starting epoch {}".format(ep))
    idx = np.arange(100)
    np.random.shuffle(idx)
    x = x[idx]
    y = y[idx]

    for i in range(10):
        start = 10 * i
        stop = start + 10
        feed_dict = {tf_x:x[start:stop], tf_y:y[start:stop]}
        
        l, _ = sess.run([loss, ts], feed_dict=feed_dict)

        if i % 10 == 0:
            print("[INFO] Loss at step {0}: {1}".format(i, l))

reg_saver.save(sess, "/Users/patrickmedina/Desktop/regression_2/model", ep * i)

sess.close()

[INFO] Starting epoch 0
[INFO] Loss at step 0: 88.03340911865234
[INFO] Starting epoch 1
[INFO] Loss at step 0: 120.70565032958984
[INFO] Starting epoch 2
[INFO] Loss at step 0: 36.99068069458008
[INFO] Starting epoch 3
[INFO] Loss at step 0: 36.51053237915039
[INFO] Starting epoch 4
[INFO] Loss at step 0: 97.08419036865234
[INFO] Starting epoch 5
[INFO] Loss at step 0: 74.45780944824219
[INFO] Starting epoch 6
[INFO] Loss at step 0: 40.493648529052734
[INFO] Starting epoch 7
[INFO] Loss at step 0: 57.15752029418945
[INFO] Starting epoch 8
[INFO] Loss at step 0: 29.602497100830078
[INFO] Starting epoch 9
[INFO] Loss at step 0: 22.47392463684082
[INFO] Starting epoch 10
[INFO] Loss at step 0: 20.718765258789062
[INFO] Starting epoch 11
[INFO] Loss at step 0: 12.054939270019531
[INFO] Starting epoch 12
[INFO] Loss at step 0: 5.797328472137451
[INFO] Starting epoch 13
[INFO] Loss at step 0: 5.908519744873047
[INFO] Starting epoch 14
[INFO] Loss at step 0: 5.344995975494385
[INFO] Starting

[INFO] Starting epoch 35
[INFO] Loss at step 0: 1.8382459878921509
[INFO] Starting epoch 36
[INFO] Loss at step 0: 0.42135682702064514
[INFO] Starting epoch 37
[INFO] Loss at step 0: 1.2868313789367676
[INFO] Starting epoch 38
[INFO] Loss at step 0: 1.6082706451416016
[INFO] Starting epoch 39
[INFO] Loss at step 0: 0.5982078313827515
[INFO] Starting epoch 40
[INFO] Loss at step 0: 1.115146279335022
[INFO] Starting epoch 41
[INFO] Loss at step 0: 0.7623821496963501
[INFO] Starting epoch 42
[INFO] Loss at step 0: 1.119942545890808
[INFO] Starting epoch 43
[INFO] Loss at step 0: 0.9253376722335815
[INFO] Starting epoch 44
[INFO] Loss at step 0: 0.6657480597496033
[INFO] Starting epoch 45
[INFO] Loss at step 0: 0.733609676361084
[INFO] Starting epoch 46
[INFO] Loss at step 0: 1.1765916347503662
[INFO] Starting epoch 47
[INFO] Loss at step 0: 2.3074824810028076
[INFO] Starting epoch 48
[INFO] Loss at step 0: 1.1137111186981201
[INFO] Starting epoch 49
[INFO] Loss at step 0: 2.52878713607788

[INFO] Starting epoch 68
[INFO] Loss at step 0: 0.5194608569145203
[INFO] Starting epoch 69
[INFO] Loss at step 0: 1.2160893678665161
[INFO] Starting epoch 70
[INFO] Loss at step 0: 0.19496949017047882
[INFO] Starting epoch 71
[INFO] Loss at step 0: 1.569798469543457
[INFO] Starting epoch 72
[INFO] Loss at step 0: 0.711682915687561
[INFO] Starting epoch 73
[INFO] Loss at step 0: 0.5142080783843994
[INFO] Starting epoch 74
[INFO] Loss at step 0: 0.7309017181396484
[INFO] Starting epoch 75
[INFO] Loss at step 0: 1.1554571390151978
[INFO] Starting epoch 76
[INFO] Loss at step 0: 0.768258810043335
[INFO] Starting epoch 77
[INFO] Loss at step 0: 1.3205249309539795
[INFO] Starting epoch 78
[INFO] Loss at step 0: 0.8266788721084595
[INFO] Starting epoch 79
[INFO] Loss at step 0: 0.5578247904777527
[INFO] Starting epoch 80
[INFO] Loss at step 0: 0.6673489212989807
[INFO] Starting epoch 81
[INFO] Loss at step 0: 2.00838303565979
[INFO] Starting epoch 82
[INFO] Loss at step 0: 1.061743140220642


'/Users/patrickmedina/Desktop/regression/model-891'

In [15]:
a = graph.get_tensor_by_name("a:0")
b = graph.get_tensor_by_name("b:0")

print(a)
print(b)

Tensor("a:0", shape=(), dtype=float32_ref)
Tensor("b:0", shape=(), dtype=float32_ref)


In [16]:
sess.run([a, b, a_1, b_1])

[6.1656513, 2.1693344, 1.6297995, 1.5964162]

We see that things got foobar...

In [17]:
sess.close()