Skip to content

Commit

Permalink
Save npz instead of npz_dict (#41)
Browse files Browse the repository at this point in the history
* hao25

* load pretrained
  • Loading branch information
zsdonghao committed Sep 5, 2018
1 parent 14ad54b commit 1ad3145
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,19 +231,11 @@ def generator():
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())

## restore pretrained weights TODO: use tl.models.VGG19
# npy_file = np.load('models', encoding='latin1').item()
# params = []
# for val in sorted(npy_file.items()):
# if val[0] == 'conv4_3':
# break
# W = np.asarray(val[1][0])
# b = np.asarray(val[1][1])
# print("Loading %s: %s, %s" % (val[0], W.shape, b.shape))
# params.extend([W, b])
# tl.files.assign_params(sess, params, cnn)
# print("Restoring model from npy file")
# cnn.restore_params(sess)
## restore pretrained weights
try:
tl.files.load_and_assign_npz(sess, os.path.join(model_path, 'pose.npz'), net)
except:
print("no pretrained model")

## train until the end
sess.run(tf.assign(lr_v, base_lr))
Expand Down Expand Up @@ -291,9 +283,9 @@ def generator():
## save intermedian results and model
if (step != 0) and (step % save_interval == 0):
draw_results(x_, confs_, conf_result, pafs_, paf_result, mask, 'train_%d_' % step)
tl.files.save_npz_dict(
tl.files.save_npz(
net.all_params, os.path.join(model_path, 'pose' + str(step) + '.npz'), sess=sess)
tl.files.save_npz_dict(net.all_params, os.path.join(model_path, 'pose.npz'), sess=sess)
tl.files.save_npz(net.all_params, os.path.join(model_path, 'pose.npz'), sess=sess)
if step == n_step: # training finished
break

Expand Down

0 comments on commit 1ad3145

Please sign in to comment.