diff --git a/src/spaces/discrete_state_environment.py b/src/spaces/discrete_state_environment.py index c2a7a59..4302803 100644 --- a/src/spaces/discrete_state_environment.py +++ b/src/spaces/discrete_state_environment.py @@ -45,6 +45,8 @@ class DiscreteStateEnvironment(object): to create bins where the average total distortion of the dataset falls in """ + IS_TILED_ENV_CONSTRAINT = False + def __init__(self, env_config: DiscreteEnvConfig) -> None: self.config = env_config self.n_rounds_below_min_distortion = 0 diff --git a/src/tests/test_suite.py b/src/tests/test_suite.py new file mode 100644 index 0000000..7531d1f --- /dev/null +++ b/src/tests/test_suite.py @@ -0,0 +1,18 @@ +import unittest + +from .test_trainer import TestTrainer +from .test_serial_hierarchy import TestSerialHierarchy +from .test_preprocessor import TestPreprocessor +from .test_actions import TestActions + +def suite(): + suite = unittest.TestSuite() + suite.addTest(TestTrainer) + suite.addTest(TestSerialHierarchy) + suite.addTest(TestPreprocessor) + suite.addTest(TestActions) + return suite + +if __name__ == '__main__': + runner = unittest.TextTestRunner() + runner.run(suite()) \ No newline at end of file diff --git a/src/tests/test_trainer.py b/src/tests/test_trainer.py new file mode 100644 index 0000000..5d885d7 --- /dev/null +++ b/src/tests/test_trainer.py @@ -0,0 +1,29 @@ +""" +Unit-tests for class Trainer +""" +import unittest + +from src.algorithms.trainer import Trainer +from src.algorithms.sarsa_semi_gradient import SARSAnConfig, SARSAn +from src.spaces.tiled_environment import TiledEnv + + +class TestTrainer(unittest.TestCase): + + def test_with_sarsa_semi_grad_agent(self): + + # create tiled environment + tiled_env = TiledEnv(env=None, num_tilings=10, max_size=4096, + tiling_dim=5) + + sarsa_config = SARSAnConfig() + agent = SARSAn(sarsa_config=sarsa_config) + + trainer = Trainer(agent=agent, env=tiled_env, + configuration={"n_episodes": 1}) + + trainer.train() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file