Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spaces/discrete_state_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class DiscreteStateEnvironment(object):
to create bins where the average total distortion of the dataset falls in
"""

IS_TILED_ENV_CONSTRAINT = False

def __init__(self, env_config: DiscreteEnvConfig) -> None:
self.config = env_config
self.n_rounds_below_min_distortion = 0
Expand Down
18 changes: 18 additions & 0 deletions src/tests/test_suite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest

from .test_trainer import TestTrainer
from .test_serial_hierarchy import TestSerialHierarchy
from .test_preprocessor import TestPreprocessor
from .test_actions import TestActions

def suite():
suite = unittest.TestSuite()
suite.addTest(TestTrainer)
suite.addTest(TestSerialHierarchy)
suite.addTest(TestPreprocessor)
suite.addTest(TestActions)
return suite

if __name__ == '__main__':
runner = unittest.TextTestRunner()
runner.run(suite())
29 changes: 29 additions & 0 deletions src/tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Unit-tests for class Trainer
"""
import unittest

from src.algorithms.trainer import Trainer
from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn
from src.spaces.tiled_environment import TiledEnv


class TestTrainer(unittest.TestCase):

def test_with_sarsa_semi_grad_agent(self):

# create tiled environment
tiled_env = TiledEnv(env=None, num_tilings=10, max_size=4096,
tiling_dim=5)

sarsa_config = SARSAnConfig()
agent = SARSAn(sarsa_config=sarsa_config)

trainer = Trainer(agent=agent, env=tiled_env,
configuration={"n_episodes": 1})

trainer.train()


if __name__ == '__main__':
unittest.main()