Skip to content

Commit

Permalink
first yadll commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Jul 27, 2016
1 parent 46d802d commit 6752335
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,39 @@ def unsupervised_layer(self, layer):
return AutoEncoder(incoming=layer, nb_units=25, hyperparameters=hp)

@pytest.fixture(scope='module')
def logistic_regression(self, unsupervised_layer):
def logistic_regression_unsupervised(self, unsupervised_layer):
from yadll.layers import LogisticRegression
return LogisticRegression(incoming=unsupervised_layer, nb_class=10)

@pytest.fixture(scope='module')
def logistic_regression(self, layer):
from yadll.layers import LogisticRegression
return LogisticRegression(incoming=layer, nb_class=10)

@pytest.fixture(scope='module')
def network(self, input, layer, logistic_regression):
from yadll.network import Network
return Network(name='test_network', layers=[input, layer, logistic_regression])

@pytest.fixture(scope='module')
def network_unsupervised(self, input,unsupervised_layer, logistic_regression):
def network_unsupervised(self, input, unsupervised_layer, logistic_regression_unsupervised):
from yadll.network import Network
return Network(name='test_network', layers=[input, unsupervised_layer, logistic_regression])
return Network(name='test_network', layers=[input, unsupervised_layer, logistic_regression_unsupervised])

def test_no_data_found(self, model_no_data, network):
model_no_data.network = network
def test_no_data_found(self, model_no_data, network_unsupervised):
model_no_data.network = network_unsupervised
from yadll.exceptions import NoDataFoundException
with pytest.raises(NoDataFoundException):
model_no_data.pretrain()
with pytest.raises(NoDataFoundException):
model_no_data.train()
model_no_data.train(unsupervised_training=False)

def test_no_network(self, model):
from yadll.exceptions import NoNetworkFoundException
with pytest.raises(NoNetworkFoundException):
model.pretrain()
with pytest.raises(NoNetworkFoundException):
model.train()
model.train(unsupervised_training=False)

def test_save_model(self, model, network):
model.network = network
Expand All @@ -105,10 +110,13 @@ def test_model(self, model, network, network_unsupervised):
model.network = network
assert model.name == 'test_model'
model.train()
network_unsupervised.layers[0].input = None
model.network = network_unsupervised
model.pretrain()
model.train()

def test_predict(self, data, model, network):
network.layers[0].input = None
model.network = network
model.train()
model.predict(data.test_set_x.eval()[:10])
Expand Down

0 comments on commit 6752335

Please sign in to comment.