Skip to content

Commit

Permalink
Compile
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Feb 23, 2017
1 parent e5a5b45 commit 9ad0a08
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 4 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def test_no_data_found(self, model_no_data, network_unsupervised):

def test_no_network(self, model):
from yadll.exceptions import NoNetworkFoundException
with pytest.raises(NoNetworkFoundException):
model.compile(compile_arg='train')
with pytest.raises(NoNetworkFoundException):
model.pretrain()
with pytest.raises(NoNetworkFoundException):
Expand Down Expand Up @@ -130,6 +132,8 @@ def test_model(self, model, network, network_unsupervised):
model.network = network_unsupervised
model.pretrain()
model.train()
model.hp.patience = 10 # test early stop
model.train()

def test_predict(self, data, model, network):
model.network = network
Expand Down
20 changes: 18 additions & 2 deletions tests/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,27 @@ def test_mean_absolute_error():

def test_binary_hinge_error():
x_val = np.asarray(np.random.uniform(size=(10, 1)), dtype=yadll.utils.floatX)
y_val = np.asarray(np.random.binomial(n=1,p=0.5, size=(10, 1)), dtype=yadll.utils.floatX)
y_val = np.asarray(np.random.binomial(n=1, p=0.5, size=(10, 1)), dtype=yadll.utils.floatX)
y_val = 2 * y_val - 1
f = theano.function([x, y], yadll.objectives.binary_hinge_error(x, y))
actual = f(x_val, y_val)
desired = np.maximum(1. - x_val * y_val, 0.).flatten()
assert_allclose(actual, desired, rtol=eps)


def test_categorical_hinge_error():
x_val = np.asarray(np.random.uniform(size=(10, 1)), dtype=yadll.utils.floatX)
y_val = np.asarray(np.random.binomial(n=1, p=0.5, size=(10, 1)), dtype=yadll.utils.floatX)
y_val = 2 * y_val - 1
f = theano.function([x, y], yadll.objectives.categorical_hinge_error(x, y))
actual = f(x_val, y_val)
desired = np.maximum(1. - x_val * y_val, 0.).flatten()
assert_allclose(actual, desired, rtol=eps)


def test_binary_crossentropy_error():
x_val = np.asarray(np.random.uniform(size=(10, 1)), dtype=yadll.utils.floatX)
y_val = np.asarray(np.random.binomial(n=1,p=0.5, size=(10, 1)), dtype=yadll.utils.floatX)
y_val = np.asarray(np.random.binomial(n=1, p=0.5, size=(10, 1)), dtype=yadll.utils.floatX)
f = theano.function([x, y], yadll.objectives.binary_crossentropy_error(x, y))
actual = f(x_val, y_val)
desired = np.mean(-(y_val * np.log(x_val) + (1 - y_val) * np.log(1 - x_val)), axis=-1)
Expand All @@ -63,3 +73,9 @@ def test_categorical_crossentropy_error():
desired = np.mean(-np.sum(y_val * np.log(x_val), axis=-1))
assert_allclose(actual, desired, rtol=eps)


def test_kullback_leibler_divergence():
f = theano.function([x, y], yadll.objectives.kullback_leibler_divergence(x, y))
actual = f(x_val, y_val)
desired = np.sum(y_val * np.log(y_val/x_val), axis=-1)
assert_allclose(actual, desired, rtol=eps)

0 comments on commit 9ad0a08

Please sign in to comment.