forked from microsoft/CyberBattleSim
/
notebook_ctf_dql.py
116 lines (95 loc) · 2.98 KB
/
notebook_ctf_dql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tabular Q-learning agent (notebook)
This notebooks can be run directly from VSCode, to generate a
traditional Jupyter Notebook to open in your browser
you can run the VSCode command `Export Currenty Python File As Jupyter Notebook`.
"""
# pylint: disable=invalid-name
# %%
import sys
import logging
import gym
import cyberbattle.agents.baseline.learner as learner
import cyberbattle.agents.baseline.plotting as p
import cyberbattle.agents.baseline.agent_wrapper as w
import cyberbattle.agents.baseline.agent_dql as dqla
from cyberbattle.agents.baseline.agent_wrapper import Verbosity
import importlib
importlib.reload(learner)
logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format="%(levelname)s: %(message)s")
# %%
ctf_env = gym.make('CyberBattleToyCtf-v0')
ep = w.EnvironmentBounds.of_identifiers(
maximum_node_count=22,
maximum_total_credentials=22,
identifiers=ctf_env.identifiers
)
iteration_count = 2000
training_episode_count = 10
eval_episode_count = 10
# %%
# Run Deep Q-learning
# 0.015
best_dqn_learning_run_10 = learner.epsilon_greedy_search(
cyberbattle_gym_env=ctf_env,
environment_properties=ep,
learner=dqla.DeepQLearnerPolicy(
ep=ep,
gamma=0.015,
replay_memory_size=10000,
target_update=10,
batch_size=512,
learning_rate=0.01), # torch default is 1e-2
episode_count=training_episode_count,
iteration_count=iteration_count,
epsilon=0.90,
render=False,
# epsilon_multdecay=0.75, # 0.999,
epsilon_exponential_decay=5000, # 10000
epsilon_minimum=0.10,
verbosity=Verbosity.Quiet,
title="DQL"
)
# %% Plot episode length
p.plot_episodes_length([best_dqn_learning_run_10])
# %%
dql_exploit_run = learner.epsilon_greedy_search(
ctf_env,
ep,
learner=best_dqn_learning_run_10['learner'],
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=0.0, # 0.35,
render=False,
title="Exploiting DQL",
verbosity=Verbosity.Quiet
)
# %%
random_run = learner.epsilon_greedy_search(
ctf_env,
ep,
learner=learner.RandomPolicy(),
episode_count=eval_episode_count,
iteration_count=iteration_count,
epsilon=1.0, # purely random
render=False,
verbosity=Verbosity.Quiet,
title="Random search"
)
# %%
# Plot averaged cumulative rewards for DQL vs Random vs DQL-Exploit
themodel = dqla.CyberBattleStateActionModel(ep)
p.plot_averaged_cummulative_rewards(
all_runs=[
best_dqn_learning_run_10,
random_run,
dql_exploit_run
],
title=f'Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\n'
f'State: {[f.name() for f in themodel.state_space.feature_selection]} '
f'({len(themodel.state_space.feature_selection)}\n'
f"Action: abstract_action ({themodel.action_space.flat_size()})")
# %%
# plot cumulative rewards for all episodes
p.plot_all_episodes(best_dqn_learning_run_10)