Skip to content

Commit

Permalink
latest
Browse files Browse the repository at this point in the history
  • Loading branch information
trigeorgis committed Oct 14, 2016
1 parent fc8d011 commit ce0baf4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
7 changes: 1 addition & 6 deletions mdm_model.py
Expand Up @@ -18,14 +18,9 @@ def norm(x):
return tf.image.resize_bilinear(tf.expand_dims(im, 0), new_size)[0, :, :, :], align_mean_shape / ratio, ratio

def normalized_rmse(pred, gt_truth):
pred_shape = pred.get_shape().as_list()
gt_truth_shape = gt_truth.get_shape().as_list()
num_lm = gt_truth_shape[1]
assert pred_shape == gt_truth_shape, "Conflicting predicted and ground truth shapes"

norm = tf.sqrt(tf.reduce_sum(((gt_truth[:, 36, :] - gt_truth[:, 45, :])**2), 1))

return tf.reduce_sum(tf.sqrt(tf.reduce_sum(tf.square(pred - gt_truth), 2)), 1) / (norm * num_lm)
return tf.reduce_sum(tf.sqrt(tf.reduce_sum(tf.square(pred - gt_truth), 2)), 1) / (norm * 68)


def conv_model(inputs, is_training=True, scope=''):
Expand Down
4 changes: 2 additions & 2 deletions mdm_train.py
Expand Up @@ -87,7 +87,7 @@ def get_random_sample(rotation_stddev=10):
if np.random.rand() < .5:
im = utils.mirror_image(im)

if np.random.rand() < .5 and False:
if np.random.rand() < .5:
theta = np.random.normal(scale=rotation_stddev)
rot = menpo.transform.rotate_ccw_about_centre(lms, theta)
im = im.warp_to_shape(im.shape, rot)
Expand All @@ -97,7 +97,7 @@ def get_random_sample(rotation_stddev=10):
return pixels, shape

image, shape = tf.py_func(get_random_sample, [],
[tf.float32, tf.float32])
[tf.float32, tf.float32], stateful=True)

initial_shape = data_provider.random_shape(shape, reference_shape,
pca_model)
Expand Down

0 comments on commit ce0baf4

Please sign in to comment.