Skip to content

Commit

Permalink
simplify metrics tests. fix arg order bug in binary_crossentropy. ren…
Browse files Browse the repository at this point in the history
…ame euclidian -> euclidean.
  • Loading branch information
lucaskolstad committed Nov 15, 2016
1 parent 92d75e0 commit 0c50c21
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 37 deletions.
8 changes: 4 additions & 4 deletions mla/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from mla.base import BaseEstimator
from mla.metrics.distance import euclidian_distance
from mla.metrics.distance import euclidean_distance

random.seed(1111)

Expand Down Expand Up @@ -98,7 +98,7 @@ def _closest(self, fpoint, centroids):
closest_index = None
closest_distance = None
for i, point in enumerate(centroids):
dist = euclidian_distance(self.X[fpoint], point)
dist = euclidean_distance(self.X[fpoint], point)
if closest_index is None or dist < closest_distance:
closest_index = i
closest_distance = dist
Expand All @@ -109,7 +109,7 @@ def _get_centroid(self, cluster):
return [np.mean(np.take(self.X[:, i], cluster)) for i in range(self.n_features)]

def _dist_from_centers(self):
return np.array([min([euclidian_distance(x, c) for c in self.centroids]) for x in self.X])
return np.array([min([euclidean_distance(x, c) for c in self.centroids]) for x in self.X])

def _choose_next_center(self):
distances = self._dist_from_centers()
Expand All @@ -120,7 +120,7 @@ def _choose_next_center(self):
return self.X[ind]

def _is_converged(self, centroids_old, centroids):
return True if sum([euclidian_distance(centroids_old[i], centroids[i]) for i in range(self.K)]) == 0 else False
return True if sum([euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]) == 0 else False

def plot(self, data=None):
sns.set(style="white")
Expand Down
2 changes: 0 additions & 2 deletions mla/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


import numpy as np


Expand Down
3 changes: 1 addition & 2 deletions mla/metrics/distance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

import numpy as np
import math


def euclidian_distance(a, b):
def euclidean_distance(a, b):
if isinstance(a, list) and isinstance(b, list):
a = np.array(a)
b = np.array(b)
Expand Down
7 changes: 3 additions & 4 deletions mla/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

import autograd.numpy as np
from autograd import grad

EPS = 1e-15

Expand Down Expand Up @@ -68,9 +66,10 @@ def hinge(actual, predicted):
return np.mean(np.max(1. - actual * predicted, 0.))


def binary_crossentropy(predicted, actual):
def binary_crossentropy(actual, predicted):
predicted = np.clip(predicted, EPS, 1 - EPS)
return np.mean(-np.sum(actual * np.log(predicted) + (1 - actual) * np.log(1 - predicted)))
return np.mean(-np.sum(actual * np.log(predicted) +
(1 - actual) * np.log(1 - predicted)))


# aliases
Expand Down
60 changes: 35 additions & 25 deletions mla/metrics/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.testing import assert_almost_equal

from mla.metrics.base import check_data, validate_input
from mla.metrics.metrics import *
from mla.metrics.metrics import get_metric


def test_data_validation():
Expand All @@ -26,53 +26,63 @@ def metric(name):


def test_classification_error():
assert metric('classification_error')([1, 2, 3, 4], [1, 2, 3, 4]) == 0
assert metric('classification_error')([1, 2, 3, 4], [1, 2, 3, 5]) == 0.25
assert metric('classification_error')([1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]) == (1.0 / 6)
f = metric('classification_error')
assert f([1, 2, 3, 4], [1, 2, 3, 4]) == 0
assert f([1, 2, 3, 4], [1, 2, 3, 5]) == 0.25
assert f([1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0]) == (1.0 / 6)


def test_absolute_error():
assert metric('absolute_error')([3], [5]) == [2]
assert metric('absolute_error')([-1], [-4]) == [3]
f = metric('absolute_error')
assert f([3], [5]) == [2]
assert f([-1], [-4]) == [3]


def test_mean_absolute_error():
assert metric('mean_absolute_error')([1, 2, 3], [1, 2, 3]) == 0
assert metric('mean_absolute_error')([1, 2, 3], [3, 2, 1]) == 4 / 3
f = metric('mean_absolute_error')
assert f([1, 2, 3], [1, 2, 3]) == 0
assert f([1, 2, 3], [3, 2, 1]) == 4 / 3


def test_squared_error():
assert metric('squared_error')([1], [1]) == [0]
assert metric('squared_error')([3], [1]) == [4]
f = metric('squared_error')
assert f([1], [1]) == [0]
assert f([3], [1]) == [4]


def test_squared_log_error():
assert metric('squared_log_error')([1], [1]) == [0]
assert metric('squared_log_error')([3], [1]) == [np.log(2) ** 2]
assert metric('squared_log_error')([np.exp(2) - 1], [np.exp(1) - 1]) == [1.0]
f = metric('squared_log_error')
assert f([1], [1]) == [0]
assert f([3], [1]) == [np.log(2) ** 2]
assert f([np.exp(2) - 1], [np.exp(1) - 1]) == [1.0]


def test_mean_squered_error():
assert metric('mean_squared_log_error')([1, 2, 3], [1, 2, 3]) == 0
assert metric('mean_squared_log_error')([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.25
def test_mean_squared_log_error():
f = metric('mean_squared_log_error')
assert f([1, 2, 3], [1, 2, 3]) == 0
assert f([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.25


def test_root_mean_squared_log_error():
assert metric('root_mean_squared_log_error')([1, 2, 3], [1, 2, 3]) == 0
assert metric('root_mean_squared_log_error')([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.5
f = metric('root_mean_squared_log_error')
assert f([1, 2, 3], [1, 2, 3]) == 0
assert f([1, 2, 3, np.exp(1) - 1], [1, 2, 3, np.exp(2) - 1]) == 0.5


def test_mean_squared_error():
assert metric('mean_squared_error')([1, 2, 3], [1, 2, 3]) == 0
assert metric('mean_squared_error')(range(1, 5), [1, 2, 3, 6]) == 1
f = metric('mean_squared_error')
assert f([1, 2, 3], [1, 2, 3]) == 0
assert f(range(1, 5), [1, 2, 3, 6]) == 1


def test_root_mean_squared_error():
assert metric('root_mean_squared_error')([1, 2, 3], [1, 2, 3]) == 0
assert metric('root_mean_squared_error')(range(1, 5), [1, 2, 3, 5]) == 0.5
f = metric('root_mean_squared_error')
assert f([1, 2, 3], [1, 2, 3]) == 0
assert f(range(1, 5), [1, 2, 3, 5]) == 0.5


def test_multiclass_logloss():
assert_almost_equal(metric('logloss')([1], [1]), 0)
assert_almost_equal(metric('logloss')([1, 1], [1, 1]), 0)
assert_almost_equal(metric('logloss')([1], [0.5]), -np.log(0.5))
f = metric('logloss')
assert_almost_equal(f([1], [1]), 0)
assert_almost_equal(f([1, 1], [1, 1]), 0)
assert_almost_equal(f([1], [0.5]), -np.log(0.5))

0 comments on commit 0c50c21

Please sign in to comment.