From d9d3103a78920fd6df7cd17099650f3182757db1 Mon Sep 17 00:00:00 2001 From: cnheider Date: Mon, 14 Sep 2020 23:21:49 +0200 Subject: [PATCH] tick --- .github/ISSUE_TEMPLATE.md | 33 + .github/utilities/collect_env.py | 243 +++++++ MANIFEST.in | 17 + __init__.py | 8 + docs/source/conf.py | 266 +------- neodroidagent/__init__.py | 6 +- .../model_free/off_policy/dqn_agent.py | 7 +- .../model_free/off_policy/sac_agent.py | 599 +++++++++--------- .../model_free/on_policy/ddpg_agent.py | 3 +- .../model_free/on_policy/ppo_agent.py | 585 +++++++++-------- .../agents/torch_agents/torch_agent.py | 12 +- neodroidagent/common/architectures/README.md | 3 +- .../common/architectures/architecture.py | 64 +- .../architectures/experimental/recurrent.py | 6 +- .../mlp_variants/concatination.py | 7 +- .../expandable_circular_buffer.py | 2 +- .../common/memory/rollout_storage.py | 10 +- .../common/session_factory/vertical/linear.py | 1 + .../session_factory/vertical/parallel.py | 2 +- .../procedures/procedure_specification.py | 13 +- .../procedures/training/off_policy_batched.py | 20 +- .../training/off_policy_episodic.py | 8 +- .../training/off_policy_step_wise.py | 11 +- .../procedures/training/on_policy_episodic.py | 186 +++--- .../single_agent_environment_session.py | 299 +++++---- .../entry_points/agent_tests/random_test.py | 5 +- .../torch_agent_tests/ddpg_test.py | 28 +- .../agent_tests/torch_agent_tests/dqn_test.py | 16 +- .../agent_tests/torch_agent_tests/pg_test.py | 14 +- .../agent_tests/torch_agent_tests/ppo_test.py | 12 +- .../agent_tests/torch_agent_tests/sac_test.py | 14 +- neodroidagent/entry_points/clean.py | 70 +- neodroidagent/entry_points/cli.py | 77 +-- neodroidagent/entry_points/session_factory.py | 72 ++- .../torch_isp/curiosity/icm.py | 2 +- neodroidagent/utilities/misc/__init__.py | 1 - neodroidagent/utilities/misc/bool_tests.py | 121 ---- setup.py | 17 +- 38 files changed, 1431 insertions(+), 1429 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE.md create mode 100644 .github/utilities/collect_env.py create mode 100644 MANIFEST.in create mode 100644 __init__.py delete mode 100644 neodroidagent/utilities/misc/bool_tests.py diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 00000000..cf69c972 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,33 @@ +If you are submitting a feature request, please preface the title with [feature request]. +If you are submitting a bug report, please fill in the following details. + +## Issue description + +Provide a short description. + +## Code example + +Please try to provide a minimal example to reproduce the bug. +Error messages and stack traces are also helpful. + +## System Info +Please copy and paste the output from our +[environment collection script](https://github.com/sintefneodroid/agent/tree/master/.github/utilities/collect_env.py) +(or fill out the checklist below manually). + +You can get the script and run it with: +``` +wget https://github.com/sintefneodroid/neo/tree/master/.github/utilities/collect_env.py +# For security purposes, please check the contents of collect_env.py before running it. +python collect_env.py +``` + +- How you installed NeodroidAgent (conda, pip, source): +- Build command you used (if compiling from source): +- OS: +- NeodroidAgent version: +- Python version: +- GPU models and configuration: +- GCC version (if compiling from source): +- CMake version: +- Versions of any other relevant libraries: diff --git a/.github/utilities/collect_env.py b/.github/utilities/collect_env.py new file mode 100644 index 00000000..668e1d18 --- /dev/null +++ b/.github/utilities/collect_env.py @@ -0,0 +1,243 @@ +# Borrowed for PyTorch repo +# This script outputs relevant system environment info +# Run it with `python collect_env.py`. +import re +import subprocess +import sys +from collections import namedtuple + +import neodroidagent +from setup import NeodroidAgentPackage + +PY3 = sys.version_info >= (3, 0) + +# System Environment Information +SystemEnv = namedtuple( + "SystemEnv", + [ + "neo_version", + "is_a_development_build", + "os", + "python_version", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + ], +) + + +def run_cmd(command): + """Returns (return-code, stdout, stderr)""" + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) + output, err = p.communicate() + rc = p.returncode + if PY3: + output = output.decode("ascii") + err = err.decode("ascii") + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Runs command using run_lambda; reads and returns entire output if rc is 0""" + rc, out, _ = run_lambda(command) + if rc is not 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Runs command using run_lambda, returns the first regex match if it exists""" + rc, out, _ = run_lambda(command) + if rc is not 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def get_platform(): + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") + + +def get_windows_version(run_lambda): + return run_and_read_all(run_lambda, "wmic os get Caption | findstr /v Caption") + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) + + +def check_release_file(run_lambda): + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(. *)"' + ) + + +def get_os(run_lambda): + platform = get_platform() + + if platform is "win32" or platform is "cygwin": + return get_windows_version(run_lambda) + + if platform == "darwin": + version = get_mac_version(run_lambda) + if version is None: + return None + return f"Mac OSX {version}" + + if platform == "linux": + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return desc + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return desc + + return platform + + # Unknown platform + return platform + + +def req_grep_fmt(): + r = "\|".join( + [ + f'{req.split(">")[0].split("=")[0]}' + for req in ( + NeodroidAgentPackage.extras["all"] + NeodroidAgentPackage.requirements + ) + ] + ) + return r + + +def get_pip_packages(run_lambda): + # People generally have `pip` as `pip` or `pip3` + def run_with_pip(pip): + return run_and_read_all( + run_lambda, + pip + f' list - -format=legacy | grep "Neodroid\|{req_grep_fmt()}"', + ) + + if not PY3: + return "pip", run_with_pip("pip") + + # Try to figure out if the user is running pip or pip3. + out2 = run_with_pip("pip") + out3 = run_with_pip("pip3") + + number_of_pips = len([x for x in [out2, out3] if x is not None]) + if number_of_pips is 0: + return "pip", out2 + + if number_of_pips == 1: + if out2 is not None: + return "pip", out2 + return "pip3", out3 + + # num_pips is 2. Return pip3 by default b/c that most likely + # is the one associated with Python 3 + return "pip3", out3 + + +def get_env_info(): + run_lambda = run_cmd + pip_version, pip_list_output = get_pip_packages(run_lambda) + + return SystemEnv( + neo_version=neodroidagent.__version__, + is_a_development_build=neodroidagent.IS_DEVELOP, + python_version=f"{sys.version_info[0]}.{sys.version_info[1]}", + pip_version=pip_version, + pip_packages=pip_list_output, + os=get_os(run_lambda), + ) + + +def pretty_str(env_info): + def replace_all_none_objects(dct, replacement="Could not collect"): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true="Yes", false="No"): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag="[prepend]"): + lines = text.split("\n") + updated_lines = [tag + line for line in lines] + return "\n".join(updated_lines) + + def replace_if_empty(text, replacement="No relevant packages"): + if text is not None and len(text) == 0: + return replacement + return text + + mutable_dict = env_info._asdict() + + mutable_dict = replace_bools(mutable_dict) # Replace True with Yes, False with No + + mutable_dict = replace_all_none_objects( + mutable_dict + ) # Replace all None objects with 'Could not collect' + + mutable_dict["pip_packages"] = replace_if_empty( + mutable_dict["pip_packages"] + ) # If either of these are '', replace with 'No relevant packages' + + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], f"[{env_info.pip_version}] " + ) + return r""" +Neo version: {neo_version} +Is a development build: {is_a_development_build} +OS: {os} +Python version: {python_version} +Versions of relevant libraries: +{pip_packages} +""".format( + **mutable_dict + ).strip() + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print(get_pip_packages(run_cmd)) + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + +if __name__ == "__main__": + main() diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..b7c01653 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,17 @@ +include requirements*.txt # Include requirements + +#include pyproject.toml + +# Include the EMDS +#include *.md +#recursive-include . *.md +global-include *.md + + + + +# Include the license file +#include LICENSE.txt + +# Include the data files +#recursive-include data * \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..85b40c42 --- /dev/null +++ b/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +__author__ = 'Christian Heider Nielsen' +__doc__ = r''' + + Created on 02-09-2020 + ''' diff --git a/docs/source/conf.py b/docs/source/conf.py index 5a319442..975f32b9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,263 +1,3 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# Neo documentation build configuration file, created by -# sphinx-quickstart on Tue Jul 25 10:23:12 2017. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - -sys.path.insert(0, os.path.abspath(".")) - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. - -extensions = [ - "m2r", - # 'recommonmark', - "sphinxcontrib.programoutput", - "sphinx.ext.autodoc", - "sphinx.ext.autosummary", - "sphinx.ext.napoleon", - "sphinx.ext.doctest", - "sphinx.ext.intersphinx", - "sphinx.ext.todo", - "sphinx.ext.coverage", - "sphinx.ext.mathjax", - "sphinx.ext.viewcode", - "sphinx.ext.githubpages", - "sphinx.ext.graphviz", -] - -napoleon_use_ivar = True - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -# source_suffix = '.rst' -source_suffix = {".rst": "restructuredtext", ".txt": "markdown", ".md": "markdown"} - -# source_parsers = { -# '.md': CommonMarkParser, -# } - -# The master toctree document. -master_doc = "index" - -# General information about the project. -project = "Agent" -author = "Christian Heider Nielsen" -copyright = f"2017, {author}" - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# version = 'master (' + neodroid.__version__ + ' )' -release = "master" - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set 'language' from the command line for these cases. -language = None - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = [] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "alabaster" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named 'default.css' will overwrite the builtin 'default.css'. -html_static_path = ["_static"] - -html_baseurl = "agent.neodroid.ml" - -# -- Options for HTMLHelp output ------------------------------------------ - -# Output file base name for HTML help builder. -htmlhelp_basename = "Agentdoc" - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - ( - master_doc, - "Agent.tex", - "Agent Documentation", - "Christian Heider Nielsen", - "manual", - ) -] - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [(master_doc, "agent", "Agent Documentation", [author], 1)] - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ( - master_doc, - "agent", - "Agent Documentation", - author, - "Agent", - "One line description of project.", - "Miscellaneous", - ) -] - -# -- Options for Epub output ---------------------------------------------- - -# Bibliographic Dublin Core info. -epub_title = project -epub_author = author -epub_publisher = author -epub_copyright = copyright - -# The unique identifier of the text. This can be a ISBN number -# or the project homepage. -# -# epub_identifier = '' - -# A unique identification for the text. -# -# epub_uid = '' - -# A list of files that should not be packed into the epub file. -epub_exclude_files = ["search.html"] - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - "python": ("https://docs.python.org/", None), - "numpy": ("http://docs.scipy.org/doc/numpy/", None), -} - -# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- -# See http://stackoverflow.com/a/41184353/3343043 - -from docutils import nodes -from sphinx import addnodes -from sphinx.util.docfields import TypedField - - -def patched_make_field(self, types, domain, items, **kw): - # `kw` catches `env=None` needed for newer sphinx while maintaining - # backwards compatibility when passed along further down! - - # type: (List, unicode, Tuple) -> nodes.field - def handle_item(fieldarg, content): - par = nodes.paragraph() - par += addnodes.literal_strong("", fieldarg) # Patch: this line added - # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, - # addnodes.literal_strong)) - if fieldarg in types: - par += nodes.Text(" (") - # NOTE: using .pop() here to prevent a single type node to be - # inserted twice into the doctree, which leads to - # inconsistencies later when references are resolved - fieldtype = types.pop(fieldarg) - if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): - typename = "".join(n.astext() for n in fieldtype) - typename = typename.replace("int", "python:int") - typename = typename.replace("long", "python:long") - typename = typename.replace("float", "python:float") - typename = typename.replace("type", "python:type") - par.extend( - self.make_xrefs( - self.typerolename, - domain, - typename, - addnodes.literal_emphasis, - **kw, - ) - ) - else: - par += fieldtype - par += nodes.Text(")") - par += nodes.Text(" -- ") - par += content - return par - - field_name = nodes.field_name("", self.label) - if len(items) == 1 and self.can_collapse: - field_arg, content = items[0] - body_node = handle_item(field_arg, content) - else: - body_node = self.list_type() - for field_arg, content in items: - body_node += nodes.list_item("", handle_item(field_arg, content)) - field_body = nodes.field_body("", body_node) - return nodes.field("", field_name, field_body) - - -TypedField.make_field = patched_make_field +version https://git-lfs.github.com/spec/v1 +oid sha256:5d0dde4f86d40844aaf2c0d64da35cdf2318ca7e211994419d0e5f10a2b61a39 +size 8411 diff --git a/neodroidagent/__init__.py b/neodroidagent/__init__.py index 1a480590..3ad532fa 100644 --- a/neodroidagent/__init__.py +++ b/neodroidagent/__init__.py @@ -43,12 +43,12 @@ def dist_is_editable(dist: Any) -> bool: distributions = {v.key: v for v in pkg_resources.working_set} if PROJECT_NAME in distributions: distribution = distributions[PROJECT_NAME] - DEVELOP = dist_is_editable(distribution) + IS_DEVELOP = dist_is_editable(distribution) else: - DEVELOP = True + IS_DEVELOP = True -def get_version(append_time: Any = DEVELOP) -> str: +def get_version(append_time: Any = IS_DEVELOP) -> str: version = __version__ if not version: version = os.getenv("VERSION", "0.0.0") diff --git a/neodroidagent/agents/torch_agents/model_free/off_policy/dqn_agent.py b/neodroidagent/agents/torch_agents/model_free/off_policy/dqn_agent.py index 54357f12..d7008107 100644 --- a/neodroidagent/agents/torch_agents/model_free/off_policy/dqn_agent.py +++ b/neodroidagent/agents/torch_agents/model_free/off_policy/dqn_agent.py @@ -11,7 +11,7 @@ from torch.nn.functional import smooth_l1_loss from torch.optim import Optimizer -from draugr.torch_utilities import to_tensor +from draugr.torch_utilities import to_scalar, to_tensor from draugr.writers import MockWriter, Writer from neodroid.utilities import ActionSpace, ObservationSpace, SignalSpace from neodroidagent.agents.torch_agents.torch_agent import TorchAgent @@ -25,10 +25,9 @@ from neodroidagent.utilities import ( ActionSpaceNotSupported, ExplorationSpecification, - is_zero_or_mod_zero, update_target, ) -from warg import GDKC, drop_unused_kws, super_init_pass_on_kws +from warg import GDKC, drop_unused_kws, super_init_pass_on_kws,is_zero_or_mod_zero __author__ = "Christian Heider Nielsen" __doc__ = r""" @@ -323,7 +322,7 @@ def _update(self, *, metric_writer: Writer = MockWriter()) -> None: self.post_process_gradients(self.value_model.parameters()) self._optimiser.step() - loss_ = loss.detach().cpu().item() + loss_ = to_scalar(loss) if metric_writer: metric_writer.scalar("td_error", td_error.mean(), self.update_i) metric_writer.scalar("loss", loss_, self.update_i) diff --git a/neodroidagent/agents/torch_agents/model_free/off_policy/sac_agent.py b/neodroidagent/agents/torch_agents/model_free/off_policy/sac_agent.py index 853dfbad..e2a0d0e9 100644 --- a/neodroidagent/agents/torch_agents/model_free/off_policy/sac_agent.py +++ b/neodroidagent/agents/torch_agents/model_free/off_policy/sac_agent.py @@ -13,26 +13,25 @@ from tqdm import tqdm from typing import Any, Dict, Sequence, Tuple -from draugr.torch_utilities import freeze_model, frozen_parameters, to_tensor +from draugr.torch_utilities import freeze_model, frozen_parameters, to_scalar, to_tensor from draugr.writers import MockWriter, Writer from neodroid.utilities import ActionSpace, ObservationSpace, SignalSpace from neodroidagent.agents.torch_agents.torch_agent import TorchAgent from neodroidagent.common import ( - Architecture, - PreConcatInputMLP, - Memory, - SamplePoint, - ShallowStdNormalMLP, - TransitionPoint, - TransitionPointBuffer, -) + Architecture, + PreConcatInputMLP, + Memory, + SamplePoint, + ShallowStdNormalMLP, + TransitionPoint, + TransitionPointBuffer, + ) from neodroidagent.utilities import ( - ActionSpaceNotSupported, - is_zero_or_mod_zero, - update_target, -) + ActionSpaceNotSupported, + update_target, + ) from neodroidagent.utilities.misc.sampling import normal_tanh_reparameterised_sample -from warg import GDKC, drop_unused_kws, super_init_pass_on_kws +from warg import GDKC, drop_unused_kws, super_init_pass_on_kws,is_zero_or_mod_zero __author__ = "Christian Heider Nielsen" __doc__ = r""" @@ -44,37 +43,37 @@ @super_init_pass_on_kws class SoftActorCriticAgent(TorchAgent): - """ + """ Soft Actor Critic Agent https://arxiv.org/pdf/1801.01290.pdf https://arxiv.org/pdf/1812.05905.pdf """ - def __init__( - self, - *, - copy_percentage: float = 1e-2, - batch_size: int = 100, - discount_factor: float = 0.95, - target_update_interval: int = 1, - num_inner_updates: int = 20, - sac_alpha: float = 1e-2, - memory_buffer: Memory = TransitionPointBuffer(1000000), - auto_tune_sac_alpha: bool = False, - auto_tune_sac_alpha_optimiser_spec: GDKC = GDKC( - constructor=torch.optim.Adam, lr=3e-4 - ), - actor_optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4), - critic_optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4), - actor_arch_spec: GDKC = GDKC( - ShallowStdNormalMLP, mean_head_activation=torch.tanh - ), - critic_arch_spec: GDKC = GDKC(PreConcatInputMLP), - critic_criterion: callable = mse_loss, - **kwargs - ): - """ + def __init__( + self, + *, + copy_percentage: float = 1e-2, + batch_size: int = 100, + discount_factor: float = 0.999, + target_update_interval: int = 1, + num_inner_updates: int = 20, + sac_alpha: float = 1e-2, + memory_buffer: Memory = TransitionPointBuffer(1000000), + auto_tune_sac_alpha: bool = False, + auto_tune_sac_alpha_optimiser_spec: GDKC = GDKC( + constructor=torch.optim.Adam, lr=3e-4 + ), + actor_optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4), + critic_optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4), + actor_arch_spec: GDKC = GDKC( + ShallowStdNormalMLP, mean_head_activation=torch.tanh + ), + critic_arch_spec: GDKC = GDKC(PreConcatInputMLP), + critic_criterion: callable = mse_loss, + **kwargs + ): + """ :param copy_percentage: :param signal_clipping: @@ -87,40 +86,40 @@ def __init__( :param random_process_spec: :param kwargs: """ - super().__init__(**kwargs) - - assert 0 <= discount_factor <= 1.0 - assert 0 <= copy_percentage <= 1.0 - - self._batch_size = batch_size - self._discount_factor = discount_factor - self._target_update_interval = target_update_interval - self._sac_alpha = sac_alpha - self._copy_percentage = copy_percentage - self._memory_buffer = memory_buffer - self._actor_optimiser_spec: GDKC = actor_optimiser_spec - self._critic_optimiser_spec: GDKC = critic_optimiser_spec - self._actor_arch_spec = actor_arch_spec - self._critic_arch_spec = critic_arch_spec - - self._num_inner_updates = num_inner_updates - self._critic_criterion = critic_criterion - - self._auto_tune_sac_alpha = auto_tune_sac_alpha - self._auto_tune_sac_alpha_optimiser_spec = auto_tune_sac_alpha_optimiser_spec - self.inner_update_i = 0 - - @drop_unused_kws - def _remember( - self, - *, - signal: Any, - terminated: Any, - state: Any, - successor_state: Any, - sample: Any - ) -> None: - """ + super().__init__(**kwargs) + + assert 0 <= discount_factor <= 1.0 + assert 0 <= copy_percentage <= 1.0 + + self._batch_size = batch_size + self._discount_factor = discount_factor + self._target_update_interval = target_update_interval + self._sac_alpha = sac_alpha + self._copy_percentage = copy_percentage + self._memory_buffer = memory_buffer + self._actor_optimiser_spec: GDKC = actor_optimiser_spec + self._critic_optimiser_spec: GDKC = critic_optimiser_spec + self._actor_arch_spec = actor_arch_spec + self._critic_arch_spec = critic_arch_spec + + self._num_inner_updates = num_inner_updates + self._critic_criterion = critic_criterion + + self._auto_tune_sac_alpha = auto_tune_sac_alpha + self._auto_tune_sac_alpha_optimiser_spec = auto_tune_sac_alpha_optimiser_spec + self.inner_update_i = 0 + + @drop_unused_kws + def _remember( + self, + *, + signal: Any, + terminated: Any, + state: Any, + successor_state: Any, + sample: Any + ) -> None: + """ @param signal: @param terminated: @@ -130,41 +129,41 @@ def _remember( @param kwargs: @return: """ - a = [ - TransitionPoint(*s) - for s in zip(state, sample[0], successor_state, signal, terminated) + a = [ + TransitionPoint(*s) + for s in zip(state, sample[0], successor_state, signal, terminated) ] - for a_ in a: - self._memory_buffer.add_transition_point(a_) + for a_ in a: + self._memory_buffer.add_transition_point(a_) - @property - def models(self) -> Dict[str, Architecture]: - """ + @property + def models(self) -> Dict[str, Architecture]: + """ @return: """ - return { - "critic_1": self.critic_1, - "critic_2": self.critic_2, - "actor": self.actor, + return { + "critic_1":self.critic_1, + "critic_2":self.critic_2, + "actor": self.actor, } - @property - def optimisers(self) -> Dict[str, Optimizer]: - return { - "actor_optimiser": self.actor_optimiser, - "critic_optimiser": self.critic_optimiser, + @property + def optimisers(self) -> Dict[str, Optimizer]: + return { + "actor_optimiser": self.actor_optimiser, + "critic_optimiser":self.critic_optimiser, } - @drop_unused_kws - def _sample( - self, - state: Any, - *args, - deterministic: bool = False, - metric_writer: Writer = MockWriter() - ) -> Tuple[torch.Tensor, Any]: - """ + @drop_unused_kws + def _sample( + self, + state: Any, + *args, + deterministic: bool = False, + metric_writer: Writer = MockWriter() + ) -> Tuple[torch.Tensor, Any]: + """ @param state: @param args: @@ -173,29 +172,29 @@ def _sample( @param kwargs: @return: """ - distribution = self.actor(to_tensor(state, device=self._device)) + distribution = self.actor(to_tensor(state, device=self._device)) - with torch.no_grad(): - return (torch.tanh(distribution.sample().detach()), distribution) + with torch.no_grad(): + return (torch.tanh(distribution.sample().detach()), distribution) - def extract_action(self, sample: SamplePoint) -> numpy.ndarray: - """ + def extract_action(self, sample: SamplePoint) -> numpy.ndarray: + """ @param sample: @return: """ - return sample[0].to("cpu").numpy() - - @drop_unused_kws - def __build__( - self, - observation_space: ObservationSpace, - action_space: ActionSpace, - signal_space: SignalSpace, - metric_writer: Writer = MockWriter(), - print_model_repr: bool = True, - ) -> None: - """ + return sample[0].to("cpu").numpy() + + @drop_unused_kws + def __build__( + self, + observation_space: ObservationSpace, + action_space: ActionSpace, + signal_space: SignalSpace, + metric_writer: Writer = MockWriter(), + print_model_repr: bool = True, + ) -> None: + """ @param observation_space: @param action_space: @@ -204,231 +203,225 @@ def __build__( @param print_model_repr: @return: """ - if action_space.is_discrete: - raise ActionSpaceNotSupported( - "discrete action space not supported in this implementation" - ) - - self._critic_arch_spec.kwargs["input_shape"] = ( - self._input_shape + self._output_shape + if action_space.is_discrete: + raise ActionSpaceNotSupported( + "discrete action space not supported in this implementation" + ) + + self._critic_arch_spec.kwargs["input_shape"] = ( + self._input_shape + self._output_shape + ) + self._critic_arch_spec.kwargs["output_shape"] = 1 + + self.critic_1 = self._critic_arch_spec().to(self._device) + self.critic_1_target = copy.deepcopy(self.critic_1).to(self._device) + freeze_model(self.critic_1_target, True, True) + + self.critic_2 = self._critic_arch_spec().to(self._device) + self.critic_2_target = copy.deepcopy(self.critic_2).to(self._device) + freeze_model(self.critic_2_target, True, True) + + self.critic_optimiser = self._critic_optimiser_spec( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) ) - self._critic_arch_spec.kwargs["output_shape"] = 1 - self.critic_1 = self._critic_arch_spec().to(self._device) - self.critic_1_target = copy.deepcopy(self.critic_1).to(self._device) - freeze_model(self.critic_1_target, True, True) - - self._critic_arch_spec.kwargs["input_shape"] = ( - self._input_shape + self._output_shape - ) - self._critic_arch_spec.kwargs["output_shape"] = 1 - self.critic_2 = self._critic_arch_spec().to(self._device) - self.critic_2_target = copy.deepcopy(self.critic_2).to(self._device) - freeze_model(self.critic_2_target, True, True) - - self.critic_optimiser = self._critic_optimiser_spec( - itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) - ) - - self._actor_arch_spec.kwargs["input_shape"] = self._input_shape - self._actor_arch_spec.kwargs["output_shape"] = self._output_shape - self.actor = self._actor_arch_spec().to(self._device) - self.actor_optimiser = self._actor_optimiser_spec(self.actor.parameters()) - - if self._auto_tune_sac_alpha: - self._target_entropy = -torch.prod( - to_tensor(self._output_shape, device=self._device) - ).item() - self._log_sac_alpha = nn.Parameter( - torch.log(to_tensor(self._sac_alpha, device=self._device)), - requires_grad=True, - ) - self.sac_alpha_optimiser = self._auto_tune_sac_alpha_optimiser_spec( - [self._log_sac_alpha] - ) - - def on_load(self) -> None: - """ + self._actor_arch_spec.kwargs["input_shape"] = self._input_shape + self._actor_arch_spec.kwargs["output_shape"] = self._output_shape + self.actor = self._actor_arch_spec().to(self._device) + self.actor_optimiser = self._actor_optimiser_spec(self.actor.parameters()) + + if self._auto_tune_sac_alpha: + self._target_entropy = -torch.prod( + to_tensor(self._output_shape, device=self._device) + ).item() + self._log_sac_alpha = nn.Parameter( + torch.log(to_tensor(self._sac_alpha, device=self._device)), + requires_grad=True, + ) + self.sac_alpha_optimiser = self._auto_tune_sac_alpha_optimiser_spec( + [self._log_sac_alpha] + ) + + def on_load(self) -> None: + """ @return: """ - self.update_targets(1.0) + self.update_targets(1.0) - def update_critics( - self, tensorised: TransitionPoint, metric_writer: Writer = None - ) -> float: - """ + def update_critics( + self, tensorised: TransitionPoint, metric_writer: Writer = None + ) -> float: + """ @param metric_writer: @param tensorised: @return: """ - with torch.no_grad(): - successor_action, successor_log_prob = normal_tanh_reparameterised_sample( - self.actor(tensorised.successor_state) - ) - - min_successor_q = ( - torch.min( - self.critic_1_target(tensorised.successor_state, successor_action), - self.critic_2_target(tensorised.successor_state, successor_action), - ) - - successor_log_prob * self._sac_alpha - ) - - successor_q_value = ( - tensorised.signal - + tensorised.non_terminal_numerical - * self._discount_factor - * min_successor_q - ).detach() - assert not successor_q_value.requires_grad - - q_value_loss1 = self._critic_criterion( - self.critic_1(tensorised.state, tensorised.action), successor_q_value + with torch.no_grad(): + successor_action, successor_log_prob = normal_tanh_reparameterised_sample( + self.actor(tensorised.successor_state) + ) + + min_successor_q = ( + torch.min( + self.critic_1_target(tensorised.successor_state, successor_action), + self.critic_2_target(tensorised.successor_state, successor_action), + ) + - successor_log_prob * self._sac_alpha + ) + + successor_q_value = ( + tensorised.signal + + tensorised.non_terminal_numerical + * self._discount_factor + * min_successor_q + ).detach() + assert not successor_q_value.requires_grad + + q_value_loss1 = self._critic_criterion( + self.critic_1(tensorised.state, tensorised.action), successor_q_value ) - q_value_loss2 = self._critic_criterion( - self.critic_2(tensorised.state, tensorised.action), successor_q_value + q_value_loss2 = self._critic_criterion( + self.critic_2(tensorised.state, tensorised.action), successor_q_value ) - critic_loss = q_value_loss1 + q_value_loss2 - assert critic_loss.requires_grad - self.critic_optimiser.zero_grad() - critic_loss.backward() - self.post_process_gradients(self.critic_1.parameters()) - self.post_process_gradients(self.critic_2.parameters()) - self.critic_optimiser.step() - - out_loss = critic_loss.detach().cpu().item() - - if metric_writer: - metric_writer.scalar("Critics_loss", out_loss) - metric_writer.scalar("q_value_loss1", q_value_loss1.cpu().mean().item()) - metric_writer.scalar("q_value_loss2", q_value_loss2.cpu().mean().item()) - metric_writer.scalar("min_successor_q", min_successor_q.cpu().mean().item()) - metric_writer.scalar( - "successor_q_value", successor_q_value.cpu().mean().item() - ) - - return out_loss - - def update_actor( - self, tensorised: torch.Tensor, metric_writer: Writer = None - ) -> float: - """ + critic_loss = q_value_loss1 + q_value_loss2 + assert critic_loss.requires_grad + self.critic_optimiser.zero_grad() + critic_loss.backward() + self.post_process_gradients(self.critic_1.parameters()) + self.post_process_gradients(self.critic_2.parameters()) + self.critic_optimiser.step() + + out_loss = to_scalar(critic_loss) + + if metric_writer: + metric_writer.scalar("Critics_loss", out_loss, self.update_i) + metric_writer.scalar("q_value_loss1", to_scalar(q_value_loss1), self.update_i) + metric_writer.scalar("q_value_loss2", to_scalar(q_value_loss2), self.update_i) + metric_writer.scalar("min_successor_q", to_scalar(min_successor_q), self.update_i) + metric_writer.scalar("successor_q_value", to_scalar(successor_q_value), self.update_i) + + return out_loss + + def update_actor( + self, tensorised: torch.Tensor, metric_writer: Writer = None + ) -> float: + """ @param tensorised: @param metric_writer: @return: """ - dist = self.actor(tensorised.state) - action, log_prob = normal_tanh_reparameterised_sample(dist) + dist = self.actor(tensorised.state) + action, log_prob = normal_tanh_reparameterised_sample(dist) - # Check gradient paths - assert action.requires_grad - assert log_prob.requires_grad + # Check gradient paths + assert action.requires_grad + assert log_prob.requires_grad - q_values = ( - self.critic_1(tensorised.state, action), - self.critic_2(tensorised.state, action), + q_values = ( + self.critic_1(tensorised.state, action), + self.critic_2(tensorised.state, action), ) - assert q_values[0].requires_grad and q_values[1].requires_grad - - policy_loss = torch.mean(self._sac_alpha * log_prob - torch.min(*q_values)) - self.actor_optimiser.zero_grad() - policy_loss.backward() - self.post_process_gradients(self.actor.parameters()) - self.actor_optimiser.step() - - out_loss = policy_loss.detach().cpu().item() - - if metric_writer: - metric_writer.scalar("Policy_loss", out_loss) - metric_writer.scalar("q_value_1", q_values[0].cpu().mean().item()) - metric_writer.scalar("q_value_2", q_values[1].cpu().mean().item()) - metric_writer.scalar("policy_stddev", dist.stddev.cpu().mean().item()) - metric_writer.scalar("policy_log_prob", log_prob.cpu().mean().item()) - - if self._auto_tune_sac_alpha: - out_loss += self.update_alpha( - log_prob.detach(), metric_writer=metric_writer - ) + assert q_values[0].requires_grad and q_values[1].requires_grad - return out_loss + policy_loss = torch.mean(self._sac_alpha * log_prob - torch.min(*q_values)) + self.actor_optimiser.zero_grad() + policy_loss.backward() + self.post_process_gradients(self.actor.parameters()) + self.actor_optimiser.step() - def update_alpha( - self, log_prob: torch.Tensor, metric_writer: Writer = None - ) -> float: - """ + out_loss = to_scalar(policy_loss) - @param log_prob: - @type log_prob: + if metric_writer: + metric_writer.scalar("Policy_loss", out_loss) + metric_writer.scalar("q_value_1", to_scalar(q_values[0])) + metric_writer.scalar("q_value_2", to_scalar(q_values[1])) + metric_writer.scalar("policy_stddev", to_scalar(dist.stddev)) + metric_writer.scalar("policy_log_prob", to_scalar(log_prob)) + + if self._auto_tune_sac_alpha: + out_loss += self.update_alpha( + log_prob.detach(), metric_writer=metric_writer + ) + + return out_loss + + def update_alpha( + self, log_prob: torch.Tensor, metric_writer: Writer = None + ) -> float: + """ + +@param log_prob: +@type log_prob: @param tensorised: @param metric_writer: @return: """ - assert not log_prob.requires_grad + assert not log_prob.requires_grad - alpha_loss = -torch.mean( - self._log_sac_alpha * (log_prob + self._target_entropy) + alpha_loss = -torch.mean( + self._log_sac_alpha * (log_prob + self._target_entropy) ) - self.sac_alpha_optimiser.zero_grad() - alpha_loss.backward() - self.post_process_gradients(self._log_sac_alpha) - self.sac_alpha_optimiser.step() + self.sac_alpha_optimiser.zero_grad() + alpha_loss.backward() + self.post_process_gradients(self._log_sac_alpha) + self.sac_alpha_optimiser.step() - self._sac_alpha = self._log_sac_alpha.exp() + self._sac_alpha = self._log_sac_alpha.exp() - out_loss = alpha_loss.detach().cpu().item() + out_loss = alpha_loss.detach().cpu().item() - if metric_writer: - metric_writer.scalar("Sac_Alpha_Loss", out_loss) - metric_writer.scalar("Sac_Alpha", self._sac_alpha.cpu().mean().item()) + if metric_writer: + metric_writer.scalar("Sac_Alpha_Loss", out_loss, self.update_i) + metric_writer.scalar("Sac_Alpha", to_scalar(self._sac_alpha), self.update_i) - return out_loss + return out_loss - def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> float: - """ + def _update(self, *args, metric_writer: Writer = MockWriter(), **kwargs) -> float: + """ @param args: @param metric_writer: @param kwargs: @return: """ - accum_loss = 0 - for i in tqdm( - range(self._num_inner_updates), desc="#Inner update", leave=False + accum_loss = 0 + for ith_inner_update in tqdm( + range(self._num_inner_updates), desc="Inner update #", leave=False, postfix=f"Agent update #{self.update_i}" ): - self.inner_update_i += 1 - batch = self._memory_buffer.sample(self._batch_size) - tensorised = TransitionPoint( - *[to_tensor(a, device=self._device) for a in batch] + self.inner_update_i += 1 + batch = self._memory_buffer.sample(self._batch_size) + tensorised = TransitionPoint( + *[to_tensor(a, device=self._device) for a in batch] + ) + + with frozen_parameters(self.actor.parameters()): + accum_loss += self.update_critics( + tensorised, metric_writer=metric_writer ) - with frozen_parameters(self.actor.parameters()): - accum_loss += self.update_critics( - tensorised, metric_writer=metric_writer - ) - - with frozen_parameters( - itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) - ): - accum_loss += self.update_actor(tensorised, metric_writer=metric_writer) + with frozen_parameters( + itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()) + ): + accum_loss += self.update_actor(tensorised, metric_writer=metric_writer) - if is_zero_or_mod_zero(self._target_update_interval, self.inner_update_i): - self.update_targets(self._copy_percentage, metric_writer=metric_writer) + if is_zero_or_mod_zero(self._target_update_interval, self.inner_update_i): + self.update_targets(self._copy_percentage, metric_writer=metric_writer) - if metric_writer: - metric_writer.scalar("Accum_loss", accum_loss) - metric_writer.scalar("num_inner_updates_i", i) + if metric_writer: + metric_writer.scalar("Accum_loss", accum_loss, self.update_i) + metric_writer.scalar("num_inner_updates_i", ith_inner_update, self.update_i) - return accum_loss + return accum_loss - def update_targets( - self, copy_percentage: float = 0.005, *, metric_writer: Writer = None - ) -> None: - """ + def update_targets( + self, copy_percentage: float = 0.005, *, metric_writer: Writer = None + ) -> None: + """ Interpolation factor in polyak averaging for target networks. Target networks are updated towards main networks according to: @@ -438,22 +431,22 @@ def update_targets( where \rho is polyak. (Always between 0 and 1, usually close to 1.) - @param metric_writer: - @type metric_writer: +@param metric_writer: +@type metric_writer: @param copy_percentage: @return: """ - if metric_writer: - metric_writer.blip("Target Models Synced", self.update_i) + if metric_writer: + metric_writer.blip("Target Models Synced", self.update_i) - update_target( - target_model=self.critic_1_target, - source_model=self.critic_1, - copy_percentage=copy_percentage, + update_target( + target_model=self.critic_1_target, + source_model=self.critic_1, + copy_percentage=copy_percentage, ) - update_target( - target_model=self.critic_2_target, - source_model=self.critic_2, - copy_percentage=copy_percentage, + update_target( + target_model=self.critic_2_target, + source_model=self.critic_2, + copy_percentage=copy_percentage, ) diff --git a/neodroidagent/agents/torch_agents/model_free/on_policy/ddpg_agent.py b/neodroidagent/agents/torch_agents/model_free/on_policy/ddpg_agent.py index fe6960ce..1e54f32d 100644 --- a/neodroidagent/agents/torch_agents/model_free/on_policy/ddpg_agent.py +++ b/neodroidagent/agents/torch_agents/model_free/on_policy/ddpg_agent.py @@ -21,12 +21,11 @@ from neodroidagent.utilities import ( ActionSpaceNotSupported, OrnsteinUhlenbeckProcess, - is_zero_or_mod_zero, update_target, ) from numpy import mean from tqdm import tqdm -from warg import GDKC, drop_unused_kws, super_init_pass_on_kws +from warg import GDKC, drop_unused_kws, is_zero_or_mod_zero, super_init_pass_on_kws __author__ = "Christian Heider Nielsen" __doc__ = r""" diff --git a/neodroidagent/agents/torch_agents/model_free/on_policy/ppo_agent.py b/neodroidagent/agents/torch_agents/model_free/on_policy/ppo_agent.py index cde4412a..d96ac31b 100644 --- a/neodroidagent/agents/torch_agents/model_free/on_policy/ppo_agent.py +++ b/neodroidagent/agents/torch_agents/model_free/on_policy/ppo_agent.py @@ -5,31 +5,28 @@ import numpy import torch -from torch.distributions import Distribution -from torch.nn.functional import mse_loss -from torch.optim import Optimizer -from tqdm import tqdm - -from draugr.writers import MockWriter, Writer -from draugr.torch_utilities import freeze_model, to_tensor - from draugr import mean_accumulator, shuffled_batches +from draugr.torch_utilities import freeze_model, to_scalar, to_tensor +from draugr.writers import MockWriter, Writer from neodroid.utilities import ActionSpace, ObservationSpace, SignalSpace from neodroidagent.agents.agent import TogglableLowHigh from neodroidagent.agents.torch_agents.torch_agent import TorchAgent from neodroidagent.common import ( - ActorCriticMLP, - CategoricalActorCriticMLP, - TransitionPointTrajectoryBuffer, - ValuedTransitionPoint, -) + ActorCriticMLP, + CategoricalActorCriticMLP, + TransitionPointTrajectoryBuffer, + ValuedTransitionPoint, + ) from neodroidagent.utilities import ( - ActionSpaceNotSupported, - is_none_or_zero_or_negative_or_mod_zero, - torch_compute_gae, - update_target, -) -from warg import GDKC, drop_unused_kws, super_init_pass_on_kws + ActionSpaceNotSupported, + torch_compute_gae, + update_target, + ) +from torch.distributions import Distribution +from torch.nn.functional import mse_loss +from torch.optim import Optimizer +from tqdm import tqdm +from warg import GDKC, drop_unused_kws, is_none_or_zero_or_negative_or_mod_zero, super_init_pass_on_kws __author__ = "Christian Heider Nielsen" __doc__ = r""" @@ -39,7 +36,7 @@ @super_init_pass_on_kws class ProximalPolicyOptimizationAgent(TorchAgent): - r""" + r""" PPO, Proximal Policy Optimization method https://arxiv.org/abs/1707.06347 - PPO @@ -49,27 +46,27 @@ class ProximalPolicyOptimizationAgent(TorchAgent): """ - def __init__( - self, - discount_factor: float = 0.95, - gae_lambda: float = 0.95, - entropy_reg_coef: float = 0, - value_reg_coef: float = 5e-1, - num_inner_updates: int = 10, - mini_batch_size: int = 64, - update_target_interval: int = 1, - surrogate_clipping_value: float = 2e-1, - copy_percentage: float = 1.0, - target_kl: float = 1e-2, - memory_buffer: Any = TransitionPointTrajectoryBuffer(), - critic_criterion: callable = mse_loss, - optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4), - continuous_arch_spec: GDKC = GDKC(constructor=ActorCriticMLP), - discrete_arch_spec: GDKC = GDKC(constructor=CategoricalActorCriticMLP), - gradient_norm_clipping: TogglableLowHigh = TogglableLowHigh(True, 0, 0.5), - **kwargs - ) -> None: - """ + def __init__( + self, + discount_factor: float = 0.95, + gae_lambda: float = 0.95, + entropy_reg_coef: float = 0, + value_reg_coef: float = 5e-1, + num_inner_updates: int = 10, + mini_batch_size: int = 64, + update_target_interval: int = 1, + surrogate_clipping_value: float = 2e-1, + copy_percentage: float = 1.0, + target_kl: float = 1e-2, + memory_buffer: Any = TransitionPointTrajectoryBuffer(), + critic_criterion: callable = mse_loss, + optimiser_spec: GDKC = GDKC(constructor=torch.optim.Adam, lr=3e-4), + continuous_arch_spec: GDKC = GDKC(constructor=ActorCriticMLP), + discrete_arch_spec: GDKC = GDKC(constructor=CategoricalActorCriticMLP), + gradient_norm_clipping: TogglableLowHigh = TogglableLowHigh(True, 0, 0.5), + **kwargs + ) -> None: + """ :param discount_factor: :param gae_lambda: @@ -94,40 +91,40 @@ def __init__( :param exploration_epsilon_decay: :param kwargs: """ - super().__init__(gradient_norm_clipping=gradient_norm_clipping, **kwargs) - - assert 0 <= discount_factor <= 1.0 - assert 0 <= gae_lambda <= 1.0 - - self._copy_percentage = copy_percentage - self._memory_buffer = memory_buffer - self._optimiser_spec: GDKC = optimiser_spec - self._continuous_arch_spec = continuous_arch_spec - self._discrete_arch_spec = discrete_arch_spec - - self._discount_factor = discount_factor - self._gae_lambda = gae_lambda - self._target_kl = target_kl - - self._mini_batch_size = mini_batch_size - self._entropy_reg_coef = entropy_reg_coef - self._value_reg_coef = value_reg_coef - self._num_inner_updates = num_inner_updates - self._update_target_interval = update_target_interval - self._critic_criterion = critic_criterion - self._surrogate_clipping_value = surrogate_clipping_value - self.inner_update_i = 0 - - @drop_unused_kws - def __build__( - self, - observation_space: ObservationSpace, - action_space: ActionSpace, - signal_space: SignalSpace, - metric_writer: Writer = MockWriter(), - print_model_repr: bool = True, - ) -> None: - """ + super().__init__(gradient_norm_clipping=gradient_norm_clipping, **kwargs) + + assert 0 <= discount_factor <= 1.0 + assert 0 <= gae_lambda <= 1.0 + + self._copy_percentage = copy_percentage + self._memory_buffer = memory_buffer + self._optimiser_spec: GDKC = optimiser_spec + self._continuous_arch_spec = continuous_arch_spec + self._discrete_arch_spec = discrete_arch_spec + + self._discount_factor = discount_factor + self._gae_lambda = gae_lambda + self._target_kl = target_kl + + self._mini_batch_size = mini_batch_size + self._entropy_reg_coefficient = entropy_reg_coef + self._value_reg_coefficient = value_reg_coef + self._num_inner_updates = num_inner_updates + self._update_target_interval = update_target_interval + self._critic_criterion = critic_criterion + self._surrogate_clipping_value = surrogate_clipping_value + self.inner_update_i = 0 + + @drop_unused_kws + def __build__( + self, + observation_space: ObservationSpace, + action_space: ActionSpace, + signal_space: SignalSpace, + metric_writer: Writer = MockWriter(), + print_model_repr: bool = True, + ) -> None: + """ @param observation_space: @param action_space: @@ -136,278 +133,270 @@ def __build__( @param print_model_repr: @return: """ - if action_space.is_mixed: - raise ActionSpaceNotSupported() - elif action_space.is_continuous: - self._continuous_arch_spec.kwargs["input_shape"] = self._input_shape - self._continuous_arch_spec.kwargs["output_shape"] = self._output_shape - self.actor_critic = self._continuous_arch_spec().to(self._device) - else: - self._discrete_arch_spec.kwargs["input_shape"] = self._input_shape - self._discrete_arch_spec.kwargs["output_shape"] = self._output_shape - self.actor_critic = self._discrete_arch_spec().to(self._device) - - self._target_actor_critic = copy.deepcopy(self.actor_critic).to(self._device) - freeze_model(self._target_actor_critic, True, True) - - self._optimiser = self._optimiser_spec(self.actor_critic.parameters()) - - @property - def models(self) -> dict: - """ + if action_space.is_mixed: + raise ActionSpaceNotSupported() + elif action_space.is_continuous: + self._continuous_arch_spec.kwargs["input_shape"] = self._input_shape + self._continuous_arch_spec.kwargs["output_shape"] = self._output_shape + self.actor_critic = self._continuous_arch_spec().to(self._device) + else: + self._discrete_arch_spec.kwargs["input_shape"] = self._input_shape + self._discrete_arch_spec.kwargs["output_shape"] = self._output_shape + self.actor_critic = self._discrete_arch_spec().to(self._device) + + self._target_actor_critic = copy.deepcopy(self.actor_critic).to(self._device) + freeze_model(self._target_actor_critic, True, True) + + self._optimiser = self._optimiser_spec(self.actor_critic.parameters()) + + @property + def models(self) -> dict: + """ @return: """ - return {"actor_critic": self.actor_critic} + return {"actor_critic":self.actor_critic} - @property - def optimisers(self) -> Dict[str, Optimizer]: - return {"_optimiser": self._optimiser} + @property + def optimisers(self) -> Dict[str, Optimizer]: + return {"_optimiser":self._optimiser} - # region Protected + # region Protected - @drop_unused_kws - def _sample(self, state: numpy.ndarray, deterministic: bool = False) -> Tuple: - """ + @drop_unused_kws + def _sample(self, state: numpy.ndarray, deterministic: bool = False) -> Tuple: + """ @param state: @return: """ - with torch.no_grad(): - dist, val_est = self._target_actor_critic( - to_tensor(state, device=self._device, dtype=torch.float) - ) + with torch.no_grad(): + dist, val_est = self._target_actor_critic( + to_tensor(state, device=self._device, dtype=torch.float) + ) - if deterministic: - if self.action_space.is_discrete: - action = dist.logits.max(-1)[-1] - else: - action = dist.mean - else: - action = dist.sample() + if deterministic: + if self.action_space.is_discrete: + action = dist.logits.max(-1)[-1] + else: + action = dist.mean + else: + action = dist.sample() - if self.action_space.is_discrete: - action = action.unsqueeze(-1) - else: - pass # TODO: Figure out - # action = torch.clamp(action, -1, 1) + if self.action_space.is_discrete: + action = action.unsqueeze(-1) - return action.detach(), dist, val_est.detach() + return action.detach(), dist, val_est.detach() - def extract_action(self, sample: torch.tensor) -> numpy.ndarray: - """ + def extract_action(self, sample: torch.tensor) -> numpy.ndarray: + """ @param sample: @return: """ - return sample[0].to("cpu").numpy() - - @drop_unused_kws - def _remember( - self, - *, - signal: Any, - terminated: Any, - state: Any, - successor_state: Any, - sample: Any - ) -> None: - self._memory_buffer.add_transition_point( - ValuedTransitionPoint( - state, - sample[0], - successor_state, - signal, - terminated, - sample[1], - sample[2], + return sample[0].to("cpu").numpy() + + @drop_unused_kws + def _remember( + self, + *, + signal: Any, + terminated: Any, + state: Any, + successor_state: Any, + sample: Any + ) -> None: + self._memory_buffer.add_transition_point( + ValuedTransitionPoint( + state, + sample[0], + successor_state, + signal, + terminated, + sample[1], + sample[2], ) ) - def _update_targets( - self, copy_percentage: float, *, metric_writer: Writer = None - ) -> None: - """ + def _update_targets( + self, copy_percentage: float, *, metric_writer: Writer = None + ) -> None: + """ @param copy_percentage: @return: """ - if metric_writer: - metric_writer.blip("Target Model Synced", self.update_i) + if metric_writer: + metric_writer.blip("Target Model Synced", self.update_i) - update_target( - target_model=self._target_actor_critic, - source_model=self.actor_critic, - copy_percentage=copy_percentage, + update_target( + target_model=self._target_actor_critic, + source_model=self.actor_critic, + copy_percentage=copy_percentage, ) - def get_log_prob(self, dist: Distribution, action: torch.tensor) -> torch.tensor: - if self.action_space.is_discrete: - return dist.log_prob(action.squeeze(-1)).unsqueeze(-1) - else: - return dist.log_prob(action).sum(axis=-1, keepdims=True) + def get_log_prob(self, dist: Distribution, action: torch.tensor) -> torch.tensor: + if self.action_space.is_discrete: + return dist.log_prob(action.squeeze(-1)).unsqueeze(-1) + else: + return dist.log_prob(action).sum(axis=-1, keepdims=True) - def _prepare_transitions(self): - transitions = self._memory_buffer.sample() - self._memory_buffer.clear() + def _prepare_transitions(self): + transitions = self._memory_buffer.sample() + self._memory_buffer.clear() - signal = to_tensor(transitions.signal, device=self._device) - non_terminal = to_tensor( - transitions.non_terminal_numerical, device=self._device + signal = to_tensor(transitions.signal, device=self._device) + non_terminal = to_tensor( + transitions.non_terminal_numerical, device=self._device ) - state = to_tensor(transitions.state, device=self.device) - action = to_tensor(transitions.action, device=self.device) - value_estimate_target = to_tensor( - transitions.value_estimate, device=self._device + state = to_tensor(transitions.state, device=self.device) + action = to_tensor(transitions.action, device=self.device) + value_estimate_target = to_tensor( + transitions.value_estimate, device=self._device ) - action_log_prob_old = to_tensor( - [ - self.get_log_prob(dist, a) - for dist, a in zip(transitions.distribution, transitions.action) + action_log_prob_old = to_tensor( + [ + self.get_log_prob(dist, a) + for dist, a in zip(transitions.distribution, transitions.action) ], - device=self.device, + device=self.device, ) - with torch.no_grad(): - *_, successor_value_estimate = self.actor_critic( - to_tensor((transitions.successor_state[-1],), device=self.device) - ) - value_estimate_target = torch.cat( - (value_estimate_target, successor_value_estimate), dim=0 - ) - - discounted_signal, advantage = torch_compute_gae( - signal, - non_terminal, - value_estimate_target, - discount_factor=self._discount_factor, - gae_lambda=self._gae_lambda, - device=self.device, - ) - - return ( - state.flatten(0, 1), - action.flatten(0, 1), - action_log_prob_old.flatten(0, 1), - discounted_signal.flatten(0, 1), - advantage.flatten(0, 1), + with torch.no_grad(): + *_, successor_value_estimate = self.actor_critic( + to_tensor((transitions.successor_state[-1],), device=self.device) + ) + value_estimate_target = torch.cat( + (value_estimate_target, successor_value_estimate), dim=0 + ) + + discounted_signal, advantage = torch_compute_gae( + signal, + non_terminal, + value_estimate_target, + discount_factor=self._discount_factor, + gae_lambda=self._gae_lambda, + device=self.device, + ) + + return ( + state.flatten(0, 1), + action.flatten(0, 1), + action_log_prob_old.flatten(0, 1), + discounted_signal.flatten(0, 1), + advantage.flatten(0, 1), ) - @drop_unused_kws - def _update(self, metric_writer: Writer = MockWriter()) -> float: - """ + @drop_unused_kws + def _update(self, metric_writer: Writer = MockWriter()) -> float: + """ @param metric_writer: @return: """ - transitions = self._prepare_transitions() + transitions = self._prepare_transitions() - accum_loss = mean_accumulator() - for i in tqdm( - range(self._num_inner_updates), desc="#Inner updates", leave=False + accum_loss = mean_accumulator() + for ith_inner_update in tqdm( + range(self._num_inner_updates), desc="#Inner updates", leave=False ): - self.inner_update_i += 1 - loss, early_stop_inner = self.inner_update( - *transitions, metric_writer=metric_writer - ) - accum_loss.send(loss) - - if is_none_or_zero_or_negative_or_mod_zero( - self._update_target_interval, self.inner_update_i - ): - self._update_targets(self._copy_percentage, metric_writer=metric_writer) - - if early_stop_inner: - break - - mean_loss = next(accum_loss) - - if metric_writer: - metric_writer.scalar("Inner Updates", i) - metric_writer.scalar("Mean Loss", mean_loss) - - return mean_loss - - def _policy_loss( - self, - new_distribution, - action_batch, - log_prob_batch_old, - adv_batch, - *, - metric_writer: Writer = None - ): - action_log_probs_new = self.get_log_prob(new_distribution, action_batch) - ratio = torch.exp(action_log_probs_new - log_prob_batch_old) - # if ratio explodes to (inf or Nan) due to the residual being to large check initialisation! - # Generated action probs from (new policy) and (old policy). - # Values of [0..1] means that actions less likely with the new policy, - # while values [>1] mean action a more likely now - clamped_ratio = torch.clamp( - ratio, - min=1.0 - self._surrogate_clipping_value, - max=1.0 + self._surrogate_clipping_value, + self.inner_update_i += 1 + loss, early_stop_inner = self.inner_update( + *transitions, metric_writer=metric_writer + ) + accum_loss.send(loss) + + if is_none_or_zero_or_negative_or_mod_zero( + self._update_target_interval, self.inner_update_i + ): + self._update_targets(self._copy_percentage, metric_writer=metric_writer) + + if early_stop_inner: + break + + mean_loss = next(accum_loss) + + if metric_writer: + metric_writer.scalar("Inner Updates", ith_inner_update) + metric_writer.scalar("Mean Loss", mean_loss) + + return mean_loss + + def _policy_loss( + self, + new_distribution, + action_batch, + log_prob_batch_old, + adv_batch, + *, + metric_writer: Writer = None + ): + action_log_probs_new = self.get_log_prob(new_distribution, action_batch) + ratio = torch.exp(action_log_probs_new - log_prob_batch_old) + # if ratio explodes to (inf or Nan) due to the residual being to large check initialisation! + # Generated action probabilities from (new policy) and (old policy). + # Values of [0..1] means that actions less likely with the new policy, + # while values [>1] mean action a more likely now + clamped_ratio = torch.clamp( + ratio, + min=1.0 - self._surrogate_clipping_value, + max=1.0 + self._surrogate_clipping_value, ) - policy_loss = -torch.min(ratio * adv_batch, clamped_ratio * adv_batch).mean() - entropy_loss = new_distribution.entropy().mean() * self._entropy_reg_coef + policy_loss = -torch.min(ratio * adv_batch, clamped_ratio * adv_batch).mean() + entropy_loss = new_distribution.entropy().mean() * self._entropy_reg_coefficient - with torch.no_grad(): - approx_kl = ( - (log_prob_batch_old - action_log_probs_new).mean().detach().cpu().item() - ) + with torch.no_grad(): + approx_kl = to_scalar((log_prob_batch_old - action_log_probs_new)) - if metric_writer: - metric_writer.scalar("ratio", ratio.mean().detach().cpu().item()) - metric_writer.scalar("entropy_loss", entropy_loss.detach().cpu().item()) - metric_writer.scalar( - "clamped_ratio", clamped_ratio.mean().detach().cpu().item() - ) + if metric_writer: + metric_writer.scalar("ratio", to_scalar(ratio)) + metric_writer.scalar("entropy_loss", to_scalar(entropy_loss)) + metric_writer.scalar("clamped_ratio", to_scalar(clamped_ratio) + ) - return policy_loss - entropy_loss, approx_kl + return policy_loss - entropy_loss, approx_kl - def inner_update(self, *transitions, metric_writer: Writer = None) -> Tuple: - batch_generator = shuffled_batches( - *transitions, size=transitions[0].size(0), batch_size=self._mini_batch_size + def inner_update(self, *transitions, metric_writer: Writer = None) -> Tuple: + batch_generator = shuffled_batches( + *transitions, size=transitions[0].size(0), batch_size=self._mini_batch_size ) - for ( - state, - action, - log_prob_old, - discounted_signal, - advantage, + for ( + state, + action, + log_prob_old, + discounted_signal, + advantage, ) in batch_generator: - new_distribution, value_estimate = self.actor_critic(state) - - policy_loss, approx_kl = self._policy_loss( - new_distribution, - action, - log_prob_old, - advantage, - metric_writer=metric_writer, - ) - critic_loss = ( - self._critic_criterion(value_estimate, discounted_signal) - * self._value_reg_coef - ) - - loss = policy_loss + critic_loss - - self._optimiser.zero_grad() - loss.backward() - self.post_process_gradients(self.actor_critic.parameters()) - self._optimiser.step() - - if metric_writer: - metric_writer.scalar( - "policy_stddev", new_distribution.stddev.cpu().mean().item() - ) - metric_writer.scalar("policy_loss", policy_loss.detach().cpu().item()) - metric_writer.scalar("critic_loss", critic_loss.detach().cpu().item()) - metric_writer.scalar("policy_approx_kl", approx_kl) - metric_writer.scalar("merged_loss", loss.detach().cpu().item()) - - if approx_kl > 1.5 * self._target_kl: - return loss.detach().cpu().item(), True - return loss.detach().cpu().item(), False + new_distribution, value_estimate = self.actor_critic(state) + + policy_loss, approx_kl = self._policy_loss( + new_distribution, + action, + log_prob_old, + advantage, + metric_writer=metric_writer, + ) + critic_loss = ( + self._critic_criterion(value_estimate, discounted_signal) + * self._value_reg_coefficient + ) + + loss = policy_loss + critic_loss + + self._optimiser.zero_grad() + loss.backward() + self.post_process_gradients(self.actor_critic.parameters()) + self._optimiser.step() + + if metric_writer: + metric_writer.scalar("policy_stddev", to_scalar(new_distribution.stddev)) + metric_writer.scalar("policy_loss", to_scalar(policy_loss)) + metric_writer.scalar("critic_loss", to_scalar(critic_loss)) + metric_writer.scalar("policy_approx_kl", approx_kl) + metric_writer.scalar("merged_loss", to_scalar(loss)) + + if approx_kl > 1.5 * self._target_kl: + return to_scalar(loss), True + return to_scalar(loss), False diff --git a/neodroidagent/agents/torch_agents/torch_agent.py b/neodroidagent/agents/torch_agents/torch_agent.py index 8e29aea7..b48bd784 100644 --- a/neodroidagent/agents/torch_agents/torch_agent.py +++ b/neodroidagent/agents/torch_agents/torch_agent.py @@ -131,17 +131,17 @@ def build( if metric_writer: try: + model = copy.deepcopy(w).to("cpu") + dummy_input = model.sample_input() + sprint(f'{k} input: {dummy_input.shape}') + import contextlib with contextlib.redirect_stdout( None - ): # So much use frame info printed... Suppress it - model = copy.deepcopy(w) - model.to("cpu") - - dummy_in = torch.empty(1, *model.input_shape, device="cpu") + ): # So much useless frame info printed... Suppress it if isinstance(metric_writer, GraphWriterMixin): - metric_writer.graph(model, dummy_in, verbose=verbose) + metric_writer.graph(model, dummy_input, verbose=verbose) # No naming available at moment... except RuntimeError as ex: sprint( f"Tensorboard(Pytorch) does not support you model! No graph added: {str(ex).splitlines()[0]}", diff --git a/neodroidagent/common/architectures/README.md b/neodroidagent/common/architectures/README.md index eac503af..4823bae7 100644 --- a/neodroidagent/common/architectures/README.md +++ b/neodroidagent/common/architectures/README.md @@ -1,3 +1,2 @@ -# Architectures - +# Torch Architectures - [Multi Layer Perceptron](mlp.py) \ No newline at end of file diff --git a/neodroidagent/common/architectures/architecture.py b/neodroidagent/common/architectures/architecture.py index b514e33c..2eb395c9 100644 --- a/neodroidagent/common/architectures/architecture.py +++ b/neodroidagent/common/architectures/architecture.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from abc import ABC -from typing import Sequence +from abc import ABC, abstractmethod +from typing import Any, Sequence +import torch from torch import nn __author__ = "Christian Heider Nielsen" @@ -16,48 +17,51 @@ class Architecture(nn.Module, ABC): - """ - """ - @drop_unused_kws - def __init__(self, input_shape: Sequence[int], output_shape: Sequence[int]): - super().__init__() - self._input_shape = input_shape - self._output_shape = output_shape +""" - @property - def input_shape(self): - """ + @drop_unused_kws + def __init__(self, input_shape: Sequence[int], output_shape: Sequence[int]): + super().__init__() + self._input_shape = input_shape + self._output_shape = output_shape - @return: - @rtype: + @property + def input_shape(self) -> Sequence[int]: """ - return self._input_shape - @property - def output_shape(self): - """ +@return: +@rtype: +""" + return self._input_shape - @return: - @rtype: + @property + def output_shape(self) -> Sequence[int]: """ - return self._output_shape - def __repr__(self): - num_trainable_params = get_num_parameters(self, only_trainable=True) - num_params = get_num_parameters(self, only_trainable=False) +@return: +@rtype: +""" + return self._output_shape + + def sample_input(self)-> Any: + return torch.empty(1, *self.input_shape, device="cpu") + + def __repr__(self): + num_trainable_params = get_num_parameters(self, only_trainable=True) + num_params = get_num_parameters(self, only_trainable=False) - dict_repr = indent_lines(f"{self.__dict__}") + dict_repr = indent_lines(f"{self.__dict__}") - trainable_params_str = indent_lines( - f"trainable/num_params: {num_trainable_params}/{num_params}\n" + trainable_params_str = indent_lines( + f"trainable/num_params: {num_trainable_params}/{num_params}\n" ) - return f"{super().__repr__()}\n{dict_repr}\n{trainable_params_str}" + return f"{super().__repr__()}\n{dict_repr}\n{trainable_params_str}" if __name__ == "__main__": - a = Architecture() + a = Architecture() - print(a) + print(a) diff --git a/neodroidagent/common/architectures/experimental/recurrent.py b/neodroidagent/common/architectures/experimental/recurrent.py index e32847a4..fadae98d 100644 --- a/neodroidagent/common/architectures/experimental/recurrent.py +++ b/neodroidagent/common/architectures/experimental/recurrent.py @@ -64,10 +64,10 @@ def _forward_gru(self, x, hxs, masks): T = int(x.size(0) / N) # unflatten - x = x.view(T, N, x.size(1)) + x = x.reshape(T, N, x.size(1)) # Same deal with masks - masks = masks.view(T, N, 1) + masks = masks.reshape(T, N, 1) outputs = [] for i in range(T): @@ -78,6 +78,6 @@ def _forward_gru(self, x, hxs, masks): # x is a (T, N, -1) tensor x = torch.stack(outputs, dim=0) # flatten - x = x.view(T * N, -1) + x = x.reshape(T * N, -1) return x, hxs diff --git a/neodroidagent/common/architectures/mlp_variants/concatination.py b/neodroidagent/common/architectures/mlp_variants/concatination.py index 72229520..f4627739 100644 --- a/neodroidagent/common/architectures/mlp_variants/concatination.py +++ b/neodroidagent/common/architectures/mlp_variants/concatination.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Iterable, List, Sequence +from typing import Any, Iterable, List, Sequence import numpy import torch @@ -12,6 +12,8 @@ __all__ = ["PreConcatInputMLP", "LateConcatInputMLP"] +from warg import passes_kws_to + class PreConcatInputMLP(MLP): """ @@ -24,6 +26,7 @@ def __init__(self, input_shape: Sequence = (2,), **kwargs): super().__init__(input_shape=input_shape, **kwargs) + @passes_kws_to(MLP.forward) def forward(self, *x, **kwargs) -> List: """ @@ -36,7 +39,6 @@ def forward(self, *x, **kwargs) -> List: """ return super().forward(torch.cat(x, dim=-1), **kwargs) - class LateConcatInputMLP(MLP): """ Late fusion, quite a botch job, only a single addition block fusion supported for now @@ -67,6 +69,7 @@ def __init__( torch.nn.Linear(s, t), torch.nn.ReLU(), torch.nn.Linear(t, output_shape[-1]) ) + @passes_kws_to(MLP.forward) def forward(self, *x, **kwargs) -> torch.tensor: """ diff --git a/neodroidagent/common/memory/data_structures/expandable_circular_buffer.py b/neodroidagent/common/memory/data_structures/expandable_circular_buffer.py index 0704c153..5a24a46a 100644 --- a/neodroidagent/common/memory/data_structures/expandable_circular_buffer.py +++ b/neodroidagent/common/memory/data_structures/expandable_circular_buffer.py @@ -12,7 +12,7 @@ __all__ = ["ExpandableCircularBuffer"] from neodroidagent.common.memory.memory import Memory -from neodroidagent.utilities import is_none_or_zero_or_negative +from warg import is_none_or_zero_or_negative class ExpandableCircularBuffer(Memory): diff --git a/neodroidagent/common/memory/rollout_storage.py b/neodroidagent/common/memory/rollout_storage.py index fb4eab98..07c40a42 100644 --- a/neodroidagent/common/memory/rollout_storage.py +++ b/neodroidagent/common/memory/rollout_storage.py @@ -152,12 +152,12 @@ def feed_forward_generator(self, advantages, mini_batches): drop_last=False, ) for i in sampler: - observations_batch = self.observations[:-1].view( + observations_batch = self.observations[:-1].reshape( -1, *self.observations.size()[2:] )[i] - actions_batch = self.actions.view(-1, self.actions.size(-1))[i] - return_batch = self.returns[:-1].view(-1, 1)[i] - old_action_log_probs_batch = self.action_log_probs.view(-1, 1)[i] - adv_targ = advantages.view(-1, 1)[i] + actions_batch = self.actions.reshape(-1, self.actions.size(-1))[i] + return_batch = self.returns[:-1].reshape(-1, 1)[i] + old_action_log_probs_batch = self.action_log_probs.reshape(-1, 1)[i] + adv_targ = advantages.reshape(-1, 1)[i] yield observations_batch, actions_batch, return_batch, old_action_log_probs_batch, adv_targ diff --git a/neodroidagent/common/session_factory/vertical/linear.py b/neodroidagent/common/session_factory/vertical/linear.py index e70f856b..829a2c9a 100644 --- a/neodroidagent/common/session_factory/vertical/linear.py +++ b/neodroidagent/common/session_factory/vertical/linear.py @@ -54,3 +54,4 @@ def __init__( environments = environment super().__init__(environments=environments, procedure=procedure, **kwargs) + diff --git a/neodroidagent/common/session_factory/vertical/parallel.py b/neodroidagent/common/session_factory/vertical/parallel.py index f6156b5a..73f0ae28 100644 --- a/neodroidagent/common/session_factory/vertical/parallel.py +++ b/neodroidagent/common/session_factory/vertical/parallel.py @@ -42,7 +42,7 @@ def __init__( assert environment_name != "" environments = NeodroidVectorGymEnvironment( environment_name=environment_name, - num__envs=num_envs, + num_envs=num_envs, auto_reset_on_terminal_state=auto_reset_on_terminal_state, ) elif isinstance(environment, bool): diff --git a/neodroidagent/common/session_factory/vertical/procedures/procedure_specification.py b/neodroidagent/common/session_factory/vertical/procedures/procedure_specification.py index 7c827c52..507433a5 100644 --- a/neodroidagent/common/session_factory/vertical/procedures/procedure_specification.py +++ b/neodroidagent/common/session_factory/vertical/procedures/procedure_specification.py @@ -23,20 +23,25 @@ def __init__( agent: Agent, *, environment: Environment, - on_improvement_callbacks: List = [], - save_best_throughtout_training: bool = True + on_improvement_callbacks=None, + save_best_throughout_training: bool = True, + train_agent: bool = True ): """ @param agent: @param environment: @param on_improvement_callbacks: -@param save_best_throughtout_training: +@param save_best_throughout_training: """ + if on_improvement_callbacks is None: + on_improvement_callbacks = [] + self.agent = agent self.environment = environment - if save_best_throughtout_training: + if save_best_throughout_training and train_agent: on_improvement_callbacks.append(self.agent.save) + print('Saving best model throughout training') self.on_improvement_callbacks = on_improvement_callbacks @staticmethod diff --git a/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_batched.py b/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_batched.py index 0e359648..001ebc1e 100644 --- a/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_batched.py +++ b/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_batched.py @@ -1,15 +1,11 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from pathlib import Path -from typing import Union +import logging -import torch -import torchsnooper from draugr.writers import MockWriter, Writer from tqdm import tqdm from draugr.metrics.accumulation import mean_accumulator -from draugr.torch_utilities import TensorBoardPytorchWriter __author__ = "Christian Heider Nielsen" @@ -19,8 +15,7 @@ from neodroidagent.common.session_factory.vertical.procedures.procedure_specification import ( Procedure, ) -from neodroidagent.utilities import is_positive_and_mod_zero -from warg.context_wrapper import ContextWrapper +from warg import is_positive_and_mod_zero class OffPolicyBatched(Procedure): @@ -28,7 +23,7 @@ def __call__( self, *, batch_size=1000, - device: Union[str, torch.device], + iterations=10000, stat_frequency=10, render_frequency=10, @@ -39,7 +34,7 @@ def __call__( ) -> None: """ -:param device: + :param log_directory: :param num_steps: :param iterations: @@ -56,7 +51,6 @@ def __call__( @param disable_stdout: @param train_agent: @param kwargs: -@type device: object """ state = self.agent.extract_features(self.environment.reset()) @@ -66,10 +60,10 @@ def __call__( running_mean_action = mean_accumulator() for batch_i in tqdm( - range(1, iterations), leave=False, disable=disable_stdout, desc="Batch #" + range(1, iterations), leave=False, disable=disable_stdout, desc="Batch #",postfix=f"Agent update #{self.agent.update_i}" ): for _ in tqdm( - range(batch_size), leave=False, disable=disable_stdout, desc="Step #" + range(batch_size), leave=False, disable=disable_stdout, desc="Step #", ): sample = self.agent.sample(state) @@ -109,7 +103,7 @@ def __call__( best_running_signal = sig self.call_on_improvement_callbacks(loss=loss, **kwargs) else: - print("no update") + logging.info("no update") if self.early_stop: break diff --git a/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_episodic.py b/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_episodic.py index 1a0f29e0..d93550a6 100644 --- a/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_episodic.py +++ b/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_episodic.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import math from itertools import count from pathlib import Path @@ -19,8 +20,7 @@ from neodroidagent.common.session_factory.vertical.procedures.procedure_specification import ( Procedure, ) -from neodroidagent.utilities import is_positive_and_mod_zero -from warg import drop_unused_kws, passes_kws_to +from warg import drop_unused_kws, passes_kws_to,is_positive_and_mod_zero __author__ = "Christian Heider Nielsen" __all__ = ["rollout_off_policy", "OffPolicyEpisodic"] @@ -53,7 +53,7 @@ def rollout_off_policy( episode_buffer = [] for step_i in tqdm( - count(), desc="Step #", leave=False, disable=not render_environment + count(), desc="Step #", leave=False, disable=not render_environment, postfix=f"Agent update #{agent.update_i}" ): sample = agent.sample( state, deterministic=disallow_random_sample, metric_writer=metric_writer @@ -104,7 +104,7 @@ def rollout_off_policy( if train_agent: agent.update(metric_writer=metric_writer) else: - print("no update") + logging.info("no update") esig = next(episode_signal) diff --git a/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_step_wise.py b/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_step_wise.py index a6fe0663..57225248 100644 --- a/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_step_wise.py +++ b/neodroidagent/common/session_factory/vertical/procedures/training/off_policy_step_wise.py @@ -5,6 +5,7 @@ import torch import torchsnooper +from draugr.drawers import MplDrawer, MockDrawer from draugr.writers import MockWriter, Writer from tqdm import tqdm @@ -18,8 +19,7 @@ from neodroidagent.common.session_factory.vertical.procedures.procedure_specification import ( Procedure, ) -from neodroidagent.utilities import is_positive_and_mod_zero, is_zero_or_mod_below -from warg.context_wrapper import ContextWrapper +from warg import is_positive_and_mod_zero, is_zero_or_mod_below class OffPolicyStepWise(Procedure): @@ -28,7 +28,6 @@ def __call__( *, num_environment_steps=500000, batch_size=128, - device: Union[str, torch.device], stat_frequency=10, render_frequency=10000, initial_observation_period=1000, @@ -37,11 +36,11 @@ def __call__( disable_stdout: bool = False, train_agent: bool = True, metric_writer: Writer = MockWriter(), + rollout_drawer: MplDrawer = MockDrawer(), **kwargs ) -> None: """ -:param device: :param log_directory: :param num_environment_steps: :param stat_frequency: @@ -64,6 +63,8 @@ def __call__( sample = self.agent.sample(state) action = self.agent.extract_action(sample) + + snapshot = self.environment.react(action) successor_state = self.agent.extract_features(snapshot) signal = self.agent.extract_signal(snapshot) @@ -124,6 +125,8 @@ def __call__( and render_frequency != 0 ): self.environment.render() + if rollout_drawer: + rollout_drawer.draw(action) if self.early_stop: break diff --git a/neodroidagent/common/session_factory/vertical/procedures/training/on_policy_episodic.py b/neodroidagent/common/session_factory/vertical/procedures/training/on_policy_episodic.py index 302d382b..452600f4 100644 --- a/neodroidagent/common/session_factory/vertical/procedures/training/on_policy_episodic.py +++ b/neodroidagent/common/session_factory/vertical/procedures/training/on_policy_episodic.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import math from itertools import count from pathlib import Path @@ -8,20 +9,19 @@ import numpy import torch import torchsnooper +from draugr.drawers import DiscreteScrollPlot -from draugr.drawers.drawer import Drawer, MockDrawer +from draugr.drawers.mpldrawer import MplDrawer, MockDrawer from draugr.metrics.accumulation import mean_accumulator, total_accumulator from draugr.writers import MockWriter, Writer from draugr.torch_utilities import TensorBoardPytorchWriter from neodroid.environments.environment import Environment -from neodroid.utilities import EnvironmentSnapshot +from neodroid.utilities import EnvironmentSnapshot, to_one_hot from neodroidagent.agents.agent import Agent from neodroidagent.common.session_factory.vertical.procedures.procedure_specification import ( - Procedure, -) -from neodroidagent.utilities.misc import is_positive_and_mod_zero -from warg.context_wrapper import ContextWrapper -from warg.decorators.kw_passing import drop_unused_kws, passes_kws_to + Procedure, + ) +from warg import is_positive_and_mod_zero, drop_unused_kws, passes_kws_to __author__ = "Christian Heider Nielsen" __all__ = ["rollout_on_policy", "OnPolicyEpisodic"] @@ -36,15 +36,17 @@ def rollout_on_policy( initial_snapshot: EnvironmentSnapshot, env: Environment, *, + rollout_ith: int = None, render_environment: bool = False, metric_writer: Writer = MockWriter(), - rollout_drawer: Drawer = MockDrawer(), + rollout_drawer: MplDrawer = MockDrawer(), train_agent: bool = True, max_length: int = None, disable_stdout: bool = False, -): - """Perform a single rollout until termination in environment + ): + """Perform a single rollout until termination in environment + :param rollout_ith: :param agent: :param rollout_drawer: :param disable_stdout: @@ -63,73 +65,78 @@ def rollout_on_policy( -average_episode_entropy- """ - state = agent.extract_features(initial_snapshot) - running_mean_action = mean_accumulator() - episode_signal = total_accumulator() + state = agent.extract_features(initial_snapshot) + running_mean_action = mean_accumulator() + episode_signal = total_accumulator() - for step_i in tqdm( - count(1), f"Update #{agent.update_i}", leave=False, disable=disable_stdout - ): - sample = agent.sample(state) - action = agent.extract_action(sample) + rollout_description = f'Rollout' + if rollout_ith: + rollout_description += f" #{rollout_ith}" + for step_i in tqdm( + count(1), rollout_description, unit='th step', leave=False, disable=disable_stdout, postfix=f"Agent update #{agent.update_i}" + ): + sample = agent.sample(state) + action = agent.extract_action(sample) - snapshot = env.react(action) + snapshot = env.react(action) - successor_state = agent.extract_features(snapshot) - terminated = snapshot.terminated - signal = agent.extract_signal(snapshot) + successor_state = agent.extract_features(snapshot) + terminated = snapshot.terminated + signal = agent.extract_signal(snapshot) - if train_agent: - agent.remember( - state=state, - signal=signal, - terminated=terminated, - sample=sample, - successor_state=successor_state, - ) + if train_agent: + agent.remember( + state=state, + signal=signal, + terminated=terminated, + sample=sample, + successor_state=successor_state, + ) - state = successor_state + state = successor_state - running_mean_action.send(action.mean()) - episode_signal.send(signal.mean()) + running_mean_action.send(action.mean()) + episode_signal.send(signal.mean()) - if render_environment: - env.render() - # if env.action_space.is_discrete and rollout_drawer: - # rollout_drawer.draw(to_one_hot(agent.output_shape, action)[0]) + if render_environment: + env.render() + if rollout_drawer: + if env.action_space.is_discrete: + action = to_one_hot(agent.output_shape, action) + rollout_drawer.draw(action) - if numpy.array(terminated).all() or (max_length and step_i > max_length): - break + if numpy.array(terminated).all() or (max_length and step_i > max_length): + break - if train_agent: - agent.update(metric_writer=metric_writer) - else: - print("no update") + if train_agent: + agent.update(metric_writer=metric_writer) + else: + logging.info("no update") - episode_return = next(episode_signal) - rma = next(running_mean_action) + episode_return = next(episode_signal) + rma = next(running_mean_action) - if metric_writer: - metric_writer.scalar("duration", step_i, agent.update_i) - metric_writer.scalar("running_mean_action", rma, agent.update_i) - metric_writer.scalar("signal", episode_return, agent.update_i) + if metric_writer: + metric_writer.scalar("duration", step_i, agent.update_i) + metric_writer.scalar("running_mean_action", rma, agent.update_i) + metric_writer.scalar("signal", episode_return, agent.update_i) - return episode_return, step_i + return episode_return, step_i class OnPolicyEpisodic(Procedure): - @passes_kws_to(rollout_on_policy) - def __call__( - self, - *, - iterations: int = 1000, - render_frequency: int = 100, - stat_frequency: int = 10, - disable_stdout: bool = False, - metric_writer: Writer = MockWriter(), - **kwargs, - ): - r""" + @passes_kws_to(rollout_on_policy) + def __call__( + self, + *, + iterations: int = 1000, + render_frequency: int = 100, + stat_frequency: int = 10, + disable_stdout: bool = False, + metric_writer: Writer = MockWriter(), + **kwargs, + ): + r""" :param log_directory: :param disable_stdout: Whether to disable stdout statements or not :type disable_stdout: bool @@ -143,30 +150,31 @@ def __call__( :rtype: TR """ - E = range(1, iterations) - E = tqdm(E, desc="Rollout #", leave=False) - - best_episode_return = -math.inf - for episode_i in E: - initial_state = self.environment.reset() - - kwargs.update( - render_environment=is_positive_and_mod_zero(render_frequency, episode_i) - ) - ret, *_ = rollout_on_policy( - self.agent, - initial_state, - self.environment, - metric_writer=is_positive_and_mod_zero( - stat_frequency, episode_i, ret=metric_writer - ), - disable_stdout=disable_stdout, - **kwargs, - ) - - if best_episode_return < ret: - best_episode_return = ret - self.call_on_improvement_callbacks(**kwargs) - - if self.early_stop: - break + E = range(1, iterations) + E = tqdm(E, desc="Rollout #", leave=False) + + best_episode_return = -math.inf + for episode_i in E: + initial_state = self.environment.reset() + + kwargs.update( + render_environment=is_positive_and_mod_zero(render_frequency, episode_i) + ) + ret, *_ = rollout_on_policy( + self.agent, + initial_state, + self.environment, + rollout_íth=episode_i, + metric_writer=is_positive_and_mod_zero( + stat_frequency, episode_i, ret=metric_writer + ), + disable_stdout=disable_stdout, + **kwargs, + ) + + if best_episode_return < ret: + best_episode_return = ret + self.call_on_improvement_callbacks(**kwargs) + + if self.early_stop: + break diff --git a/neodroidagent/common/session_factory/vertical/single_agent_environment_session.py b/neodroidagent/common/session_factory/vertical/single_agent_environment_session.py index b0b4b392..20ab4e13 100644 --- a/neodroidagent/common/session_factory/vertical/single_agent_environment_session.py +++ b/neodroidagent/common/session_factory/vertical/single_agent_environment_session.py @@ -3,12 +3,13 @@ import inspect import time from contextlib import suppress +from os import cpu_count from typing import Any, Type import torch import torchsnooper - -from draugr import CaptureEarlyStop, add_early_stopping_key_combination, sprint +from draugr import CaptureEarlyStop, MockWriter, add_early_stopping_key_combination, sprint +from draugr.drawers import DiscreteScrollPlot, SeriesScrollPlot from draugr.torch_utilities import TensorBoardPytorchWriter, torch_seed from neodroidagent import PROJECT_APP_PATH from neodroidagent.agents import Agent @@ -16,6 +17,7 @@ from warg import GDKC, passes_kws_to from warg.context_wrapper import ContextWrapper from warg.decorators.timing import StopWatch + from .environment_session import EnvironmentSession from .procedures.procedure_specification import Procedure @@ -27,25 +29,28 @@ class SingleAgentEnvironmentSession(EnvironmentSession): - @passes_kws_to( - add_early_stopping_key_combination, - Agent.__init__, - Agent.save, - Procedure.__call__, - ) - def __call__( - self, - agent: Type[Agent], - *, - load_time: Any, - seed: int, - save_ending_model: bool = False, - continue_training: bool = True, - train_agent: bool = True, - debug: bool = False, - **kwargs, - ): - """ + @passes_kws_to( + add_early_stopping_key_combination, + Agent.__init__, + Agent.save, + Procedure.__init__, + Procedure.__call__, + ) + def __call__( + self, + agent: Type[Agent], + *, + load_time: Any = str(int(time.time())), + seed: int = 0, + save_ending_model: bool = False, + save_training_resume: bool = False, + continue_training: bool = True, + train_agent: bool = True, + debug: bool = False, + num_envs: int = cpu_count(), + **kwargs, + ): + """ Start a session, builds Agent and starts/connect environment(s), and runs Procedure @@ -53,129 +58,155 @@ def __call__( :param kwargs: :return: """ - with ContextWrapper(torchsnooper.snoop, debug): - with ContextWrapper(torch.autograd.detect_anomaly, debug): - - if agent is None: - raise NoAgent + kwargs.update(num_envs=num_envs) + kwargs.update(train_agent=train_agent) + kwargs.update(debug=debug) + kwargs.update(environment=self._environment) + + with ContextWrapper(torchsnooper.snoop, debug): + with ContextWrapper(torch.autograd.detect_anomaly, debug): + + if agent is None: + raise NoAgent + + if inspect.isclass(agent): + sprint( + "Instantiating Agent", color="crimson", bold=True, italic=True + ) + torch_seed(seed) + self._environment.seed(seed) + + agent = agent(load_time=load_time, seed=seed, **kwargs) + + agent_class_name = agent.__class__.__name__ + + total_shape = "_".join( + [ + str(i) + for i in ( + self._environment.observation_space.shape + + self._environment.action_space.shape + + self._environment.signal_space.shape + ) + ] + ) + + environment_name = f"{self._environment.environment_name}_{total_shape}" + + save_directory = ( + PROJECT_APP_PATH.user_data / environment_name / agent_class_name + ) + log_directory = ( + PROJECT_APP_PATH.user_log + / environment_name + / agent_class_name + / load_time + ) + + if self._environment.action_space.is_discrete: + rollout_drawer = GDKC(DiscreteScrollPlot, + num_actions=self._environment.action_space.discrete_steps, + default_delta=None + ) + else: + rollout_drawer = GDKC(SeriesScrollPlot, + window_length=100, + default_delta=None + ) - if inspect.isclass(agent): - sprint( - "Instantiating Agent", color="crimson", bold=True, italic=True - ) - torch_seed(seed) - self._environment.seed(seed) - - agent = agent(load_time=load_time, seed=seed, **kwargs) - - agent_class_name = agent.__class__.__name__ - - total_shape = "_".join( - [ - str(i) - for i in ( - self._environment.observation_space.shape - + self._environment.action_space.shape - + self._environment.signal_space.shape - ) - ] + if train_agent: # TODO: allow metric writing while not training with flag + metric_writer = GDKC(TensorBoardPytorchWriter, + path=log_directory + ) + else: + metric_writer = GDKC(MockWriter) + + with ContextWrapper(metric_writer, train_agent) as metric_writer: + with ContextWrapper(rollout_drawer, num_envs == 1) as rollout_drawer: + + agent.build( + self._environment.observation_space, + self._environment.action_space, + self._environment.signal_space, + metric_writer=metric_writer, ) - environment_name = f"{self._environment.environment_name}_{total_shape}" - - save_directory = ( - PROJECT_APP_PATH.user_data / environment_name / agent_class_name + kwargs.update( + environment_name=(self._environment.environment_name,), + save_directory=save_directory, + log_directory=log_directory, + load_time=load_time, + seed=seed, + train_agent=train_agent, ) - log_directory = ( - PROJECT_APP_PATH.user_log - / environment_name - / agent_class_name - / load_time - ) - - with TensorBoardPytorchWriter(log_directory) as metric_writer: - - agent.build( - self._environment.observation_space, - self._environment.action_space, - self._environment.signal_space, - metric_writer=metric_writer, - ) - kwargs.update( - environment_name=(self._environment.environment_name,), - save_directory=save_directory, - log_directory=log_directory, - load_time=load_time, - seed=seed, - train_agent=train_agent, + found = False + if continue_training: + sprint( + "Searching for previously trained models for initialisation for this configuration " + "(Architecture, Action Space, Observation Space, ...)", + color="crimson", + bold=True, + italic=True, + ) + found = agent.load( + save_directory=save_directory, evaluation=not train_agent + ) + if not found: + sprint( + "Did not find any previously trained models for this configuration", + color="crimson", + bold=True, + italic=True, ) - found = False - if continue_training: - sprint( - "Searching for previously trained models for initialisation for this configuration " - "(Architecture, Action Space, Observation Space, ...)", - color="crimson", - bold=True, - italic=True, - ) - found = agent.load( - save_directory=save_directory, evaluation=not train_agent - ) - if not found: - sprint( - "Did not find any previously trained models for this configuration", - color="crimson", - bold=True, - italic=True, - ) - - if not train_agent: - agent.eval() - else: - agent.train() - - if not found: - sprint( - "Training from new initialisation", - color="crimson", - bold=True, - italic=True, - ) - - session_proc = self._procedure(agent, environment=self._environment) - - with CaptureEarlyStop( - callbacks=self._procedure.stop_procedure, **kwargs - ): - with StopWatch() as timer: - with suppress(KeyboardInterrupt): - training_resume = session_proc( - metric_writer=metric_writer, **kwargs - ) - if training_resume and "stats" in training_resume: - training_resume.stats.save(**kwargs) - - end_message = f"Training ended, time elapsed: {timer // 60:.0f}m {timer % 60:.0f}s" - line_width = 9 - sprint( - f'\n{"-" * line_width} {end_message} {"-" * line_width}\n', - color="crimson", - bold=True, - italic=True, - ) + if not train_agent: + agent.eval() + else: + agent.train() + + if not found: + sprint( + "Training from new initialisation", + color="crimson", + bold=True, + italic=True, + ) + + session_proc = self._procedure(agent, **kwargs) + + with CaptureEarlyStop( + callbacks=self._procedure.stop_procedure, **kwargs + ): + with StopWatch() as timer: + with suppress(KeyboardInterrupt): + training_resume = session_proc( + metric_writer=metric_writer, + rollout_drawer=rollout_drawer, + **kwargs + ) + if training_resume and "stats" in training_resume and save_training_resume: + training_resume.stats.save(**kwargs) + + end_message = f"Training ended, time elapsed: {timer // 60:.0f}m {timer % 60:.0f}s" + line_width = 9 + sprint( + f'\n{"-" * line_width} {end_message} {"-" * line_width}\n', + color="crimson", + bold=True, + italic=True, + ) - if save_ending_model: - agent.save(**kwargs) + if save_ending_model: + agent.save(**kwargs) - try: - self._environment.close() - except BrokenPipeError: - pass + try: + self._environment.close() + except BrokenPipeError: + pass - exit(0) + exit(0) if __name__ == "__main__": - print(SingleAgentEnvironmentSession) + print(SingleAgentEnvironmentSession) diff --git a/neodroidagent/entry_points/agent_tests/random_test.py b/neodroidagent/entry_points/agent_tests/random_test.py index 1826c033..c454c819 100644 --- a/neodroidagent/entry_points/agent_tests/random_test.py +++ b/neodroidagent/entry_points/agent_tests/random_test.py @@ -20,8 +20,11 @@ def random_run( rollouts=None, skip_confirmation: bool = True, environment_type: Union[bool, str] = True, - config=random_config, + config=None, ) -> None: + if config is None: + config = random_config + if rollouts: config.ROLLOUTS = rollouts diff --git a/neodroidagent/entry_points/agent_tests/torch_agent_tests/ddpg_test.py b/neodroidagent/entry_points/agent_tests/torch_agent_tests/ddpg_test.py index f1124b44..485575f8 100644 --- a/neodroidagent/entry_points/agent_tests/torch_agent_tests/ddpg_test.py +++ b/neodroidagent/entry_points/agent_tests/torch_agent_tests/ddpg_test.py @@ -32,20 +32,24 @@ def ddpg_run( skip_confirmation: bool = True, environment_type: Union[bool, str] = True, - config=ddpg_config, -): - session_factory( - DeepDeterministicPolicyGradientAgent, - config, - session=ParallelSession, - skip_confirmation=skip_confirmation, - environment=environment_type, - ) + config=None, **kwargs + ): + if config is None: + config = ddpg_config + session_factory( + DeepDeterministicPolicyGradientAgent, + config, + session=ParallelSession, + skip_confirmation=skip_confirmation, + environment=environment_type, **kwargs + ) -def ddpg_test(config=ddpg_config): - ddpg_run(environment_type="gym", config=config) +def ddpg_test(config=None, **kwargs): + if config is None: + config = ddpg_config + ddpg_run(environment_type="gym", config=config, **kwargs) if __name__ == "__main__": - ddpg_test() + ddpg_test() diff --git a/neodroidagent/entry_points/agent_tests/torch_agent_tests/dqn_test.py b/neodroidagent/entry_points/agent_tests/torch_agent_tests/dqn_test.py index d753d2a7..be2541b5 100644 --- a/neodroidagent/entry_points/agent_tests/torch_agent_tests/dqn_test.py +++ b/neodroidagent/entry_points/agent_tests/torch_agent_tests/dqn_test.py @@ -37,23 +37,27 @@ def dqn_run( skip_confirmation: bool = True, environment_type: Union[bool, str] = True, - config=dqn_config, + config=None,**kwargs ) -> None: + if config is None: + config = dqn_config session_factory( DeepQNetworkAgent, config, - session=ParallelSession( + session=GDKC(ParallelSession, environment_name=ENVIRONMENT_NAME, procedure=OffPolicyEpisodic, - environment=environment_type, + environment=environment_type,**kwargs ), skip_confirmation=skip_confirmation, - environment=environment_type, + environment=environment_type,**kwargs ) -def dqn_test(config=dqn_config): - dqn_run(environment_type="gym", config=config) +def dqn_test(config=None,**kwargs): + if config is None: + config = dqn_config + dqn_run(environment_type="gym", config=config,**kwargs) if __name__ == "__main__": diff --git a/neodroidagent/entry_points/agent_tests/torch_agent_tests/pg_test.py b/neodroidagent/entry_points/agent_tests/torch_agent_tests/pg_test.py index b5453d49..768e2aed 100644 --- a/neodroidagent/entry_points/agent_tests/torch_agent_tests/pg_test.py +++ b/neodroidagent/entry_points/agent_tests/torch_agent_tests/pg_test.py @@ -32,19 +32,25 @@ def pg_run( skip_confirmation: bool = True, environment_type: Union[bool, str] = True, *, - config=pg_config + config=None,**kwargs + ) -> None: + if config is None: + config = pg_config + session_factory( PolicyGradientAgent, config, session=ParallelSession, skip_confirmation=skip_confirmation, - environment=environment_type, + environment=environment_type,**kwargs ) -def pg_test(config=pg_config) -> None: - pg_run(environment_type="gym", config=config) +def pg_test(config=None,**kwargs) -> None: + if config is None: + config = pg_config + pg_run(environment_type="gym", config=config,**kwargs) if __name__ == "__main__": diff --git a/neodroidagent/entry_points/agent_tests/torch_agent_tests/ppo_test.py b/neodroidagent/entry_points/agent_tests/torch_agent_tests/ppo_test.py index 2a690c83..da052913 100644 --- a/neodroidagent/entry_points/agent_tests/torch_agent_tests/ppo_test.py +++ b/neodroidagent/entry_points/agent_tests/torch_agent_tests/ppo_test.py @@ -43,19 +43,21 @@ ppo_config = globals() -def ppo_test(config=ppo_config): +def ppo_test(config=None,**kwargs): """ @param config: @type config: """ - ppo_run(environment_type="gym", config=config) + if config is None: + config = ppo_config + ppo_run(environment_type="gym", config=config,**kwargs) def ppo_run( skip_confirmation: bool = True, environment_type: Union[bool, str] = True, - config=ppo_config, + config=None,**kwargs ): """ @@ -66,12 +68,14 @@ def ppo_run( @param config: @type config: """ + if config is None: + config = ppo_config session_factory( ProximalPolicyOptimizationAgent, config, session=ParallelSession, environment=environment_type, - skip_confirmation=skip_confirmation, + skip_confirmation=skip_confirmation,**kwargs ) diff --git a/neodroidagent/entry_points/agent_tests/torch_agent_tests/sac_test.py b/neodroidagent/entry_points/agent_tests/torch_agent_tests/sac_test.py index d998946d..18c6328b 100644 --- a/neodroidagent/entry_points/agent_tests/torch_agent_tests/sac_test.py +++ b/neodroidagent/entry_points/agent_tests/torch_agent_tests/sac_test.py @@ -37,25 +37,31 @@ sac_config = globals() -def sac_test(config=sac_config): - sac_run(environment_type="gym", config=config) +def sac_test(config=None,**kwargs): + if config is None: + config = sac_config + sac_run(environment_type="gym", config=config,**kwargs) def sac_run( skip_confirmation: bool = True, environment_type: Union[bool, str] = True, - config=sac_config, + config=None,**kwargs ): + if config is None: + config = sac_config session_factory( SoftActorCriticAgent, config, - session=ParallelSession( + session=GDKC(ParallelSession, procedure=OffPolicyStepWise, environment_name=ENVIRONMENT_NAME, auto_reset_on_terminal_state=True, environment=environment_type, + **kwargs ), skip_confirmation=skip_confirmation, + **kwargs ) diff --git a/neodroidagent/entry_points/clean.py b/neodroidagent/entry_points/clean.py index 7bc0fc62..1b7a8fd9 100644 --- a/neodroidagent/entry_points/clean.py +++ b/neodroidagent/entry_points/clean.py @@ -11,47 +11,59 @@ def clean_data() -> None: - print(f"Wiping {PROJECT_APP_PATH.user_data}") - if PROJECT_APP_PATH.user_data.exists(): - data_dir = str(PROJECT_APP_PATH.user_data) - rmtree(data_dir) - else: - PROJECT_APP_PATH.user_data.mkdir() + """ + + """ + print(f"Wiping {PROJECT_APP_PATH.user_data}") + if PROJECT_APP_PATH.user_data.exists(): + data_dir = str(PROJECT_APP_PATH.user_data) + rmtree(data_dir) + else: + PROJECT_APP_PATH.user_data.mkdir() def clean_log() -> None: - print(f"Wiping {PROJECT_APP_PATH.user_log}") - if PROJECT_APP_PATH.user_log.exists(): - log_dir = str(PROJECT_APP_PATH.user_log) - rmtree(log_dir) - else: - PROJECT_APP_PATH.user_log.mkdir() + """ + + """ + print(f"Wiping {PROJECT_APP_PATH.user_log}") + if PROJECT_APP_PATH.user_log.exists(): + log_dir = str(PROJECT_APP_PATH.user_log) + rmtree(log_dir) + else: + PROJECT_APP_PATH.user_log.mkdir() def clean_cache() -> None: - print(f"Wiping {PROJECT_APP_PATH.user_cache}") - if PROJECT_APP_PATH.user_cache.exists(): - cache_dir = str(PROJECT_APP_PATH.user_cache) - rmtree(cache_dir) - else: - PROJECT_APP_PATH.user_cache.mkdir() + """ + + """ + print(f"Wiping {PROJECT_APP_PATH.user_cache}") + if PROJECT_APP_PATH.user_cache.exists(): + cache_dir = str(PROJECT_APP_PATH.user_cache) + rmtree(cache_dir) + else: + PROJECT_APP_PATH.user_cache.mkdir() def clean_config() -> None: - print(f"Wiping {PROJECT_APP_PATH.user_config}") - if PROJECT_APP_PATH.user_config.exists(): - config_dir = str(PROJECT_APP_PATH.user_config) - rmtree(config_dir) - else: - PROJECT_APP_PATH.user_config.mkdir() + """ + + """ + print(f"Wiping {PROJECT_APP_PATH.user_config}") + if PROJECT_APP_PATH.user_config.exists(): + config_dir = str(PROJECT_APP_PATH.user_config) + rmtree(config_dir) + else: + PROJECT_APP_PATH.user_config.mkdir() def clean_all() -> None: - clean_config() - clean_data() - clean_cache() - clean_log() + clean_config() + clean_data() + clean_cache() + clean_log() if __name__ == "__main__": - clean_all() + clean_all() diff --git a/neodroidagent/entry_points/cli.py b/neodroidagent/entry_points/cli.py index b0997eb6..b2522c97 100644 --- a/neodroidagent/entry_points/cli.py +++ b/neodroidagent/entry_points/cli.py @@ -26,62 +26,67 @@ class RunAgent: - def __init__(self, agent_key: str, agent_callable: callable): - self.agent_key = agent_key - self.agent_callable = agent_callable + def __init__(self, agent_key: str, agent_callable: callable): + self.agent_key = agent_key + self.agent_callable = agent_callable - def train(self, **overrides) -> None: - """ + def train(self, **explicit_overrides) -> None: + """ -@param overrides: Accepts kwarg overrides to config +@param explicit_overrides: Accepts kwarg overrides to config @return: """ - default_config = NOD(AGENT_CONFIG[self.agent_key]) + default_config = NOD(AGENT_CONFIG[self.agent_key]) - overrides = upper_dict(overrides) - for key, arg in overrides.items(): - setattr(default_config, key, arg) + config_overrides = upper_dict(explicit_overrides) + for key, arg in config_overrides.items(): + setattr(default_config, key, arg) - print("Overrides:") - print(overrides) - print(default_config) + print("Explicit Overrides:") + print(explicit_overrides) + #print(default_config) - self.agent_callable(config=default_config) + self.agent_callable(config=default_config, **explicit_overrides) - def run(self): - self.train(train_agent=False,render_frequency=1,save=False) + def run(self): + self.train(train_agent=False, + render_frequency=1, + save=False, + save_ending=False, + num_envs=1, + save_best_throughout_training=False) class NeodroidAgentCLI: - def __init__(self): - for k, v in AGENT_OPTIONS.items(): - setattr(self, k, RunAgent(k, v)) + def __init__(self): + for k, v in AGENT_OPTIONS.items(): + setattr(self, k, RunAgent(k, v)) - @staticmethod - def version() -> None: - """ + @staticmethod + def version() -> None: + """ Prints the version of this Neodroid installation. """ - draw_cli_header() - print(f"Version: {get_version()}") + draw_cli_header() + print(f"Version: {get_version()}") - @staticmethod - def sponsors() -> None: - print(sponsors) + @staticmethod + def sponsors() -> None: + print(sponsors) -def draw_cli_header(*, title: str = "Neodroid Agent", font: str = "big"): - figlet = Figlet(font=font, justify="center", width=terminal_width) - description = figlet.renderText(title) +def draw_cli_header(*, title: str = "Neodroid Agent", font: str = "big") -> None: + figlet = Figlet(font=font, justify="center", width=terminal_width) + description = figlet.renderText(title) - print(f"{description}{underline}\n") + print(f"{description}{underline}\n") -def main(*, always_draw_header: bool = False): - if always_draw_header: - draw_cli_header() - fire.Fire(NeodroidAgentCLI, name="neodroid-agent") +def main(*, always_draw_header: bool = False) -> None: + if always_draw_header: + draw_cli_header() + fire.Fire(NeodroidAgentCLI, name="neodroid-agent") if __name__ == "__main__": - main() + main() diff --git a/neodroidagent/entry_points/session_factory.py b/neodroidagent/entry_points/session_factory.py index c88bb7f9..28284bd8 100644 --- a/neodroidagent/entry_points/session_factory.py +++ b/neodroidagent/entry_points/session_factory.py @@ -16,11 +16,11 @@ from draugr import sprint from neodroidagent.agents import Agent from neodroidagent.common.session_factory.vertical.environment_session import ( - EnvironmentSession, -) + EnvironmentSession, + ) from neodroidagent.utilities import NoProcedure -from warg import NOD, config_to_mapping +from warg import GDKC, NOD, config_to_mapping AgentType = TypeVar("AgentType", bound=Agent) EnvironmentSessionType = TypeVar("EnvironmentSessionType", bound=EnvironmentSession) @@ -28,51 +28,57 @@ def session_factory( agent: Type[AgentType] = None, - config: Union[object, dict] = {}, + config=None, *, session: Union[Type[EnvironmentSessionType], EnvironmentSession], save: bool = True, has_x_server: bool = True, skip_confirmation: bool = True, **kwargs, -): - r""" + ): + r""" Entry point start a starting a training session with the functionality of parsing cmdline arguments and confirming configuration to use before training and overwriting of default training configurations """ - if isinstance(config, dict): - config = NOD(**config) - else: - config = NOD(config.__dict__) + if config is None: + config = {} - if has_x_server: - display_env = getenv("DISPLAY", None) - if display_env is None: - config.RENDER_ENVIRONMENT = False - has_x_server = False + if isinstance(config, dict): + config = NOD(**config) + else: + config = NOD(config.__dict__) - config_mapping = config_to_mapping(config) - config_mapping.update(**kwargs) + if has_x_server: + display_env = getenv("DISPLAY", None) + if display_env is None: + config.RENDER_ENVIRONMENT = False + has_x_server = False - if not skip_confirmation: - sprint(f"\nUsing config: {config}\n", highlight=True, color="yellow") - for key, arg in config_mapping: - print(f"{key} = {arg}") + config_mapping = config_to_mapping(config) + config_mapping.update(**kwargs) - sprint(f"\n.. Also save:{save}," f" has_x_server:{has_x_server}") - input("\nPress Enter to begin... ") + config_mapping.update(save=save, has_x_server=has_x_server) - if session is None: - raise NoProcedure - elif inspect.isclass(session): - session = session(**config_mapping) + if not skip_confirmation: + sprint(f"\nUsing config: {config}\n", highlight=True, color="yellow") + for key, arg in config_mapping: + print(f"{key} = {arg}") - try: - session(agent, save=save, has_x_server=has_x_server, **config_mapping) - except KeyboardInterrupt: - print("Stopping") + input("\nPress Enter to begin... ") - torch.cuda.empty_cache() + if session is None: + raise NoProcedure + elif inspect.isclass(session): + session = session(**config_mapping) # Use passed config arguments + elif isinstance(session,GDKC): + session = session(**kwargs) # Assume some kw parameters is set prior to passing session, only override with explicit overrides - exit(0) + try: + session(agent, **config_mapping) + except KeyboardInterrupt: + print("Stopping") + + torch.cuda.empty_cache() + + exit(0) diff --git a/neodroidagent/utilities/exploration/intrinsic_signals/torch_isp/curiosity/icm.py b/neodroidagent/utilities/exploration/intrinsic_signals/torch_isp/curiosity/icm.py index fda778ab..ef327f8e 100644 --- a/neodroidagent/utilities/exploration/intrinsic_signals/torch_isp/curiosity/icm.py +++ b/neodroidagent/utilities/exploration/intrinsic_signals/torch_isp/curiosity/icm.py @@ -8,7 +8,7 @@ from torch.distributions import Categorical from torch.nn import CrossEntropyLoss, MSELoss -from draugr.torch_utilities.tensors.to_tensor import to_tensor +from draugr.torch_utilities import to_tensor from draugr.writers import Writer from neodroid.utilities import ActionSpace, ObservationSpace, SignalSpace diff --git a/neodroidagent/utilities/misc/__init__.py b/neodroidagent/utilities/misc/__init__.py index b431c59b..dca1a01b 100644 --- a/neodroidagent/utilities/misc/__init__.py +++ b/neodroidagent/utilities/misc/__init__.py @@ -5,7 +5,6 @@ __doc__ = r""" """ -from .bool_tests import * from .checks import * from .sampling import * from .target_updates import * diff --git a/neodroidagent/utilities/misc/bool_tests.py b/neodroidagent/utilities/misc/bool_tests.py deleted file mode 100644 index 1f2ae49e..00000000 --- a/neodroidagent/utilities/misc/bool_tests.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -__author__ = "Christian Heider Nielsen" -__doc__ = r""" - - Created on 02/01/2020 - """ - -__all__ = [ - "is_positive_and_mod_zero", - "is_zero_or_mod_zero", - "is_none_or_zero_or_negative", - "is_zero_or_mod_below", - "is_none_or_zero_or_negative_or_mod_zero", -] - -from typing import Any - -from warg import drop_unused_kws, passes_kws_to - - -@drop_unused_kws -def is_positive_and_mod_zero( - mod: int, counter: int, *, ret: Any = True, alt: Any = False -) -> Any: - """ - -test if mod is positive -then test if counter % mod is 0 -if both tests are true return ret -else return alt - - -@param mod: -@param counter: -@param ret: -@param alt: -@return: -""" - - return ret if (mod > 0 and (counter % mod == 0)) else alt - - -@drop_unused_kws -def is_zero_or_mod_below( - mod: int, below: int, counter: int, *, ret: Any = True, alt: Any = False -) -> Any: - """ - -test if mod is zero or if counter % mod is 0 -if any of the tests are true return ret -else return alt - - - @param below: - @type below: -@param mod: -@param counter: -@param ret: -@param alt: -@return: -""" - return ret if (mod == 0 or (counter % mod < below)) else alt - - -@drop_unused_kws -def is_zero_or_mod_zero( - mod: int, counter: int, *, ret: Any = True, alt: Any = False -) -> Any: - """ - -test if mod is zero or if counter % mod is 0 -if any of the tests are true return ret -else return alt - - -@param mod: -@param counter: -@param ret: -@param alt: -@return: -""" - return ret if (mod == 0 or (counter % mod == 0)) else alt - - -def is_none_or_zero_or_negative(obj: Any) -> bool: - """ - -@param obj: -@return: -""" - is_none = obj is None - is_negative = False - if isinstance(obj, (int, float)): - is_negative = obj <= 0 - - return is_none or is_negative - - -@passes_kws_to(is_zero_or_mod_zero) -def is_none_or_zero_or_negative_or_mod_zero(mod: int, counter: int, **kwargs) -> bool: - """ - -@param mod: -@param counter: -@param kwargs: -@return: -""" - return is_none_or_zero_or_negative(mod) or is_zero_or_mod_zero( - mod, counter, **kwargs - ) - - -if __name__ == "__main__": - assert is_zero_or_mod_below(5, 3, 7) == True - assert is_zero_or_mod_below(5, 2, 4) == False - for i in range(9): - print(is_zero_or_mod_zero(1, i)) - for i in range(9): - print(is_zero_or_mod_zero(2, i)) diff --git a/setup.py b/setup.py index e08a3265..d707e226 100644 --- a/setup.py +++ b/setup.py @@ -29,8 +29,9 @@ def python_version_check(major=3, minor=6): __author__ = author +__all__ = ['NeodroidAgentPackage'] -class NeodroidAgentPackage: +class NeodroidAgentPackageMeta(type): @property def dependencies_testing(self) -> list: return ["pytest", "mock"] @@ -81,13 +82,14 @@ def maintainer_email(self) -> str: @property def package_data(self) -> dict: - # data = glob.glob('data/', recursive=True) + emds = [str(p) for p in pathlib.Path(__file__).parent.rglob('.md')] return { - # 'PackageName':[ - # *data - # ] + 'neodroidagent':[ + *emds + ] } + @property def entry_points(self) -> dict: return { @@ -181,10 +183,13 @@ def classifiers(self) -> List[str]: def version(self) -> str: return version +class NeodroidAgentPackage(metaclass=NeodroidAgentPackageMeta): + pass + if __name__ == "__main__": - pkg = NeodroidAgentPackage() + pkg = NeodroidAgentPackage setup( name=pkg.package_name,