Skip to content

Commit

Permalink
Support single array input for metric (apache#9930)
Browse files Browse the repository at this point in the history
* fix apache#9865

* add unittest

* fix format

* fix format

* fix superfluous loop in metric

* fix lint
  • Loading branch information
hetong007 authored and szha committed Mar 13, 2018
1 parent acfa335 commit 8e9cdb3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 14 deletions.
59 changes: 45 additions & 14 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,25 @@
from . import registry


def check_label_shapes(labels, preds, shape=0):
if shape == 0:
def check_label_shapes(labels, preds, wrap=False, shape=False):
"""Helper function for checking shape of label and prediction
Parameters
----------
labels : list of `NDArray`
The labels of the data.
preds : list of `NDArray`
Predicted values.
wrap : boolean
If True, wrap labels/preds in a list if they are single NDArray
shape : boolean
If True, check the shape of labels and preds;
Otherwise only check their length.
"""
if not shape:
label_shape, pred_shape = len(labels), len(preds)
else:
label_shape, pred_shape = labels.shape, preds.shape
Expand All @@ -40,6 +57,13 @@ def check_label_shapes(labels, preds, shape=0):
raise ValueError("Shape of labels {} does not match shape of "
"predictions {}".format(label_shape, pred_shape))

if wrap:
if isinstance(labels, ndarray.ndarray.NDArray):
labels = [labels]
if isinstance(preds, ndarray.ndarray.NDArray):
preds = [preds]

return labels, preds

class EvalMetric(object):
"""Base class for all evaluation metrics.
Expand Down Expand Up @@ -386,15 +410,15 @@ def update(self, labels, preds):
Prediction values for samples. Each prediction value can either be the class index,
or a vector of likelihoods for all classes.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32')

check_label_shapes(label, pred_label)
labels, preds = check_label_shapes(label, pred_label)

self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
Expand Down Expand Up @@ -456,7 +480,7 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred_label in zip(labels, preds):
assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims'
Expand Down Expand Up @@ -614,7 +638,7 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
self.metrics.update_binary_stats(label, pred)
Expand Down Expand Up @@ -785,14 +809,16 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.asnumpy()
pred = pred.asnumpy()

if len(label.shape) == 1:
label = label.reshape(label.shape[0], 1)
if len(pred.shape) == 1:
pred = pred.reshape(pred.shape[0], 1)

self.sum_metric += numpy.abs(label - pred).mean()
self.num_inst += 1 # numpy.prod(label.shape)
Expand Down Expand Up @@ -843,14 +869,16 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.asnumpy()
pred = pred.asnumpy()

if len(label.shape) == 1:
label = label.reshape(label.shape[0], 1)
if len(pred.shape) == 1:
pred = pred.reshape(pred.shape[0], 1)

self.sum_metric += ((label - pred)**2.0).mean()
self.num_inst += 1 # numpy.prod(label.shape)
Expand Down Expand Up @@ -901,14 +929,16 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.asnumpy()
pred = pred.asnumpy()

if len(label.shape) == 1:
label = label.reshape(label.shape[0], 1)
if len(pred.shape) == 1:
pred = pred.reshape(pred.shape[0], 1)

self.sum_metric += numpy.sqrt(((label - pred)**2.0).mean())
self.num_inst += 1
Expand Down Expand Up @@ -969,7 +999,7 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.asnumpy()
Expand Down Expand Up @@ -1037,7 +1067,7 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.asnumpy()
Expand Down Expand Up @@ -1095,9 +1125,10 @@ def update(self, labels, preds):
preds : list of `NDArray`
Predicted values.
"""
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
check_label_shapes(label, pred, 1)
check_label_shapes(label, pred, False, True)
label = label.asnumpy()
pred = pred.asnumpy()
self.sum_metric += numpy.corrcoef(pred.ravel(), label.ravel())[0, 1]
Expand Down Expand Up @@ -1209,7 +1240,7 @@ def update(self, labels, preds):
Predicted values.
"""
if not self._allow_extra_outputs:
check_label_shapes(labels, preds)
labels, preds = check_label_shapes(labels, preds, True)

for pred, label in zip(preds, labels):
label = label.asnumpy()
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,27 @@ def test_pearsonr():
_, pearsonr = metric.get()
assert pearsonr == pearsonr_expected

def test_single_array_input():
pred = mx.nd.array([[1,2,3,4]])
label = pred + 0.1

mse = mx.metric.create('mse')
mse.update(label, pred)
_, mse_res = mse.get()
np.testing.assert_almost_equal(mse_res, 0.01)

mae = mx.metric.create('mae')
mae.update(label, pred)
mae.get()
_, mae_res = mae.get()
np.testing.assert_almost_equal(mae_res, 0.1)

rmse = mx.metric.create('rmse')
rmse.update(label, pred)
rmse.get()
_, rmse_res = rmse.get()
np.testing.assert_almost_equal(rmse_res, 0.1)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 8e9cdb3

Please sign in to comment.