forked from IntelLabs/coach
/
test_basic_rl_graph_manager.py
73 lines (65 loc) · 3.1 KB
/
test_basic_rl_graph_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gc
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.core_types import EnvironmentSteps
from rl_coach.utils import get_open_port
from multiprocessing import Process
from tensorflow import logging
import pytest
logging.set_verbosity(logging.INFO)
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_pong_a3c():
tf.reset_default_graph()
from rl_coach.presets.Atari_A3C import graph_manager
assert graph_manager
graph_manager.env_params.level = "PongDeterministic-v4"
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
# graph_manager.improve()
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_pong_nec():
tf.reset_default_graph()
from rl_coach.presets.Atari_NEC import graph_manager
assert graph_manager
graph_manager.env_params.level = "PongDeterministic-v4"
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
# graph_manager.improve()
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test"))
# graph_manager.improve()
# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
# graph_manager.improve()
# graph_manager.evaluate(EnvironmentSteps(1000))
# graph_manager.save_checkpoint()
#
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
# while True:
# graph_manager.restore_checkpoint()
# graph_manager.evaluate(EnvironmentSteps(1000))
# gc.collect()
if __name__ == '__main__':
pass
# test_basic_rl_graph_manager_with_pong_a3c()
# test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn()