Skip to content

Commit

Permalink
fix r2 metric
Browse files Browse the repository at this point in the history
  • Loading branch information
aymericdamien committed Oct 13, 2016
1 parent 3d4bce7 commit a8136af
Showing 1 changed file with 65 additions and 7 deletions.
72 changes: 65 additions & 7 deletions tflearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,43 @@ def __init__(self, name=None):
super(R2, self).__init__(name)
self.name = "R2" if not name else name

def build(self, predictions, targets, inputs):
def build(self, predictions, targets, inputs=None):
""" Build standard error tensor. """
self.built = True
self.tensor = r2_op(predictions, targets, inputs)
self.tensor = r2_op(predictions, targets)
# Add a special name to that tensor, to be used by monitors
self.tensor.m_name = self.name


class WeightedR2(Metric):
""" Weighted Standard Error.
Computes coefficient of determination. Useful to evaluate a linear
regression.
Examples:
```python
# To be used with TFLearn estimators
r2 = R2()
regression = regression(net, metric=r2)
```
Arguments:
name: The name to display.
"""

def __init__(self, name=None):
super(WeightedR2, self).__init__(name)
self.name = "R2" if not name else name

def build(self, predictions, targets, inputs):
""" Build standard error tensor. """
self.built = True
self.tensor = weighted_r2_op(predictions, targets, inputs)
# Add a special name to that tensor, to be used by monitors
self.tensor.m_name = self.name


class Prediction_Counts(Metric):
""" Prints the count of each category of prediction that is present in the predictions.
Expand Down Expand Up @@ -200,7 +229,6 @@ def build(self, predictions, targets, inputs=None):
prediction_counts = Prediction_Counts



# ----------
# Metric ops
# ----------
Expand Down Expand Up @@ -309,7 +337,37 @@ def top_k_op(predictions, targets, k=1):
return acc


def r2_op(predictions, targets, inputs):
def r2_op(predictions, targets):
""" r2_op.
An op that calculates the standard error.
Examples:
```python
input_data = placeholder(shape=[None, 784])
y_pred = my_network(input_data) # Apply some ops
y_true = placeholder(shape=[None, 10]) # Labels
stderr_op = r2_op(y_pred, y_true)
# Calculate standard error by feeding data X and labels Y
std_error = sess.run(stderr_op, feed_dict={input_data: X, y_true: Y})
```
Arguments:
predictions: `Tensor`.
targets: `Tensor`.
Returns:
`Float`. The standard error.
"""
with tf.name_scope('StandardError'):
a = tf.reduce_sum(tf.square(predictions))
b = tf.reduce_sum(tf.square(targets))
return tf.div(a, b)


def weighted_r2_op(predictions, targets, inputs):
""" r2_op.
An op that calculates the standard error.
Expand All @@ -334,12 +392,12 @@ def r2_op(predictions, targets, inputs):
`Float`. The standard error.
"""
with tf.name_scope('StandardError'):
with tf.name_scope('WeightedStandardError'):
if hasattr(inputs, '__len__'):
inputs = tf.add_n(inputs)
if inputs.get_shape().as_list() != targets.get_shape().as_list():
raise Exception("R2 metric requires Inputs and Targets to have "
"same shape.")
raise Exception("Weighted R2 metric requires Inputs and Targets to "
"have same shape.")
a = tf.reduce_sum(tf.square(predictions - inputs))
b = tf.reduce_sum(tf.square(targets - inputs))
return tf.div(a, b)

0 comments on commit a8136af

Please sign in to comment.