From 298704c2ab4ed673b8384280d019935beb1bd42c Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 11 Feb 2022 14:58:10 +0000 Subject: [PATCH 1/2] Add constraint --- src/spaces/discrete_state_environment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spaces/discrete_state_environment.py b/src/spaces/discrete_state_environment.py index c2a7a59..4302803 100644 --- a/src/spaces/discrete_state_environment.py +++ b/src/spaces/discrete_state_environment.py @@ -45,6 +45,8 @@ class DiscreteStateEnvironment(object): to create bins where the average total distortion of the dataset falls in """ + IS_TILED_ENV_CONSTRAINT = False + def __init__(self, env_config: DiscreteEnvConfig) -> None: self.config = env_config self.n_rounds_below_min_distortion = 0 From 13a0ad8c0cd4484b05428c01b1e5eef7ac40c9ac Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 11 Feb 2022 14:59:51 +0000 Subject: [PATCH 2/2] #30 Add test_suite and test for Trainer class --- src/tests/test_suite.py | 18 ++++++++++++++++++ src/tests/test_trainer.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 src/tests/test_suite.py create mode 100644 src/tests/test_trainer.py diff --git a/src/tests/test_suite.py b/src/tests/test_suite.py new file mode 100644 index 0000000..7531d1f --- /dev/null +++ b/src/tests/test_suite.py @@ -0,0 +1,18 @@ +import unittest + +from .test_trainer import TestTrainer +from .test_serial_hierarchy import TestSerialHierarchy +from .test_preprocessor import TestPreprocessor +from .test_actions import TestActions + +def suite(): + suite = unittest.TestSuite() + suite.addTest(TestTrainer) + suite.addTest(TestSerialHierarchy) + suite.addTest(TestPreprocessor) + suite.addTest(TestActions) + return suite + +if __name__ == '__main__': + runner = unittest.TextTestRunner() + runner.run(suite()) \ No newline at end of file diff --git a/src/tests/test_trainer.py b/src/tests/test_trainer.py new file mode 100644 index 0000000..5d885d7 --- /dev/null +++ b/src/tests/test_trainer.py @@ -0,0 +1,29 @@ +""" +Unit-tests for class Trainer +""" +import unittest + +from src.algorithms.trainer import Trainer +from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn +from src.spaces.tiled_environment import TiledEnv + + +class TestTrainer(unittest.TestCase): + + def test_with_sarsa_semi_grad_agent(self): + + # create tiled environment + tiled_env = TiledEnv(env=None, num_tilings=10, max_size=4096, + tiling_dim=5) + + sarsa_config = SARSAnConfig() + agent = SARSAn(sarsa_config=sarsa_config) + + trainer = Trainer(agent=agent, env=tiled_env, + configuration={"n_episodes": 1}) + + trainer.train() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file