Skip to content

Commit

Permalink
first yadll commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Jul 27, 2016
1 parent 6752335 commit 4d3891d
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pytest

import logging

class TestModel:
@pytest.fixture(scope='module')
Expand All @@ -12,6 +13,14 @@ def data(self):
[np.random.random((50, 25)), np.random.random_integers(low=0, high=9, size=(500,))]]
return Data(data)

@pytest.fixture(scope='module')
def data_y_2D(self):
from yadll.data import Data
data = [[np.random.random((100, 25)), np.random.random_integers(low=0, high=9, size=(100,2))],
[np.random.random((50, 25)), np.random.random_integers(low=0, high=9, size=(50,2))],
[np.random.random((50, 25)), np.random.random_integers(low=0, high=9, size=(500,2))]]
return Data(data)

@pytest.fixture(scope='module')
def hp(self):
from yadll.hyperparameters import Hyperparameters
Expand All @@ -37,6 +46,11 @@ def model_no_data(self, hp):
from yadll.model import Model
return Model(name='test_model', hyperparameters=hp)

@pytest.fixture(scope='module')
def model_y_2D(self, data_y_2D, hp):
from yadll.model import Model
return Model(name='test_model', data=data_y_2D, hyperparameters=hp)

@pytest.fixture(scope='module')
def input(self):
from yadll.layers import InputLayer
Expand Down Expand Up @@ -93,11 +107,15 @@ def test_no_network(self, model):
with pytest.raises(NoNetworkFoundException):
model.train(unsupervised_training=False)

def test_save_model(self, model, network):
def test_save_model(self, model, network, caplog):
model.network = network
from yadll.model import save_model, load_model
caplog.setLevel(logging.ERROR)
save_model(model)
assert 'No file name. Model not saved.' in caplog.text()
model.train(save_mode='end')
model.train(save_mode='each')
model.train(save_mode='dummy')
model.file=('test_model.ym')
model.train()
model.train(save_mode='end')
Expand All @@ -106,19 +124,21 @@ def test_save_model(self, model, network):
save_model(model, 'test_model.ym')
test_model = load_model('test_model.ym')

def test_model(self, model, network, network_unsupervised):
def test_model(self, model, model_y_2D, network, network_unsupervised):
model.network = network
assert model.name == 'test_model'
model.train()
network_unsupervised.layers[0].input = None
model.network = network_unsupervised
model.pretrain()
model.train()
model_y_2D.network = network
#model_y_2D.train()

def test_predict(self, data, model, network):
network.layers[0].input = None
model.network = network
model.train()
model.network.layers[0].input = None
model.predict(data.test_set_x.eval()[:10])


Expand Down

0 comments on commit 4d3891d

Please sign in to comment.