From 1f5fe6bdf29483f3f64addb03c77dc3b009a85b8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 18 Nov 2022 11:39:35 +0000 Subject: [PATCH 1/5] init --- examples/ddpg/ddpg.py | 2 +- examples/dqn/dqn.py | 2 +- examples/dreamer/dreamer.py | 2 +- examples/dreamer/dreamer_utils.py | 2 +- examples/ppo/ppo.py | 2 +- examples/redq/redq.py | 2 +- examples/sac/sac.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 37632b66115..599417558e8 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -117,7 +117,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index ad63711e4a0..64e2b3ca058 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -107,7 +107,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 4e903731219..602d49b7256 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_environment=transformed_env_constructor(cfg)(), - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), ) stats = {k: v.clone() for k, v in stats.items()} elif cfg.from_pixels: diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 0c69edc5c60..0eef5be0413 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -307,7 +307,7 @@ def call_record( if cfg.record_video and record._count % cfg.record_interval == 0: world_model_td = sampled_tensordict - true_pixels = recover_pixels(world_model_td["next_pixels"], stats) + true_pixels = recover_pixels(world_model_td[("next", "pixels")], stats) reco_pixels = recover_pixels(world_model_td["next", "reco_pixels"], stats) with autocast(dtype=torch.float16): diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index 93b73c82bfe..a992b127306 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -103,7 +103,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() diff --git a/examples/redq/redq.py b/examples/redq/redq.py index e911f91fd78..654ee082cb4 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -118,7 +118,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() diff --git a/examples/sac/sac.py b/examples/sac/sac.py index d400ec64701..2aa61d3df68 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -118,7 +118,7 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key="next_pixels" if cfg.from_pixels else "next_observation_vector", + key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() From f98d4c6b03ad58567497d76a0e38e7c75f6030bd Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 18 Nov 2022 13:27:30 +0000 Subject: [PATCH 2/5] tests1 --- .../linux_examples/scripts/environment.yml | 30 ++ .../linux_examples/scripts/install.sh | 46 +++ .../linux_examples/scripts/post_process.sh | 6 + .../scripts/run-clang-format.py | 356 ++++++++++++++++++ .../linux_examples/scripts/run_test.sh | 38 ++ .../linux_examples/scripts/setup_env.sh | 88 +++++ 6 files changed, 564 insertions(+) create mode 100644 .circleci/unittest/linux_examples/scripts/environment.yml create mode 100755 .circleci/unittest/linux_examples/scripts/install.sh create mode 100755 .circleci/unittest/linux_examples/scripts/post_process.sh create mode 100755 .circleci/unittest/linux_examples/scripts/run-clang-format.py create mode 100755 .circleci/unittest/linux_examples/scripts/run_test.sh create mode 100755 .circleci/unittest/linux_examples/scripts/setup_env.sh diff --git a/.circleci/unittest/linux_examples/scripts/environment.yml b/.circleci/unittest/linux_examples/scripts/environment.yml new file mode 100644 index 00000000000..3987e1f4e9b --- /dev/null +++ b/.circleci/unittest/linux_examples/scripts/environment.yml @@ -0,0 +1,30 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - protobuf + - pip: + - hypothesis + - future + - cloudpickle + - gym + - pygame + - gym[accept-rom-license] + - gym[atari] + - moviepy + - tqdm + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - expecttest + - pyyaml + - scipy + - hydra-core + - tensorboard + - wandb + - dm_control + - mlflow + - av + - coverage diff --git a/.circleci/unittest/linux_examples/scripts/install.sh b/.circleci/unittest/linux_examples/scripts/install.sh new file mode 100755 index 00000000000..ad126f23b0a --- /dev/null +++ b/.circleci/unittest/linux_examples/scripts/install.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" + echo "Using cpu build" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu +else + pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113 +fi + +# smoke test +python -c "import functorch" + +# install snapshot +pip install git+https://github.com/pytorch/torchsnapshot + +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict + +printf "* Installing torchrl\n" +python setup.py develop diff --git a/.circleci/unittest/linux_examples/scripts/post_process.sh b/.circleci/unittest/linux_examples/scripts/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_examples/scripts/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_examples/scripts/run-clang-format.py b/.circleci/unittest/linux_examples/scripts/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_examples/scripts/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh new file mode 100755 index 00000000000..3ac0e1585db --- /dev/null +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# this code is supposed to run on CPU +# rendering with the combination of packages we have here in headless mode +# is hard to nail. +# IMPORTANT: As a consequence, we can't guarantee TorchRL compatibility with +# rendering with this version of gym / mujoco-py. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU + +coverage run -m pytest test/smoke_test.py -v --durations 20 +coverage run -m pytest test/smoke_test_deps.py -v --durations 20 +coverage run -m python examples/ddpg/ddpg.py \ + total_frames=14 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_devices=cpu \ + optim_steps_per_batch=1 +coverage xml -i diff --git a/.circleci/unittest/linux_examples/scripts/setup_env.sh b/.circleci/unittest/linux_examples/scripts/setup_env.sh new file mode 100755 index 00000000000..00a21db6cc1 --- /dev/null +++ b/.circleci/unittest/linux_examples/scripts/setup_env.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install mujoco +printf "* Installing mujoco and related\n" +mkdir -p $root_dir/.mujoco +cd $root_dir/.mujoco/ +wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz +tar -xf mujoco210-linux-x86_64.tar.gz +cd $this_dir + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + + +if [[ $OSTYPE == 'darwin'* ]]; then + PRIVATE_MUJOCO_GL=glfw +elif [ "${CU_VERSION:-}" == cpu ]; then + PRIVATE_MUJOCO_GL=osmesa +else + PRIVATE_MUJOCO_GL=egl +fi + +export MUJOCO_GL=$PRIVATE_MUJOCO_GL +conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \ + DISPLAY=unix:0.0 \ + MJLIB_PATH=$root_dir/.mujoco/mujoco-2.1.1/lib/libmujoco.so.2.1.1 \ + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$root_dir/.mujoco/mujoco210/bin \ + SDL_VIDEODRIVER=dummy \ + MUJOCO_GL=$PRIVATE_MUJOCO_GL \ + PYOPENGL_PLATFORM=$PRIVATE_MUJOCO_GL + +# Software rendering requires GLX and OSMesa. +if [ $PRIVATE_MUJOCO_GL == 'egl' ] || [ $PRIVATE_MUJOCO_GL == 'osmesa' ] ; then + yum makecache + yum install -y glfw + yum install -y glew + yum install -y mesa-libGL + yum install -y mesa-libGL-devel + yum install -y mesa-libOSMesa-devel + yum -y install egl-utils + yum -y install freeglut +fi + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune From f3f174a5245edf1415189dc9b0c48b771e00e221 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 18 Nov 2022 13:42:06 +0000 Subject: [PATCH 3/5] run examples in tests --- .circleci/config.yml | 56 +++++++++++++++++++ .../linux_examples/scripts/run_test.sh | 15 ++++- examples/ddpg/ddpg.py | 2 +- examples/dqn/dqn.py | 2 +- torchrl/envs/transforms/transforms.py | 13 ++++- 5 files changed, 81 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 6fe31e1ecf4..35e1edd7ecf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -347,6 +347,57 @@ jobs: - store_test_results: path: test-results + + unittest_linux_examples_gpu: + <<: *binary_common + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "pytorch/manylinux-cuda113" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + CU_VERSION: << parameters.cu_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_examples/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + - run: + name: Setup + command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_examples/scripts/setup_env.sh + - save_cache: + + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_examples/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} + + paths: + - conda + - env + - run: + name: Install torchrl + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_examples/scripts/install.sh + - run: + name: Run tests + command: bash .circleci/unittest/linux_examples/scripts/run_test.sh +# command: docker run --env-file ./env.list -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/run_test.sh + - run: + name: Codecov upload + command: | + bash <(curl -s https://codecov.io/bash) -Z -F linux_examples-gpu + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_examples/scripts/post_process.sh + - store_test_results: + path: test-results + unittest_linux_habitat_gpu: <<: *binary_common machine: @@ -888,3 +939,8 @@ workflows: cu_version: cu113 name: unittest_linux_olddeps_gpu_py3.9 python_version: '3.9' + + - unittest_linux_examples_gpu: + cu_version: cu113 + name: unittest_linux_examples_gpu_py3.9 + python_version: '3.9' diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 3ac0e1585db..4ac07a2c2d5 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -33,6 +33,17 @@ coverage run -m python examples/ddpg/ddpg.py \ frames_per_batch=16 \ num_workers=2 \ env_per_collector=1 \ - collector_devices=cpu \ - optim_steps_per_batch=1 + collector_devices=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True +coverage run -m python examples/dqn/dqn.py \ + total_frames=14 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_devices=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True coverage xml -i diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 599417558e8..ad65b4cd737 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -191,7 +191,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder_rm = TransformedEnv(recorder.base_env) for transform in recorder.transform: if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) + recorder_rm.append_transform(transform.clone()) else: recorder_rm = recorder diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index 64e2b3ca058..c776a52fecb 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -161,7 +161,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder_rm = TransformedEnv(recorder.base_env) for transform in recorder.transform: if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) + recorder_rm.append_transform(transform.clone()) else: recorder_rm = recorder diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b3c92d15932..e2ae5663b86 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -215,12 +215,20 @@ def __repr__(self) -> str: def set_parent(self, parent: Union[Transform, EnvBase]) -> None: if self.__dict__["_parent"] is not None: - raise AttributeError("parent of transform already set") + raise AttributeError( + "parent of transform already set. " + "Call `transform.clone()` to get a similar transform with no parent set." + ) self.__dict__["_parent"] = parent def reset_parent(self) -> None: self.__dict__["_parent"] = None + def clone(self): + self_copy = copy(self) + self_copy.reset_parent() + return self_copy + @property def parent(self) -> EnvBase: if not hasattr(self, "_parent"): @@ -242,8 +250,7 @@ def parent(self) -> EnvBase: f"Compose parent was of type {type(compose_parent)} but expected TransformedEnv." ) if compose_parent.transform is not compose: - comp_parent_trans = copy(compose_parent.transform) - comp_parent_trans.reset_parent() + comp_parent_trans = compose_parent.transform.clone() else: comp_parent_trans = None out = TransformedEnv( From c60ace54e0bb9deab4ef4cd517e676d3355c2a2d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Nov 2022 11:26:11 +0000 Subject: [PATCH 4/5] [Feature] MPPI Planner (#694) * amend * [BugFix] ConvNet forward method with tensors of more than 4 dimensions (#686) * cnn forward fix * more general code * cnn testing * precommit run check * convnet tests * [Feature] add `standard_normal` for RewardScaling (#682) * Add standard_normal * give attribute access * Update standard_normal * Update tests * Fix tests * Address in-place scaling of reward * Improvise tests * [Feature] Jumanji envs (#674) * amend * [Feature] Default collate_fn (#688) * init * amend * amend * [BugFix] Fix Examples (#687) * amend * [Refactoring] Replace direct gym version checks with decorated functions (#691) * [Refactoring] Replace gym version checking with decorated functions (#) Initial commit. Only tests. * Refactoring in gym.py * More refactoring in gym.py * Completed refactoring * amend * amend * Version 0.0.3 (#696) * [Docs] Host TensorDict docs inside TorchRL docs (#693) * Pull tensordict docs into TorchRL docs * Add banner for tensordict docs * [BugFix] Fix docs build (#698) * [BugFix] Proper error messages for orphan transform creation (#697) * amend * [Feature] Append, init and insert transforms in ReplayBuffer (#695) * lint Co-authored-by: albertbou92 Co-authored-by: Aditya Gandhamal <61016383+adityagandhamal@users.noreply.github.com> Co-authored-by: yingchenlin Co-authored-by: Sergey Ordinskiy <113687736+ordinskiy@users.noreply.github.com> Co-authored-by: Tom Begley Co-authored-by: Alan Schelten --- .circleci/config.yml | 49 +++ .../linux_examples/scripts/run_test.sh | 64 ++- .../scripts_jumanji/environment.yml | 18 + .../linux_libs/scripts_jumanji/install.sh | 45 ++ .../scripts_jumanji/post_process.sh | 6 + .../scripts_jumanji/run-clang-format.py | 356 ++++++++++++++++ .../linux_libs/scripts_jumanji/run_test.sh | 32 ++ .../linux_libs/scripts_jumanji/setup_env.sh | 62 +++ .../unittest/linux_optdeps/scripts/install.sh | 2 +- .github/workflows/docs.yml | 4 + .github/workflows/wheels.yml | 6 +- docs/source/index.rst | 3 + docs/source/reference/envs.rst | 3 + examples/ddpg/ddpg.py | 4 +- examples/dqn/dqn.py | 4 +- examples/dreamer/dreamer.py | 5 +- examples/dreamer/dreamer_utils.py | 4 +- examples/ppo/ppo.py | 6 +- examples/redq/redq.py | 6 +- examples/sac/sac.py | 6 +- setup.py | 2 +- test/_utils_internal.py | 76 ++-- test/smoke_test_deps.py | 13 +- test/test_collector.py | 24 +- test/test_env.py | 31 +- test/test_libs.py | 157 +++++-- test/test_modules.py | 51 ++- test/test_rb.py | 186 +++++++-- test/test_trainer.py | 36 +- test/test_transforms.py | 89 +++- torchrl/data/replay_buffers/rb_prototype.py | 50 ++- torchrl/data/replay_buffers/replay_buffers.py | 25 +- torchrl/data/replay_buffers/storages.py | 74 +++- torchrl/data/tensor_specs.py | 43 +- torchrl/envs/common.py | 6 +- torchrl/envs/libs/gym.py | 134 +++--- torchrl/envs/libs/jumanji.py | 391 ++++++++++++++++++ torchrl/envs/transforms/__init__.py | 3 +- torchrl/envs/transforms/transforms.py | 103 +++-- torchrl/modules/models/models.py | 9 + torchrl/modules/planners/cem.py | 126 +++--- torchrl/modules/planners/mppi.py | 252 +++++++++++ torchrl/trainers/helpers/envs.py | 11 +- torchrl/trainers/helpers/replay_buffer.py | 2 - torchrl/trainers/loggers/csv.py | 3 + version.txt | 2 +- 46 files changed, 2178 insertions(+), 406 deletions(-) create mode 100644 .circleci/unittest/linux_libs/scripts_jumanji/environment.yml create mode 100755 .circleci/unittest/linux_libs/scripts_jumanji/install.sh create mode 100755 .circleci/unittest/linux_libs/scripts_jumanji/post_process.sh create mode 100755 .circleci/unittest/linux_libs/scripts_jumanji/run-clang-format.py create mode 100755 .circleci/unittest/linux_libs/scripts_jumanji/run_test.sh create mode 100755 .circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh create mode 100644 torchrl/envs/libs/jumanji.py create mode 100644 torchrl/modules/planners/mppi.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 35e1edd7ecf..ed602992ee4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -453,6 +453,51 @@ jobs: - store_test_results: path: test-results + unittest_linux_jumanji_gpu: + <<: *binary_common + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "pytorch/manylinux-cuda113" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + CU_VERSION: << parameters.cu_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_jumanji/environment.yml" }}-{{ checksum ".circleci-weekly" }} + - run: + name: Setup + command: .circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh + - save_cache: + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_jumanji/environment.yml" }}-{{ checksum ".circleci-weekly" }} + paths: + - conda + - env + - run: + name: Install torchrl + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_libs/scripts_jumanji/install.sh + - run: + name: Run tests + command: bash .circleci/unittest/linux_libs/scripts_jumanji/run_test.sh + - run: + name: Codecov upload + command: | + bash <(curl -s https://codecov.io/bash) -Z -F linux-jumanji + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_jumanji/post_process.sh + - store_test_results: + path: test-results + unittest_linux_gym_gpu: <<: *binary_common machine: @@ -879,6 +924,10 @@ workflows: cu_version: cu113 name: unittest_linux_habitat_gpu_py3.8 python_version: '3.8' + - unittest_linux_jumanji_gpu: + cu_version: cu113 + name: unittest_linux_jumanji_gpu_py3.8 + python_version: '3.8' - unittest_linux_gym_gpu: cu_version: cu113 name: unittest_linux_gym_gpu_py3.8 diff --git a/.circleci/unittest/linux_examples/scripts/run_test.sh b/.circleci/unittest/linux_examples/scripts/run_test.sh index 4ac07a2c2d5..880d4baa43d 100755 --- a/.circleci/unittest/linux_examples/scripts/run_test.sh +++ b/.circleci/unittest/linux_examples/scripts/run_test.sh @@ -26,8 +26,8 @@ export MKL_THREADING_LAYER=GNU coverage run -m pytest test/smoke_test.py -v --durations 20 coverage run -m pytest test/smoke_test_deps.py -v --durations 20 -coverage run -m python examples/ddpg/ddpg.py \ - total_frames=14 \ +coverage run examples/ddpg/ddpg.py \ + total_frames=48 \ init_random_frames=10 \ batch_size=10 \ frames_per_batch=16 \ @@ -35,9 +35,11 @@ coverage run -m python examples/ddpg/ddpg.py \ env_per_collector=1 \ collector_devices=cuda:0 \ optim_steps_per_batch=1 \ - record_video=True -coverage run -m python examples/dqn/dqn.py \ - total_frames=14 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +coverage run examples/dqn/dqn.py \ + total_frames=48 \ init_random_frames=10 \ batch_size=10 \ frames_per_batch=16 \ @@ -45,5 +47,55 @@ coverage run -m python examples/dqn/dqn.py \ env_per_collector=1 \ collector_devices=cuda:0 \ optim_steps_per_batch=1 \ - record_video=True + record_video=True \ + record_frames=4 \ + buffer_size=120 +coverage run examples/redq/redq.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_devices=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +coverage run examples/sac/sac.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_devices=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 +coverage run examples/ppo/ppo.py \ + total_frames=48 \ + batch_size=10 \ + frames_per_batch=16 \ + num_workers=2 \ + env_per_collector=1 \ + collector_devices=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + lr_scheduler= +coverage run examples/dreamer/dreamer.py \ + total_frames=48 \ + init_random_frames=10 \ + batch_size=10 \ + frames_per_batch=200 \ + num_workers=2 \ + env_per_collector=1 \ + collector_devices=cuda:0 \ + optim_steps_per_batch=1 \ + record_video=True \ + record_frames=4 \ + buffer_size=120 \ + rssm_hidden_dim=17 coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/environment.yml b/.circleci/unittest/linux_libs/scripts_jumanji/environment.yml new file mode 100644 index 00000000000..a7456ada46f --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_jumanji/environment.yml @@ -0,0 +1,18 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - expecttest + - pyyaml + - scipy + - hydra-core + - jumanji diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/install.sh b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh new file mode 100755 index 00000000000..c0f97977649 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_jumanji/install.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# smoke test +python -c "import functorch" + +printf "* Installing torchrl\n" +pip3 install -e . + +# smoke test +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/post_process.sh b/.circleci/unittest/linux_libs/scripts_jumanji/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_jumanji/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/run-clang-format.py b/.circleci/unittest/linux_libs/scripts_jumanji/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_jumanji/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/run_test.sh b/.circleci/unittest/linux_libs/scripts_jumanji/run_test.sh new file mode 100755 index 00000000000..ba250511c90 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_jumanji/run_test.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +#wget https://github.com/openai/mujoco-py/blob/master/vendor/10_nvidia.json +#mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json + +# this workflow only tests the libs +python -c "import jumanji" + +coverage run -m pytest test/test_libs.py --instafail -v --durations 20 --capture no -k TestJumanji +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh b/.circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh new file mode 100755 index 00000000000..705bd9a3814 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +## 3. Install mujoco +#printf "* Installing mujoco and related\n" +#mkdir $root_dir/.mujoco +#cd $root_dir/.mujoco/ +#wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz +#tar -xf mujoco-2.1.1-linux-x86_64.tar.gz +#wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz +#tar -xf mujoco210-linux-x86_64.tar.gz +#cd $this_dir + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +#yum makecache +#yum -y install glfw-devel +#yum -y install libGLEW +#yum -y install gcc-c++ diff --git a/.circleci/unittest/linux_optdeps/scripts/install.sh b/.circleci/unittest/linux_optdeps/scripts/install.sh index e5bfe588fb5..84951e95f24 100755 --- a/.circleci/unittest/linux_optdeps/scripts/install.sh +++ b/.circleci/unittest/linux_optdeps/scripts/install.sh @@ -42,7 +42,7 @@ pip install git+https://github.com/pytorch-labs/tensordict python -c "import functorch" printf "* Installing torchrl\n" -python setup.py develop +pip3 install -e . # smoke test python -c "import torchrl" diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 033f41de9fc..8b9b3fe21d5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -74,6 +74,10 @@ jobs: cd ./docs conda run -n build_binary make html cd .. + - name: Pull TensorDict docs + run: | + git clone --branch gh-pages https://github.com/pytorch-labs/tensordict.git docs/build/html/tensordict + rm -rf docs/build/html/tensordict/.git - name: Get output time run: echo "The time was ${{ steps.build.outputs.time }}" - name: Deploy diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 4633c83aac0..b080a8c8c8c 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -4,7 +4,7 @@ on: types: [opened, synchronize, reopened] push: branches: - - release/0.0.2a + - release/0.0.3 jobs: @@ -26,7 +26,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.0.2a python3 setup.py bdist_wheel + BUILD_VERSION=0.0.3 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -66,7 +66,7 @@ jobs: run: | export CC=clang CXX=clang++ python3 -mpip install wheel - BUILD_VERSION=0.0.2a python3 setup.py bdist_wheel + BUILD_VERSION=0.0.3 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/docs/source/index.rst b/docs/source/index.rst index 2eecbe8cf1a..39c90e1841b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -6,6 +6,9 @@ Welcome to the TorchRL Documentation! ===================================== +.. note:: + The TensorDict class has been moved out of TorchRL into a dedicated library. Take a look at `the documentation <./tensordict>`_ or find the source code `on GitHub `_. + TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. It provides pytorch and python-first, low and high level abstractions for RL that are intended to be efficient, modular, documented and properly tested. diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 7112d8a59f2..ad84d38bdc1 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -248,3 +248,6 @@ Libraries gym.GymWrapper dm_control.DMControlEnv dm_control.DMControlWrapper + jumanji.JumanjiEnv + jumanji.JumanjiWrapper + habitat.HabitatEnv diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index ad65b4cd737..5fb84a2b25f 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -117,7 +117,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), + key=("next", "pixels") + if cfg.from_pixels + else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() diff --git a/examples/dqn/dqn.py b/examples/dqn/dqn.py index c776a52fecb..fea51d0667c 100644 --- a/examples/dqn/dqn.py +++ b/examples/dqn/dqn.py @@ -107,7 +107,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), + key=("next", "pixels") + if cfg.from_pixels + else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() diff --git a/examples/dreamer/dreamer.py b/examples/dreamer/dreamer.py index 602d49b7256..f6926638182 100644 --- a/examples/dreamer/dreamer.py +++ b/examples/dreamer/dreamer.py @@ -126,7 +126,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_environment=transformed_env_constructor(cfg)(), - key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), + key=("next", "pixels") + if cfg.from_pixels + else ("next", "observation_vector"), ) stats = {k: v.clone() for k, v in stats.items()} elif cfg.from_pixels: @@ -246,6 +248,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # To be closer to the paper, we would need to fill it with trajectories of lentgh 1000 and then sample subsequences of length batch_length. # tensordict = tensordict.reshape(-1, cfg.batch_length) + print(tensordict.shape) replay_buffer.extend(tensordict.cpu()) logger.log_scalar( "r_training", diff --git a/examples/dreamer/dreamer_utils.py b/examples/dreamer/dreamer_utils.py index 0eef5be0413..341f97739cc 100644 --- a/examples/dreamer/dreamer_utils.py +++ b/examples/dreamer/dreamer_utils.py @@ -92,7 +92,7 @@ def make_env_transforms( env.append_transform(Resize(cfg.image_size, cfg.image_size)) if cfg.grayscale: env.append_transform(GrayScale()) - env.append_transform(FlattenObservation()) + env.append_transform(FlattenObservation(0)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: obs_stats = {"loc": 0.0, "scale": 1.0} @@ -354,7 +354,7 @@ def make_recorder_env(cfg, video_tag, stats, logger, create_env_fn): recorder_rm = TransformedEnv(recorder.base_env) for transform in recorder.transform: if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) + recorder_rm.append_transform(transform.clone()) else: recorder_rm = recorder diff --git a/examples/ppo/ppo.py b/examples/ppo/ppo.py index a992b127306..a5b7ea1da96 100644 --- a/examples/ppo/ppo.py +++ b/examples/ppo/ppo.py @@ -103,7 +103,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), + key=("next", "pixels") + if cfg.from_pixels + else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() @@ -162,7 +164,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder_rm = TransformedEnv(recorder.base_env) for transform in recorder.transform: if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) + recorder_rm.append_transform(transform.clone()) else: recorder_rm = recorder diff --git a/examples/redq/redq.py b/examples/redq/redq.py index 654ee082cb4..398d8610368 100644 --- a/examples/redq/redq.py +++ b/examples/redq/redq.py @@ -118,7 +118,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), + key=("next", "pixels") + if cfg.from_pixels + else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() @@ -191,7 +193,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder_rm = TransformedEnv(recorder.base_env) for transform in recorder.transform: if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) + recorder_rm.append_transform(transform.clone()) else: recorder_rm = recorder diff --git a/examples/sac/sac.py b/examples/sac/sac.py index 2aa61d3df68..e23f567dc36 100644 --- a/examples/sac/sac.py +++ b/examples/sac/sac.py @@ -118,7 +118,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stats = get_stats_random_rollout( cfg, proof_env, - key=("next", "pixels") if cfg.from_pixels else ("next", "observation_vector"), + key=("next", "pixels") + if cfg.from_pixels + else ("next", "observation_vector"), ) # make sure proof_env is closed proof_env.close() @@ -187,7 +189,7 @@ def main(cfg: "DictConfig"): # noqa: F821 recorder_rm = TransformedEnv(recorder.base_env) for transform in recorder.transform: if not isinstance(transform, VideoRecorder): - recorder_rm.append_transform(transform) + recorder_rm.append_transform(transform.clone()) else: recorder_rm = recorder diff --git a/setup.py b/setup.py index bba67f6c94f..341f012e247 100644 --- a/setup.py +++ b/setup.py @@ -207,7 +207,7 @@ def _main(argv): "numpy", "packaging", "cloudpickle", - "tensordict-nightly", + "tensordict", ], extras_require={ "atari": [ diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 421c61b08a4..c5e7bb6ea45 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -11,14 +11,43 @@ # this returns relative path from current file. import pytest import torch.cuda -from tensordict.tensordict import TensorDictBase -from torchrl._utils import seed_generator +from torchrl._utils import seed_generator, implement_for from torchrl.envs import EnvBase - +from torchrl.envs.libs.gym import _has_gym # Specified for test_utils.py __version__ = "0.3" +# Default versions of the environments. +CARTPOLE_VERSIONED = "CartPole-v1" +HALFCHEETAH_VERSIONED = "HalfCheetah-v4" +PENDULUM_VERSIONED = "Pendulum-v1" +PONG_VERSIONED = "ALE/Pong-v5" + + +@implement_for("gym", None, "0.21.0") +def _set_gym_environments(): # noqa: F811 + global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + + CARTPOLE_VERSIONED = "CartPole-v0" + HALFCHEETAH_VERSIONED = "HalfCheetah-v2" + PENDULUM_VERSIONED = "Pendulum-v0" + PONG_VERSIONED = "Pong-v4" + + +@implement_for("gym", "0.21.0", None) +def _set_gym_environments(): # noqa: F811 + global CARTPOLE_VERSIONED, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, PONG_VERSIONED + + CARTPOLE_VERSIONED = "CartPole-v1" + HALFCHEETAH_VERSIONED = "HalfCheetah-v4" + PENDULUM_VERSIONED = "Pendulum-v1" + PONG_VERSIONED = "ALE/Pong-v5" + + +if _has_gym: + _set_gym_environments() + def get_relative_path(curr_file, *path_components): return os.path.join(os.path.dirname(curr_file), *path_components) @@ -48,10 +77,13 @@ def _test_fake_tensordict(env: EnvBase): keys1 = set(fake_tensordict.keys()) keys2 = set(real_tensordict.keys()) assert keys1 == keys2 - fake_tensordict = fake_tensordict.expand(3).to_tensordict() - fake_tensordict.zero_() - real_tensordict.zero_() - assert (fake_tensordict == real_tensordict).all() + fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) + fake_tensordict = fake_tensordict.expand(*real_tensordict.shape) + fake_tensordict = fake_tensordict.to_tensordict() + assert ( + fake_tensordict.apply(lambda x: torch.zeros_like(x)) + == real_tensordict.apply(lambda x: torch.zeros_like(x)) + ).all() for key in keys2: assert fake_tensordict[key].shape == real_tensordict[key].shape @@ -61,26 +93,20 @@ def _test_fake_tensordict(env: EnvBase): def _check_dtype(key, value, obs_spec, input_spec): - if isinstance(value, TensorDictBase) and key == "next": - for _key, _value in value.items(): - _check_dtype(_key, _value, obs_spec, input_spec=None) - elif isinstance(value, TensorDictBase) and key in obs_spec.keys(): - for _key, _value in value.items(): - _check_dtype(_key, _value, obs_spec=obs_spec[key], input_spec=None) - elif isinstance(value, TensorDictBase) and key in input_spec.keys(): + if key in {"reward", "done"}: + return + elif key == "next": for _key, _value in value.items(): - _check_dtype(_key, _value, obs_spec=None, input_spec=input_spec[key]) + _check_dtype(_key, _value, obs_spec, input_spec) + return + elif key in input_spec.keys(yield_nesting_keys=True): + assert input_spec[key].is_in(value), (input_spec[key], value) + return + elif key in obs_spec.keys(yield_nesting_keys=True): + assert obs_spec[key].is_in(value), (input_spec[key], value) + return else: - if obs_spec is not None and key in obs_spec.keys(): - assert ( - obs_spec[key].dtype is value.dtype - ), f"{obs_spec[key].dtype} vs {value.dtype} for {key}" - elif input_spec is not None and key in input_spec.keys(): - assert ( - input_spec[key].dtype is value.dtype - ), f"{input_spec[key].dtype} vs {value.dtype} for {key}" - else: - assert key in {"done", "reward"}, (key, obs_spec, input_spec) + raise KeyError(key) # Decorator to retry upon certain Exceptions. diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index 03caa3e8d39..56463039bf4 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -2,21 +2,10 @@ import tempfile import pytest +from _utils_internal import PONG_VERSIONED from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv -if _has_gym: - import gym - from packaging import version - - gym_version = version.parse(gym.__version__) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) -else: - # placeholders - PONG_VERSIONED = "ALE/Pong-v5" - try: from torch.utils.tensorboard import SummaryWriter diff --git a/test/test_collector.py b/test/test_collector.py index a0b122fd563..f7f94035e0e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -8,7 +8,7 @@ import numpy as np import pytest import torch -from _utils_internal import generate_seeds +from _utils_internal import generate_seeds, PENDULUM_VERSIONED, PONG_VERSIONED from mocking_classes import ( ContinuousActionVecMockEnv, DiscreteActionConvMockEnv, @@ -42,22 +42,6 @@ TensorDictModule, ) -if _has_gym: - import gym - from packaging import version - - gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) -else: - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - PONG_VERSIONED = "ALE/Pong-v5" - # torch.set_default_dtype(torch.double) @@ -314,7 +298,11 @@ def env_fn(seed): @pytest.mark.skipif(not _has_gym, reason="gym library is not installed") def test_collector_env_reset(): torch.manual_seed(0) - env = SerialEnv(2, lambda: GymEnv(PONG_VERSIONED, frame_skip=4)) + + def make_env(): + return GymEnv(PONG_VERSIONED, frame_skip=4) + + env = SerialEnv(2, make_env) # env = SerialEnv(3, lambda: GymEnv("CartPole-v1", frame_skip=4)) env.set_seed(0) collector = SyncDataCollector( diff --git a/test/test_env.py b/test/test_env.py index aefeb4f36de..fa1607041ae 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -11,7 +11,13 @@ import pytest import torch import yaml -from _utils_internal import get_available_devices +from _utils_internal import ( + get_available_devices, + CARTPOLE_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + HALFCHEETAH_VERSIONED, +) from mocking_classes import ( ActionObsMergeLinear, DiscreteActionConvMockEnv, @@ -49,30 +55,11 @@ ) from torchrl.modules.tensordict_module import WorldModelWrapper +gym_version = None if _has_gym: import gym gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) - CARTPOLE_VERSIONED = ( - "CartPole-v1" if gym_version > version.parse("0.20.0") else "CartPole-v0" - ) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) - HALFCHEETAH_VERSIONED = ( - "HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2" - ) -else: - # placeholder - gym_version = version.parse("0.0.1") - - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - CARTPOLE_VERSIONED = "CartPole-v1" - PONG_VERSIONED = "ALE/Pong-v5" try: this_dir = os.path.dirname(os.path.realpath(__file__)) @@ -1048,7 +1035,7 @@ def test_batch_unlocked_with_batch_size(device): @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.skipif( - gym_version < version.parse("0.20.0"), + gym_version is None or gym_version < version.parse("0.20.0"), reason="older versions of half-cheetah do not have 'x_position' info key.", ) def test_info_dict_reader(seed=0): diff --git a/test/test_libs.py b/test/test_libs.py index 8667b158934..bb853f9642d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3,18 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +from sys import platform import numpy as np import pytest import torch -from _utils_internal import _test_fake_tensordict -from _utils_internal import get_available_devices +from _utils_internal import ( + _test_fake_tensordict, + get_available_devices, + HALFCHEETAH_VERSIONED, + PONG_VERSIONED, + PENDULUM_VERSIONED, +) from packaging import version +from tensordict.tensordict import assert_allclose_td +from torchrl._utils import implement_for from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.collectors import RandomPolicy +from torchrl.envs import EnvCreator, ParallelEnv +from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper from torchrl.envs.libs.dm_control import _has_dmc +from torchrl.envs.libs.gym import GymEnv, GymWrapper from torchrl.envs.libs.gym import _has_gym, _is_from_pixels from torchrl.envs.libs.habitat import HabitatEnv, _has_habitat +from torchrl.envs.libs.jumanji import JumanjiEnv, _has_jumanji if _has_gym: import gym @@ -31,39 +43,8 @@ from dm_control import suite from dm_control.suite.wrappers import pixels -from sys import platform - -from tensordict.tensordict import assert_allclose_td -from torchrl.envs import EnvCreator, ParallelEnv -from torchrl.envs.libs.dm_control import DMControlEnv, DMControlWrapper -from torchrl.envs.libs.gym import GymEnv, GymWrapper - IS_OSX = platform == "darwin" -if _has_gym: - from packaging import version - - gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) - HC_VERSIONED = ( - "HalfCheetah-v4" if gym_version > version.parse("0.20.0") else "HalfCheetah-v2" - ) - PONG_VERSIONED = ( - "ALE/Pong-v5" if gym_version > version.parse("0.20.0") else "Pong-v4" - ) - - # if gym_version < version.parse("0.24.0") and torch.cuda.device_count() > 0: - # from opengl_rendering import create_opengl_context - # - # create_opengl_context() -else: - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - HC_VERSIONED = "HalfCheetah-v4" - PONG_VERSIONED = "ALE/Pong-v5" - @pytest.mark.skipif(not _has_gym, reason="no gym library found") @pytest.mark.parametrize( @@ -122,10 +103,7 @@ def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): base_env = gym.make(env_name, frameskip=frame_skip) frame_skip = 1 else: - if gym_version < version.parse("0.26.0"): - base_env = gym.make(env_name) - else: - base_env = gym.make(env_name, render_mode="rgb_array") + base_env = _make_gym_environment(env_name) if from_pixels and not _is_from_pixels(base_env): base_env = PixelObservationWrapper(base_env, pixels_only=pixels_only) @@ -163,6 +141,16 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): _test_fake_tensordict(env) +@implement_for("gym", None, "0.26") +def _make_gym_environment(env_name): # noqa: F811 + return gym.make(env_name) + + +@implement_for("gym", "0.26", None) +def _make_gym_environment(env_name): # noqa: F811 + return gym.make(env_name, render_mode="rgb_array") + + @pytest.mark.skipif(not _has_dmc, reason="no dm_control library found") @pytest.mark.parametrize("env_name,task", [["cheetah", "run"]]) @pytest.mark.parametrize("frame_skip", [1, 3]) @@ -269,9 +257,9 @@ def test_faketd(self, env_name, task, frame_skip, from_pixels, pixels_only): "env_lib,env_args,env_kwargs", [ [DMControlEnv, ("cheetah", "run"), {"from_pixels": True}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": True}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}], [DMControlEnv, ("cheetah", "run"), {"from_pixels": False}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": False}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}], [GymEnv, (PONG_VERSIONED,), {}], ], ) @@ -300,15 +288,15 @@ def test_td_creation_from_spec(env_lib, env_args, env_kwargs): assert fake_td.get(key).device == td0.get(key).device -@pytest.mark.skipif(IS_OSX, reason="rendeing unstable on osx, skipping") +@pytest.mark.skipif(IS_OSX, reason="rendering unstable on osx, skipping") @pytest.mark.skipif(not (_has_dmc and _has_gym), reason="gym or dm_control not present") @pytest.mark.parametrize( "env_lib,env_args,env_kwargs", [ [DMControlEnv, ("cheetah", "run"), {"from_pixels": True}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": True}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": True}], [DMControlEnv, ("cheetah", "run"), {"from_pixels": False}], - [GymEnv, (HC_VERSIONED,), {"from_pixels": False}], + [GymEnv, (HALFCHEETAH_VERSIONED,), {"from_pixels": False}], [GymEnv, (PONG_VERSIONED,), {}], ], ) @@ -354,6 +342,91 @@ def test_habitat(self, envname): _test_fake_tensordict(env) +@pytest.mark.skipif(not _has_jumanji, reason="jumanji not installed") +@pytest.mark.parametrize("envname", ["Snake-6x6-v0", "TSP50-v0"]) +class TestJumanji: + def test_jumanji_seeding(self, envname): + final_seed = [] + tdreset = [] + tdrollout = [] + for _ in range(2): + env = JumanjiEnv(envname) + torch.manual_seed(0) + np.random.seed(0) + final_seed.append(env.set_seed(0)) + tdreset.append(env.reset()) + tdrollout.append(env.rollout(max_steps=50)) + env.close() + del env + assert final_seed[0] == final_seed[1] + assert_allclose_td(*tdreset) + assert_allclose_td(*tdrollout) + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_jumanji_batch_size(self, envname, batch_size): + env = JumanjiEnv(envname, batch_size=batch_size) + env.set_seed(0) + tdreset = env.reset() + tdrollout = env.rollout(max_steps=50) + env.close() + del env + assert tdreset.batch_size == batch_size + assert tdrollout.batch_size[:-1] == batch_size + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_jumanji_spec_rollout(self, envname, batch_size): + env = JumanjiEnv(envname, batch_size=batch_size) + env.set_seed(0) + _test_fake_tensordict(env) + + @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + def test_jumanji_consistency(self, envname, batch_size): + import jax + import jax.numpy as jnp + import numpy as onp + + env = JumanjiEnv(envname, batch_size=batch_size) + obs_keys = list(env.observation_spec.keys(True)) + env.set_seed(1) + rollout = env.rollout(10) + + env.set_seed(1) + key = env.key + base_env = env._env + key, *keys = jax.random.split(key, np.prod(batch_size) + 1) + state, timestep = jax.vmap(base_env.reset)(jnp.stack(keys)) + # state = env._reshape(state) + # timesteps.append(timestep) + for i in range(rollout.shape[-1]): + action = rollout[..., i]["action"] + # state = env._flatten(state) + action = env._flatten(env.read_action(action)) + state, timestep = jax.vmap(base_env.step)(state, action) + # state = env._reshape(state) + # timesteps.append(timestep) + checked = False + for _key in obs_keys: + if isinstance(_key, str): + _key = (_key,) + try: + t2 = getattr(timestep, _key[0]) + except AttributeError: + try: + t2 = getattr(timestep.observation, _key[0]) + except AttributeError: + continue + t1 = rollout[..., i][("next", *_key)] + for __key in _key[1:]: + t2 = getattr(t2, _key) + t2 = torch.tensor(onp.asarray(t2)).view_as(t1) + torch.testing.assert_close(t1, t2) + checked = True + if not checked: + raise AttributeError( + f"None of the keys matched: {rollout}, {list(timestep.__dict__.keys())}" + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_modules.py b/test/test_modules.py index 57cbd4d99a4..4590c393f50 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -40,6 +40,8 @@ RSSMRollout, ) from torchrl.modules.models.utils import SquashDims +from torchrl.modules.planners.mppi import MPPIPlanner +from torchrl.objectives.value import TDLambdaEstimate @pytest.fixture @@ -124,7 +126,9 @@ def test_mlp( ) @pytest.mark.parametrize("squeeze_output", [False]) @pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("batch", [(2,), (2, 2)]) def test_convnet( + batch, in_features, depth, num_cells, @@ -145,7 +149,6 @@ def test_convnet( seed=0, ): torch.manual_seed(seed) - batch = 2 convnet = ConvNet( in_features=in_features, depth=depth, @@ -165,9 +168,9 @@ def test_convnet( ) if in_features is None: in_features = 5 - x = torch.randn(batch, in_features, input_size, input_size, device=device) + x = torch.randn(*batch, in_features, input_size, input_size, device=device) y = convnet(x) - assert y.shape == torch.Size([batch, expected_features]) + assert y.shape == torch.Size([*batch, expected_features]) @pytest.mark.parametrize( @@ -468,9 +471,9 @@ def test_func_transformer(self): torch.testing.assert_close(fmodule(params, buffers, x, x), module(x, x)) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("batch_size", [3, 5]) class TestPlanner: - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("batch_size", [3, 5]) def test_CEM_model_free_env(self, device, batch_size, seed=1): env = MockBatchedUnLockedEnv(device=device) torch.manual_seed(seed) @@ -479,13 +482,47 @@ def test_CEM_model_free_env(self, device, batch_size, seed=1): planning_horizon=10, optim_steps=2, num_candidates=100, - num_top_k_candidates=2, + top_k=2, ) td = env.reset(TensorDict({}, batch_size=batch_size).to(device)) td_copy = td.clone() td = planner(td) - assert td.get("action").shape[1:] == env.action_spec.shape + assert ( + td.get("action").shape[-len(env.action_spec.shape) :] + == env.action_spec.shape + ) + assert env.action_spec.is_in(td.get("action")) + + for key in td.keys(): + if key != "action": + assert torch.allclose(td[key], td_copy[key]) + def test_MPPI(self, device, batch_size, seed=1): + torch.manual_seed(seed) + env = MockBatchedUnLockedEnv(device=device) + value_net = nn.LazyLinear(1, device=device) + value_net = ValueOperator(value_net, in_keys=["observation"]) + advantage_module = TDLambdaEstimate( + 0.99, + 0.95, + value_net, + ) + planner = MPPIPlanner( + env, + advantage_module, + temperature=1.0, + planning_horizon=10, + optim_steps=2, + num_candidates=100, + top_k=2, + ) + td = env.reset(TensorDict({}, batch_size=batch_size).to(device)) + td_copy = td.clone() + td = planner(td) + assert ( + td.get("action").shape[-len(env.action_spec.shape) :] + == env.action_spec.shape + ) assert env.action_spec.is_in(td.get("action")) for key in td.keys(): diff --git a/test/test_rb.py b/test/test_rb.py index 36915407450..0ac6228fcc2 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -4,6 +4,9 @@ # LICENSE file in the root directory of this source tree. import argparse +import importlib +from functools import partial +from unittest import mock import numpy as np import pytest @@ -28,14 +31,29 @@ ListStorage, ) from torchrl.data.replay_buffers.writers import RoundRobinWriter +from torchrl.envs.transforms.transforms import ( + CatTensors, + FlattenObservation, + SqueezeTransform, + ToTensorImage, + RewardClipping, + BinarizeReward, + Resize, + CenterCrop, + UnsqueezeTransform, + GrayScale, + ObservationNorm, + CatFrames, + RewardScaling, + DoubleToFloat, + VecNorm, + DiscreteActionProjection, + FiniteTensorDictCheck, + gSDENoise, + PinMemoryTransform, +) - -collate_fn_dict = { - ListStorage: lambda x: torch.stack(x, 0), - LazyTensorStorage: lambda x: x, - LazyMemmapStorage: lambda x: x, - None: lambda x: torch.stack(x, 0), -} +_has_tv = importlib.util.find_spec("torchvision") is not None @pytest.mark.parametrize( @@ -54,7 +72,6 @@ @pytest.mark.parametrize("size", [3, 100]) class TestPrototypeBuffers: def _get_rb(self, rb_type, size, sampler, writer, storage): - collate_fn = collate_fn_dict[storage] if storage is not None: storage = storage(size) @@ -65,9 +82,7 @@ def _get_rb(self, rb_type, size, sampler, writer, storage): sampler = sampler(**sampler_args) writer = writer() - rb = rb_type( - collate_fn=collate_fn, storage=storage, sampler=sampler, writer=writer - ) + rb = rb_type(storage=storage, sampler=sampler, writer=writer) return rb def _get_datum(self, rb_type): @@ -152,7 +167,6 @@ def test_sample(self, rb_type, sampler, writer, storage, size): for d in new_data: found_similar = False for b in data: - print(b, d) if isinstance(b, TensorDictBase): keys = set(d.keys()).intersection(b.keys()) b = b.exclude("index").select(*keys, strict=False) @@ -192,7 +206,6 @@ def test_prototype_prb(priority_key, contiguous, device): np.random.seed(0) rb = rb_prototype.TensorDictReplayBuffer( sampler=samplers.PrioritizedSampler(5, alpha=0.7, beta=0.9), - collate_fn=None if contiguous else lambda x: torch.stack(x, 0), priority_key=priority_key, ) td1 = TensorDict( @@ -271,7 +284,6 @@ def test_rb_prototype_trajectories(stack): alpha=0.7, beta=0.9, ), - collate_fn=lambda x: torch.stack(x, 0), priority_key="td_error", ) rb.extend(traj_td) @@ -315,7 +327,6 @@ class TestBuffers: _default_params_td_prb = {"alpha": 0.8, "beta": 0.9} def _get_rb(self, rbtype, size, storage, prefetch): - collate_fn = collate_fn_dict[storage] if storage is not None: storage = storage(size) if rbtype is ReplayBuffer: @@ -328,13 +339,7 @@ def _get_rb(self, rbtype, size, storage, prefetch): params = self._default_params_td_prb else: raise NotImplementedError(rbtype) - rb = rbtype( - size=size, - storage=storage, - prefetch=prefetch, - collate_fn=collate_fn, - **params - ) + rb = rbtype(size=size, storage=storage, prefetch=prefetch, **params) return rb def _get_datum(self, rbtype): @@ -460,7 +465,6 @@ def test_prb(priority_key, contiguous, device): 5, alpha=0.7, beta=0.9, - collate_fn=None if contiguous else lambda x: torch.stack(x, 0), priority_key=priority_key, ) td1 = TensorDict( @@ -537,7 +541,6 @@ def test_rb_trajectories(stack): 5, alpha=0.7, beta=0.9, - collate_fn=lambda x: torch.stack(x, 0), priority_key="td_error", ) rb.extend(traj_td) @@ -565,10 +568,14 @@ def test_shared_storage_prioritized_sampler(): sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1) rb0 = rb_prototype.ReplayBuffer( - storage=storage, writer=writer, sampler=sampler0, collate_fn=lambda x: x + storage=storage, + writer=writer, + sampler=sampler0, ) rb1 = rb_prototype.ReplayBuffer( - storage=storage, writer=writer, sampler=sampler1, collate_fn=lambda x: x + storage=storage, + writer=writer, + sampler=sampler1, ) data = TensorDict({"a": torch.arange(50)}, [50]) @@ -593,9 +600,11 @@ def test_legacy_rb_does_not_attach(): storage = LazyMemmapStorage(n) writer = RoundRobinWriter() sampler = RandomSampler() - rb = ReplayBuffer(storage=storage, size=n, prefetch=0, collate_fn=lambda x: x) + rb = ReplayBuffer(storage=storage, size=n, prefetch=0) prb = rb_prototype.ReplayBuffer( - storage=storage, writer=writer, sampler=sampler, collate_fn=lambda x: x + storage=storage, + writer=writer, + sampler=sampler, ) assert len(storage._attached_entities) == 1 @@ -603,6 +612,127 @@ def test_legacy_rb_does_not_attach(): assert rb not in storage._attached_entities +def test_append_transform(): + rb = rb_prototype.ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0)) + td = TensorDict( + { + "observation": torch.randn(2, 4, 3, 16), + "observation2": torch.randn(2, 4, 3, 16), + }, + [], + ) + rb.add(td) + flatten = CatTensors( + in_keys=["observation", "observation2"], out_key="observation_cat" + ) + + rb.append_transform(flatten) + + sampled, _ = rb.sample(1) + assert sampled.get("observation_cat").shape[-1] == 32 + + +def test_init_transform(): + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) + + rb = rb_prototype.ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=flatten + ) + + td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) + rb.add(td) + sampled, _ = rb.sample(1) + assert sampled.get("flattened").shape[-1] == 48 + + +def test_insert_transform(): + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) + rb = rb_prototype.ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=flatten + ) + td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) + rb.add(td) + + rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) + + sampled, _ = rb.sample(1) + assert sampled.get("flattened").shape[-1] == 48 + + with pytest.raises(ValueError): + rb.insert_transform(10, SqueezeTransform(-1, in_keys=["observation"])) + + +transforms = [ + ToTensorImage, + pytest.param( + partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping" + ), + BinarizeReward, + pytest.param( + partial(Resize, w=2, h=2), + id="Resize", + marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), + ), + pytest.param( + partial(CenterCrop, w=1), + id="CenterCrop", + marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), + ), + pytest.param( + partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" + ), + pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + GrayScale, + pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), + CatFrames, + pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), + DoubleToFloat, + VecNorm, +] + + +@pytest.mark.parametrize("transform", transforms) +def test_smoke_replay_buffer_transform(transform): + rb = rb_prototype.ReplayBuffer( + transform=transform(in_keys="observation"), + ) + + td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) + rb.add(td) + rb.sample(1) + + rb._transform = mock.MagicMock() + rb.sample(1) + assert rb._transform.called + + +transforms = [ + partial(DiscreteActionProjection, max_n=1, m=1), + FiniteTensorDictCheck, + gSDENoise, + PinMemoryTransform, +] + + +@pytest.mark.parametrize("transform", transforms) +def test_smoke_replay_buffer_transform_no_inkeys(transform): + rb = rb_prototype.ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=transform() + ) + + td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1)}, []) + rb.add(td) + rb.sample(1) + + rb._transform = mock.MagicMock() + rb.sample(1) + assert rb._transform.called + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_trainer.py b/test/test_trainer.py index 7ed0766a84f..1fd576dd58d 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -249,20 +249,22 @@ def test_rb_trainer_state_dict(self, prioritized, storage_type): S = 100 if storage_type == "list": storage = ListStorage(S) - collate_fn = lambda x: torch.stack(x, 0) elif storage_type == "memmap": storage = LazyMemmapStorage(S) - collate_fn = lambda x: x else: raise NotImplementedError if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage, collate_fn=collate_fn + S, + 1.1, + 0.9, + storage=storage, ) else: replay_buffer = TensorDictReplayBuffer( - S, storage=storage, collate_fn=collate_fn + S, + storage=storage, ) N = 9 @@ -371,16 +373,13 @@ def test_rb_trainer_save( def make_storage(): if storage_type == "list": storage = ListStorage(S) - collate_fn = lambda x: torch.stack(x, 0) elif storage_type == "tensor": storage = LazyTensorStorage(S) - collate_fn = lambda x: x elif storage_type == "memmap": storage = LazyMemmapStorage(S) - collate_fn = lambda x: x else: raise NotImplementedError - return storage, collate_fn + return storage with tempfile.TemporaryDirectory() as tmpdirname: if backend == "torch": @@ -391,14 +390,18 @@ def make_storage(): raise NotImplementedError trainer = mocking_trainer(file) - storage, collate_fn = make_storage() + storage = make_storage() if prioritized: replay_buffer = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage, collate_fn=collate_fn + S, + 1.1, + 0.9, + storage=storage, ) else: replay_buffer = TensorDictReplayBuffer( - S, storage=storage, collate_fn=collate_fn + S, + storage=storage, ) rb_trainer = ReplayBufferTrainer(replay_buffer=replay_buffer, batch_size=N) @@ -413,6 +416,7 @@ def make_storage(): [batch], ) trainer._process_batch_hook(td) + # sample from rb td_out = trainer._process_optim_batch_hook(td) if prioritized: td_out.set(replay_buffer.priority_key, torch.rand(N)) @@ -420,14 +424,18 @@ def make_storage(): trainer.save_trainer(True) trainer2 = mocking_trainer() - storage2, _ = make_storage() + storage2 = make_storage() if prioritized: replay_buffer2 = TensorDictPrioritizedReplayBuffer( - S, 1.1, 0.9, storage=storage2, collate_fn=collate_fn + S, + 1.1, + 0.9, + storage=storage2, ) else: replay_buffer2 = TensorDictReplayBuffer( - S, storage=storage2, collate_fn=collate_fn + S, + storage=storage2, ) N = 9 rb_trainer2 = ReplayBufferTrainer( diff --git a/test/test_transforms.py b/test/test_transforms.py index 050d1831fed..d5ec4da981c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,11 +4,17 @@ # LICENSE file in the root directory of this source tree. import argparse from copy import copy, deepcopy +from functools import partial import numpy as np import pytest import torch -from _utils_internal import get_available_devices, retry, dtype_fixture # noqa +from _utils_internal import ( # noqa + get_available_devices, + retry, + dtype_fixture, + PENDULUM_VERSIONED, +) from mocking_classes import ( ContinuousActionVecMockEnv, DiscreteActionConvMockEnvNumpy, @@ -49,6 +55,7 @@ from torchrl.envs.transforms import TransformedEnv, VecNorm from torchrl.envs.transforms.r3m import _R3MNet from torchrl.envs.transforms.transforms import ( + DiscreteActionProjection, _has_tv, CenterCrop, NoopResetEnv, @@ -56,21 +63,10 @@ SqueezeTransform, TensorDictPrimer, UnsqueezeTransform, + gSDENoise, ) from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform -if _has_gym: - import gym - from packaging import version - - gym_version = version.parse(gym.__version__) - PENDULUM_VERSIONED = ( - "Pendulum-v1" if gym_version > version.parse("0.20.0") else "Pendulum-v0" - ) -else: - # placeholders - PENDULUM_VERSIONED = "Pendulum-v1" - TIMEOUT = 10.0 @@ -1289,13 +1285,16 @@ def test_binarized_reward(self, device, batch): @pytest.mark.parametrize("loc", [1, 5]) @pytest.mark.parametrize("keys", [None, ["reward_1"]]) @pytest.mark.parametrize("device", get_available_devices()) - def test_reward_scaling(self, batch, scale, loc, keys, device): + @pytest.mark.parametrize("standard_normal", [True, False]) + def test_reward_scaling(self, batch, scale, loc, keys, device, standard_normal): torch.manual_seed(0) if keys is None: keys_total = set([]) else: keys_total = set(keys) - reward_scaling = RewardScaling(in_keys=keys, scale=scale, loc=loc) + reward_scaling = RewardScaling( + in_keys=keys, scale=scale, loc=loc, standard_normal=standard_normal + ) td = TensorDict( { **{key: torch.randn(*batch, 1, device=device) for key in keys_total}, @@ -1308,13 +1307,17 @@ def test_reward_scaling(self, batch, scale, loc, keys, device): td_copy = td.clone() reward_scaling(td) for key in keys_total: - assert (td.get(key) == td_copy.get(key).mul_(scale).add_(loc)).all() + if standard_normal: + original_key = td.get(key) + scaled_key = (td_copy.get(key) - loc) / scale + torch.testing.assert_close(original_key, scaled_key) + else: + original_key = td.get(key) + scaled_key = td_copy.get(key) * scale + loc + torch.testing.assert_close(original_key, scaled_key) assert (td.get("dont touch") == td_copy.get("dont touch")).all() - if len(keys_total) == 0: - assert ( - td.get("reward") == td_copy.get("reward").mul_(scale).add_(loc) - ).all() - elif len(keys_total) == 1: + + if len(keys_total) == 1: reward_spec = UnboundedContinuousTensorSpec(device=device) reward_spec = reward_scaling.transform_reward_spec(reward_spec) assert reward_spec.shape == torch.Size([1]) @@ -1975,6 +1978,50 @@ def test_batch_unlocked_with_batch_size_transformed(device): env.step(td_expanded) +transforms = [ + ToTensorImage, + pytest.param( + partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping" + ), + BinarizeReward, + pytest.param( + partial(Resize, w=2, h=2), + id="Resize", + marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), + ), + pytest.param( + partial(CenterCrop, w=1), + id="CenterCrop", + marks=pytest.mark.skipif(not _has_tv, reason="needs torchvision dependency"), + ), + pytest.param(partial(FlattenObservation, first_dim=-3), id="FlattenObservation"), + pytest.param( + partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" + ), + pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + GrayScale, + ObservationNorm, + CatFrames, + pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), + FiniteTensorDictCheck, + DoubleToFloat, + CatTensors, + pytest.param( + partial(DiscreteActionProjection, max_n=1, m=1), id="DiscreteActionProjection" + ), + NoopResetEnv, + TensorDictPrimer, + PinMemoryTransform, + gSDENoise, + VecNorm, +] + + +@pytest.mark.parametrize("transform", transforms) +def test_smoke_compose_transform(transform): + Compose(transform()) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/replay_buffers/rb_prototype.py b/torchrl/data/replay_buffers/rb_prototype.py index 65fddaa1257..a80e1abdf28 100644 --- a/torchrl/data/replay_buffers/rb_prototype.py +++ b/torchrl/data/replay_buffers/rb_prototype.py @@ -6,9 +6,10 @@ import torch from tensordict.tensordict import TensorDictBase, LazyStackedTensorDict -from .replay_buffers import pin_memory_output, stack_tensors, stack_td +from torchrl.envs.transforms.transforms import Compose, Transform +from .replay_buffers import pin_memory_output from .samplers import Sampler, RandomSampler -from .storages import Storage, ListStorage +from .storages import Storage, ListStorage, _get_default_collate from .utils import INT_CLASSES, _to_numpy, accept_remote_rref_udf_invocation from .writers import Writer, RoundRobinWriter @@ -30,6 +31,8 @@ class ReplayBuffer: samples. prefetch (int, optional): number of next batches to be prefetched using multithreading. + transform (Transform, optional): Transform to be executed when sample() is called. + To chain transforms use the :obj:`Compose` class. """ def __init__( @@ -40,6 +43,7 @@ def __init__( collate_fn: Optional[Callable] = None, pin_memory: bool = False, prefetch: Optional[int] = None, + transform: Optional[Transform] = None, ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) @@ -47,7 +51,11 @@ def __init__( self._writer = writer if writer is not None else RoundRobinWriter() self._writer.register_storage(self._storage) - self._collate_fn = collate_fn or stack_tensors + self._collate_fn = ( + collate_fn + if collate_fn is not None + else _get_default_collate(self._storage) + ) self._pin_memory = pin_memory self._prefetch = bool(prefetch) @@ -58,6 +66,12 @@ def __init__( self._replay_lock = threading.RLock() self._futures_lock = threading.RLock() + if transform is None: + transform = Compose() + elif not isinstance(transform, Compose): + transform = Compose(transform) + transform.eval() + self._transform = transform def __len__(self) -> int: with self._replay_lock: @@ -127,6 +141,7 @@ def _sample(self, batch_size: int) -> Tuple[Any, dict]: data = self._storage[index] if not isinstance(index, INT_CLASSES): data = self._collate_fn(data) + data = self._transform(data) return data, info def sample(self, batch_size: int) -> Tuple[Any, dict]: @@ -159,6 +174,29 @@ def sample(self, batch_size: int) -> Tuple[Any, dict]: def mark_update(self, index: Union[int, torch.Tensor]) -> None: self._sampler.mark_update(index) + def append_transform(self, transform: Transform) -> None: + """Appends transform at the end. + + Transforms are applied in order when `sample` is called. + + Args: + transform (Transform): The transform to be appended + """ + transform.eval() + self._transform.append(transform) + + def insert_transform(self, index: int, transform: Transform) -> None: + """Inserts transform. + + Transforms are executed in order when `sample` is called. + + Args: + index (int): Position to insert the transform. + transform (Transform): The transform to be appended + """ + transform.eval() + self._transform.insert(index, transform) + class TensorDictReplayBuffer(ReplayBuffer): """TensorDict-specific wrapper around the ReplayBuffer class. @@ -169,12 +207,6 @@ class TensorDictReplayBuffer(ReplayBuffer): """ def __init__(self, priority_key: str = "td_error", **kw) -> None: - if not kw.get("collate_fn"): - - def collate_fn(x): - return stack_td(x, 0, contiguous=True) - - kw["collate_fn"] = collate_fn super().__init__(**kw) self.priority_key = priority_key diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index b26c701e24d..8abd435738d 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -13,7 +13,6 @@ import torch from tensordict.tensordict import ( TensorDictBase, - _stack as stack_td, LazyStackedTensorDict, ) from torch import Tensor @@ -24,7 +23,11 @@ SumSegmentTreeFp32, SumSegmentTreeFp64, ) -from torchrl.data.replay_buffers.storages import Storage, ListStorage +from torchrl.data.replay_buffers.storages import ( + Storage, + ListStorage, + _get_default_collate, +) from torchrl.data.replay_buffers.utils import INT_CLASSES from torchrl.data.replay_buffers.utils import ( _to_numpy, @@ -118,9 +121,11 @@ def __init__( self._storage = storage self._capacity = size self._cursor = 0 - if collate_fn is None: - collate_fn = stack_tensors - self._collate_fn = collate_fn + self._collate_fn = ( + collate_fn + if collate_fn is not None + else _get_default_collate(self._storage) + ) self._pin_memory = pin_memory self._prefetch = prefetch is not None and prefetch > 0 @@ -558,11 +563,6 @@ def __init__( prefetch: Optional[int] = None, storage: Optional[Storage] = None, ): - if collate_fn is None: - - def collate_fn(x): - return stack_td(x, 0, contiguous=True) - super().__init__(size, collate_fn, pin_memory, prefetch, storage=storage) @@ -606,11 +606,6 @@ def __init__( prefetch: Optional[int] = None, storage: Optional[Storage] = None, ) -> None: - if collate_fn is None: - - def collate_fn(x): - return stack_td(x, 0, contiguous=True) - super(TensorDictPrioritizedReplayBuffer, self).__init__( size=size, alpha=alpha, diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 28e600c2a72..b37cd2fb589 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -368,28 +368,35 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: f"The storage was created in {out.filename} and occupies {filesize} Mb of storage." ) else: - out = TensorDict({}, [self.max_size, *data.shape]) + # out = TensorDict({}, [self.max_size, *data.shape]) print("The storage is being created: ") - for key, tensor in sorted(data.items()): - if isinstance(tensor, TensorDictBase): - out[key] = ( - tensor.expand(self.max_size) - .clone() - .zero_() - .memmap_(prefix=self.scratch_dir) - .to(self.device) - ) - else: - out[key] = _value = MemmapTensor( - self.max_size, - *tensor.shape, - device=self.device, - dtype=tensor.dtype, - prefix=self.scratch_dir, - ) - filesize = os.path.getsize(_value.filename) / 1024 / 1024 + out = ( + data.expand(self.max_size, *data.shape) + .to_tensordict() + .zero_() + .memmap_(prefix=self.scratch_dir) + .to(self.device) + ) + for key, tensor in sorted(out.flatten_keys(".").items()): + # if isinstance(tensor, TensorDictBase): + # out[key] = ( + # tensor.expand(self.max_size, *tensor.shape) + # .clone() + # .zero_() + # .memmap_(prefix=self.scratch_dir) + # .to(self.device) + # ) + # else: + # out[key] = _value = MemmapTensor( + # self.max_size, + # *tensor.shape, + # device=self.device, + # dtype=tensor.dtype, + # prefix=self.scratch_dir, + # ) + filesize = os.path.getsize(tensor.filename) / 1024 / 1024 print( - f"\t{key}: {_value.filename}, {filesize} Mb of storage (size: {[self.max_size, *tensor.shape]})." + f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {[self.max_size, *tensor.shape]})." ) self._storage = out self.initialized = True @@ -414,3 +421,30 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: ) elif _CKPT_BACKEND == "torch": return mem_map_tensor._tensor + + +def _collate_list_tensordict(x): + out = torch.stack(x, 0) + if isinstance(out, TensorDictBase): + return out.to_tensordict() + return out + + +def _collate_list_tensors(*x): + return tuple(torch.stack(_x, 0) for _x in zip(*x)) + + +def _collate_contiguous(x): + if isinstance(x, TensorDictBase): + return x.to_tensordict() + return x.clone() + + +def _get_default_collate(storage, _is_tensordict=True): + if isinstance(storage, ListStorage): + if _is_tensordict: + return _collate_list_tensordict + else: + return _collate_list_tensors + elif isinstance(storage, (LazyTensorStorage, LazyMemmapStorage)): + return _collate_contiguous diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index e1c455fb569..dd8df096854 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -419,8 +419,8 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) a, b = self.space - shape = [*shape, *self.shape] if self.dtype in (torch.float, torch.double, torch.half): + shape = [*shape, *self.shape] out = ( torch.zeros(shape, dtype=self.dtype, device=self.device).uniform_() * (b - a) @@ -433,7 +433,7 @@ def rand(self, shape=None) -> torch.Tensor: return out else: interval = self.space.maximum - self.space.minimum - r = torch.rand(*interval.shape, device=interval.device) + r = torch.rand(*shape, *interval.shape, device=interval.device) r = interval * r r = self.space.minimum + r r = r.to(self.dtype).to(self.device) @@ -647,7 +647,7 @@ def rand(self, shape=None) -> torch.Tensor: if shape is None: shape = torch.Size([]) interval = self.space.maximum - self.space.minimum - r = torch.rand(interval.shape, device=interval.device) + r = torch.rand(*shape, *interval.shape, device=interval.device) r = r * interval r = self.space.minimum + r r = r.to(self.dtype) @@ -798,9 +798,15 @@ def __init__( shape = torch.Size([shape]) dtype, device = _default_dtype_and_device(dtype, device) + if dtype == torch.bool: + min_value = False + max_value = True + else: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max space = ContinuousBox( - torch.full(shape, torch.iinfo(dtype).min, device=device), - torch.full(shape, torch.iinfo(dtype).max, device=device), + torch.full(shape, min_value, device=device), + torch.full(shape, max_value, device=device), ) super(UnboundedDiscreteTensorSpec, self).__init__( @@ -1101,6 +1107,15 @@ class CompositeSpec(TensorSpec): b: None, c: None)) + CompositeSpec supports nested indexing: + >>> spec = CompositeSpec(obs=None) + >>> spec["nested", "x"] = None + >>> print(spec) + CompositeSpec( + nested: CompositeSpec( + x: None), + x: None) + """ domain: str = "composite" @@ -1212,7 +1227,12 @@ def encode(self, vals: Dict[str, Any]) -> Dict[str, torch.Tensor]: raise RuntimeError( "CompositeSpec.encode cannot be used with missing values." ) - out[key] = self[key].encode(item) + try: + out[key] = self[key].encode(item) + except KeyError: + raise KeyError( + f"The CompositeSpec instance with keys {self.keys()} does not have a '{key}' key." + ) return out def __repr__(self) -> str: @@ -1258,12 +1278,13 @@ def project(self, val: TensorDictBase) -> TensorDictBase: def rand(self, shape=None) -> TensorDictBase: if shape is None: shape = torch.Size([]) + _dict = { + key: self[key].rand(shape) + for key in self.keys(True) + if isinstance(key, str) and self[key] is not None + } return TensorDict( - { - key: self[key].rand(shape) - for key in self.keys(True) - if isinstance(key, str) and self[key] is not None - }, + _dict, batch_size=shape, ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 58e46579ca9..d5b3517dd76 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -646,11 +646,11 @@ def to(self, device: DEVICE_TYPING) -> EnvBase: def fake_tensordict(self) -> TensorDictBase: """Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout.""" input_spec = self.input_spec - fake_input = input_spec.zero(self.batch_size) + fake_input = input_spec.rand(self.batch_size) observation_spec = self.observation_spec - fake_obs = observation_spec.zero(self.batch_size) + fake_obs = observation_spec.rand(self.batch_size) reward_spec = self.reward_spec - fake_reward = reward_spec.zero(self.batch_size) + fake_reward = reward_spec.rand(self.batch_size) fake_td = TensorDict( { **fake_obs, diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 494de15e92a..cf284f4e7db 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -19,13 +19,13 @@ TensorSpec, UnboundedContinuousTensorSpec, ) +from ..._utils import implement_for from ...data.utils import numpy_to_torch_dtype_dict from ..gym_like import GymLikeEnv, default_info_dict_reader from ..utils import _classproperty try: import gym - from packaging import version _has_gym = True except ImportError: @@ -48,9 +48,6 @@ from torchrl.envs.libs.utils import ( GymPixelObservationWrapper as PixelObservationWrapper, ) - gym_version = version.parse(gym.__version__) - if gym_version >= version.parse("0.26.0"): - from gym.wrappers.compatibility import EnvCompatibility __all__ = ["GymWrapper", "GymEnv"] @@ -103,16 +100,22 @@ def _gym_to_torchrl_spec_transform( def _get_envs(to_dict=False) -> List: - if gym_version < version.parse("0.26.0"): - envs = gym.envs.registration.registry.env_specs.keys() - else: - envs = gym.envs.registration.registry.keys() - + envs = _get_gym_envs() envs = list(envs) envs = sorted(envs) return envs +@implement_for("gym", None, "0.26.0") +def _get_gym_envs(): # noqa: F811 + return gym.envs.registration.registry.env_specs.keys() + + +@implement_for("gym", "0.26.0", None) +def _get_gym_envs(): # noqa: F811 + return gym.envs.registration.registry.keys() + + def _get_gym(): if _has_gym: return gym @@ -186,20 +189,30 @@ def _build_env( "PixelObservationWrapper cannot be used to wrap an environment" "that is already a PixelObservationWrapper instance." ) - if gym_version >= version.parse("0.26.0") and not env.render_mode: - warnings.warn( - "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper " - "should be created with `gym.make(env_name, render_mode=mode)` where possible," - 'where mode is either "rgb_array" or any other supported mode.' - ) - # resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper - env = EnvCompatibility(env) - env.reset() - env = LegacyPixelObservationWrapper(env, pixels_only=pixels_only) - else: - env = PixelObservationWrapper(env, pixels_only=pixels_only) + env = self._build_gym_env(env, pixels_only) return env + @implement_for("gym", None, "0.26.0") + def _build_gym_env(self, env, pixels_only): # noqa: F811 + return PixelObservationWrapper(env, pixels_only=pixels_only) + + @implement_for("gym", "0.26.0", None) + def _build_gym_env(self, env, pixels_only): # noqa: F811 + from gym.wrappers.compatibility import EnvCompatibility + + if env.render_mode: + return PixelObservationWrapper(env, pixels_only=pixels_only) + + warnings.warn( + "Environments provided to GymWrapper that need to be wrapped in PixelObservationWrapper " + "should be created with `gym.make(env_name, render_mode=mode)` where possible," + 'where mode is either "rgb_array" or any other supported mode.' + ) + # resetting as 0.26 comes with a very 'nice' OrderEnforcing wrapper + env = EnvCompatibility(env) + env.reset() + return LegacyPixelObservationWrapper(env, pixels_only=pixels_only) + @_classproperty def available_envs(cls) -> List[str]: return _get_envs() @@ -208,29 +221,35 @@ def available_envs(cls) -> List[str]: def lib(self) -> ModuleType: return gym - def _set_seed(self, seed: int) -> int: - skip = False + def _set_seed(self, seed: int) -> int: # noqa: F811 if self._seed_calls_reset is None: - if gym_version < version.parse("0.19.0"): - self._seed_calls_reset = False - self._env.seed(seed=seed) - else: - try: - self.reset(seed=seed) - skip = True - self._seed_calls_reset = True - except TypeError as err: - warnings.warn( - f"reset with seed kwarg returned an exception: {err}.\n" - f"Calling env.seed from now on." - ) - self._seed_calls_reset = False - if self._seed_calls_reset and not skip: + # Determine basing on gym version whether `reset` is called when setting seed. + self._set_seed_initial(seed) + elif self._seed_calls_reset: self.reset(seed=seed) - elif not self._seed_calls_reset: + else: self._env.seed(seed=seed) + return seed + @implement_for("gym", None, "0.19.0") + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + self._seed_calls_reset = False + self._env.seed(seed=seed) + + @implement_for("gym", "0.19.0", None) + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + try: + self.reset(seed=seed) + self._seed_calls_reset = True + except TypeError as err: + warnings.warn( + f"reset with seed kwarg returned an exception: {err}.\n" + f"Calling env.seed from now on." + ) + self._seed_calls_reset = False + self._env.seed(seed=seed) + def _make_specs(self, env: "gym.Env") -> None: self.action_spec = _gym_to_torchrl_spec_transform( env.action_space, @@ -294,15 +313,25 @@ class GymEnv(GymWrapper): def __init__(self, env_name, disable_env_checker=None, **kwargs): kwargs["env_name"] = env_name - if gym_version >= version.parse("0.24.0"): - kwargs["disable_env_checker"] = ( - disable_env_checker if disable_env_checker is not None else True - ) - elif disable_env_checker is not None: + self._set_gym_args(kwargs, disable_env_checker) + super().__init__(**kwargs) + + @implement_for("gym", None, "0.24.0") + def _set_gym_args( # noqa: F811 + self, kwargs, disable_env_checker: bool = None + ) -> None: + if disable_env_checker is not None: raise RuntimeError( "disable_env_checker should only be set if gym version is > 0.24" ) - super().__init__(**kwargs) + + @implement_for("gym", "0.24.0", None) + def _set_gym_args( # noqa: F811 + self, kwargs, disable_env_checker: bool = None + ) -> None: + kwargs["disable_env_checker"] = ( + disable_env_checker if disable_env_checker is not None else True + ) def _build_env( self, @@ -312,12 +341,11 @@ def _build_env( if not _has_gym: raise RuntimeError( f"gym not found, unable to create {env_name}. " - f"Consider downloading and installing dm_control from" + f"Consider downloading and installing gym from" f" {self.git_url}" ) from_pixels = kwargs.get("from_pixels", False) - if from_pixels and gym_version > version.parse("0.25.0"): - kwargs.setdefault("render_mode", "rgb_array") + self._set_gym_default(kwargs, from_pixels) if "from_pixels" in kwargs: del kwargs["from_pixels"] pixels_only = kwargs.get("pixels_only", True) @@ -350,6 +378,16 @@ def _build_env( raise err return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + @implement_for("gym", None, "0.25.1") + def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 + # Do nothing for older gym versions. + pass + + @implement_for("gym", "0.25.1", None) + def _set_gym_default(self, kwargs, from_pixels: bool) -> None: # noqa: F811 + if from_pixels: + kwargs.setdefault("render_mode", "rgb_array") + @property def env_name(self): return self._constructor_kwargs["env_name"] diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py new file mode 100644 index 00000000000..1b6f691cef1 --- /dev/null +++ b/torchrl/envs/libs/jumanji.py @@ -0,0 +1,391 @@ +import dataclasses +from typing import Optional, Dict, Union + +import numpy as np +import torch +from tensordict.tensordict import TensorDict, TensorDictBase, make_tensordict + +from torchrl.data import ( + DEVICE_TYPING, + TensorSpec, + CompositeSpec, + DiscreteTensorSpec, + OneHotDiscreteTensorSpec, + NdBoundedTensorSpec, + NdUnboundedContinuousTensorSpec, + NdUnboundedDiscreteTensorSpec, +) +from torchrl.data.utils import numpy_to_torch_dtype_dict +from torchrl.envs import GymLikeEnv + +try: + import jax + import jumanji + from jax import numpy as jnp + + _has_jumanji = True +except ImportError as err: + _has_jumanji = False + IMPORT_ERR = str(err) + + +def _ndarray_to_tensor(value: Union["jnp.ndarray", np.ndarray], device) -> torch.Tensor: + # tensor doesn't support conversion from jnp.ndarray. + if isinstance(value, jnp.ndarray): + value = np.asarray(value) + # tensor doesn't support unsigned dtypes. + if value.dtype == np.uint16: + value = value.astype(np.int16) + elif value.dtype == np.uint32: + value = value.astype(np.int32) + elif value.dtype == np.uint64: + value = value.astype(np.int64) + # convert to tensor. + return torch.tensor(value).to(device) + + +def _object_to_tensordict(obj: Union, device, batch_size) -> TensorDictBase: + """Converts a namedtuple or a dataclass to a TensorDict.""" + t = {} + if isinstance(obj, tuple) and hasattr(obj, "_fields"): # named tuple + _iter = obj._fields + elif dataclasses.is_dataclass(obj): + _iter = (field.name for field in dataclasses.fields(obj)) + else: + raise NotImplementedError(f"unsupported data type {type(obj)}") + for name in _iter: + value = getattr(obj, name) + if isinstance(value, (jnp.ndarray, np.ndarray)): + t[name] = _ndarray_to_tensor(value, device=device) + else: + t[name] = _object_to_tensordict(value, device, batch_size) + return make_tensordict(**t, device=device, batch_size=batch_size) + + +def _tensordict_to_object(tensordict: TensorDictBase, object_example): + """Converts a TensorDict to a namedtuple or a dataclass.""" + object_type = type(object_example) + t = {} + for name in tensordict.keys(): + value = tensordict[name] + if isinstance(value, TensorDictBase): + t[name] = _tensordict_to_object(value, getattr(object_example, name)) + else: + example = getattr(object_example, name) + t[name] = ( + value.detach().numpy().reshape(example.shape).astype(example.dtype) + ) + return object_type(**t) + + +def _jumanji_to_torchrl_spec_transform( + spec, + dtype: Optional[torch.dtype] = None, + device: DEVICE_TYPING = None, + categorical_action_encoding: bool = True, +) -> TensorSpec: + if isinstance(spec, jumanji.specs.DiscreteArray): + action_space_cls = ( + DiscreteTensorSpec + if categorical_action_encoding + else OneHotDiscreteTensorSpec + ) + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + return action_space_cls(spec.num_values, dtype=dtype, device=device) + elif isinstance(spec, jumanji.specs.BoundedArray): + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + return NdBoundedTensorSpec( + shape=spec.shape, + minimum=np.asarray(spec.minimum), + maximum=np.asarray(spec.maximum), + dtype=dtype, + device=device, + ) + elif isinstance(spec, jumanji.specs.Array): + if dtype is None: + dtype = numpy_to_torch_dtype_dict[spec.dtype] + if dtype in (torch.float, torch.double, torch.half): + return NdUnboundedContinuousTensorSpec( + shape=spec.shape, dtype=dtype, device=device + ) + else: + return NdUnboundedDiscreteTensorSpec( + shape=spec.shape, dtype=dtype, device=device + ) + elif isinstance(spec, jumanji.specs.Spec) and hasattr(spec, "__dict__"): + new_spec = {} + for key, value in spec.__dict__.items(): + if isinstance(value, jumanji.specs.Spec): + if key.endswith("_obs"): + key = key[:-4] + if key.endswith("_spec"): + key = key[:-5] + new_spec[key] = _jumanji_to_torchrl_spec_transform( + value, dtype, device, categorical_action_encoding + ) + return CompositeSpec(**new_spec) + else: + raise TypeError(f"Unsupported spec type {type(spec)}") + + +def _torchrl_data_to_spec_transform(data) -> TensorSpec: + if isinstance(data, torch.Tensor): + if data.dtype in (torch.float, torch.double, torch.half): + return NdUnboundedContinuousTensorSpec( + shape=data.shape, dtype=data.dtype, device=data.device + ) + else: + return NdUnboundedDiscreteTensorSpec( + shape=data.shape, dtype=data.dtype, device=data.device + ) + elif isinstance(data, TensorDict): + return CompositeSpec( + **{ + key: _torchrl_data_to_spec_transform(value) + for key, value in data.items() + } + ) + else: + raise TypeError(f"Unsupported data type {type(data)}") + + +class JumanjiWrapper(GymLikeEnv): + """Jumanji environment wrapper. + + Examples: + >>> env = jumanju.make("Snake-6x6-v0") + >>> env = JumanjiWrapper(env) + >>> td0 = env.reset() + >>> print(td0) + >>> td1 = env.rand_step(td0) + >>> print(td1) + >>> print(env.available_envs) + """ + + git_url = "https://github.com/instadeepai/jumanji" + + @property + def lib(self): + return jumanji + + def __init__(self, env: "jumanji.env.Environment" = None, **kwargs): + if env is not None: + kwargs["env"] = env + super().__init__(**kwargs) + + def _build_env( + self, + env, + _seed: Optional[int] = None, + from_pixels: bool = False, + render_kwargs: Optional[dict] = None, + pixels_only: bool = False, + camera_id: Union[int, str] = 0, + **kwargs, + ): + self.from_pixels = from_pixels + self.pixels_only = pixels_only + + if from_pixels: + raise NotImplementedError("TODO") + return env + + def _make_state_example(self, env): + key = jax.random.PRNGKey(0) + keys = jax.random.split(key, self.batch_size.numel()) + state, _ = jax.vmap(env.reset)(jnp.stack(keys)) + state = self._reshape(state) + return state + + def _make_state_spec(self, env) -> TensorSpec: + key = jax.random.PRNGKey(0) + state, _ = env.reset(key) + state_dict = _object_to_tensordict(state, self.device, batch_size=()) + state_spec = _torchrl_data_to_spec_transform(state_dict) + return state_spec + + def _make_input_spec(self, env) -> TensorSpec: + return CompositeSpec( + action=_jumanji_to_torchrl_spec_transform( + env.action_spec(), device=self.device + ), + ) + + def _make_observation_spec(self, env) -> TensorSpec: + spec = env.observation_spec() + new_spec = _jumanji_to_torchrl_spec_transform(spec, device=self.device) + if isinstance(spec, jumanji.specs.Array): + return CompositeSpec(observation=new_spec) + elif isinstance(spec, jumanji.specs.Spec): + return CompositeSpec(**{k: v for k, v in new_spec.items()}) + else: + raise TypeError(f"Unsupported spec type {type(spec)}") + + def _make_reward_spec(self, env) -> TensorSpec: + return _jumanji_to_torchrl_spec_transform(env.reward_spec(), device=self.device) + + def _make_specs(self, env: "jumanji.env.Environment") -> None: # noqa: F821 + + # extract spec from jumanji definition + self._input_spec = self._make_input_spec(env) + self._observation_spec = self._make_observation_spec(env) + self._reward_spec = self._make_reward_spec(env) + + # extract state spec from instance + self._state_spec = self._make_state_spec(env) + self._input_spec["state"] = self._state_spec + + # build state example for data conversion + self._state_example = self._make_state_example(env) + + def _check_kwargs(self, kwargs: Dict): + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, (jumanji.env.Environment,)): + raise TypeError("env is not of type 'jumanji.env.Environment'.") + + def _init_env(self): + pass + + def _set_seed(self, seed): + if seed is None: + raise Exception("Jumanji requires an integer seed.") + self.key = jax.random.PRNGKey(seed) + + def read_state(self, state): + state_dict = _object_to_tensordict(state, self.device, self.batch_size) + return self._state_spec.encode(state_dict) + + def read_obs(self, obs): + if isinstance(obs, (list, jnp.ndarray, np.ndarray)): + obs_dict = _ndarray_to_tensor(obs, self.device) + else: + obs_dict = _object_to_tensordict(obs, self.device, self.batch_size) + return super().read_obs(obs_dict) + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + + # prepare inputs + state = _tensordict_to_object(tensordict.get("state"), self._state_example) + action = self.read_action(tensordict.get("action")) + reward = self.reward_spec.zero(self.batch_size) + + # flatten batch size into vector + state = self._flatten(state) + action = self._flatten(action) + + # jax vectorizing map on env.step + state, timestep = jax.vmap(self._env.step)(state, action) + + # reshape batch size from vector + state = self._reshape(state) + timestep = self._reshape(timestep) + + # collect outputs + state_dict = self.read_state(state) + obs_dict = self.read_obs(timestep.observation) + reward = self.read_reward(reward, np.asarray(timestep.reward)) + done = torch.tensor( + np.asarray(timestep.step_type == self.lib.types.StepType.LAST) + ) + + self._is_done = done + + # build results + tensordict_out = TensorDict( + source=obs_dict, + batch_size=tensordict.batch_size, + device=self.device, + ) + tensordict_out.set("reward", reward) + tensordict_out.set("done", done) + tensordict_out["state"] = state_dict + + return tensordict_out + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + + # generate random keys + self.key, *keys = jax.random.split(self.key, self.numel() + 1) + + # jax vectorizing map on env.reset + state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys)) + + # reshape batch size from vector + state = self._reshape(state) + timestep = self._reshape(timestep) + + # collect outputs + state_dict = self.read_state(state) + obs_dict = self.read_obs(timestep.observation) + done = torch.zeros(self.batch_size, dtype=torch.bool) + + self._is_done = done + + # build results + tensordict_out = TensorDict( + source=obs_dict, + batch_size=self.batch_size, + device=self.device, + ) + tensordict_out.set("done", done) + tensordict_out["state"] = state_dict + + return tensordict_out + + def _reshape(self, x): + shape, n = self.batch_size, 1 + return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) + + def _flatten(self, x): + shape, n = (self.batch_size.numel(),), len(self.batch_size) + return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x) + + +class JumanjiEnv(JumanjiWrapper): + """Jumanji environment wrapper. + + Examples: + >>> env = JumanjiEnv(env_name="Snake-6x6-v0", frame_skip=4) + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + """ + + def __init__(self, env_name, **kwargs): + kwargs["env_name"] = env_name + super().__init__(**kwargs) + + def _build_env( + self, + env_name: str, + **kwargs, + ) -> "jumanji.env.Environment": + if not _has_jumanji: + raise RuntimeError( + f"jumanji not found, unable to create {env_name}. " + f"Consider installing jumanji. More info:" + f" {self.git_url}. (Original error message during import: {IMPORT_ERR})." + ) + from_pixels = kwargs.pop("from_pixels", False) + pixels_only = kwargs.pop("pixels_only", True) + assert not kwargs + self.wrapper_frame_skip = 1 + env = self.lib.make(env_name, **kwargs) + return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + + @property + def env_name(self): + return self._constructor_kwargs["env_name"] + + def _check_kwargs(self, kwargs: Dict): + if "env_name" not in kwargs: + raise TypeError("Expected 'env_name' to be part of kwargs") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})" diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 9a702c8d290..cdb4d35f4ce 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -28,5 +28,6 @@ VecNorm, gSDENoise, TensorDictPrimer, + SqueezeTransform, ) -from .vip import VIPTransform +from .vip import VIPTransform, VIPRewardTransform diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e2ae5663b86..410cb639fb3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -230,12 +230,13 @@ def clone(self): return self_copy @property - def parent(self) -> EnvBase: + def parent(self) -> Optional[EnvBase]: if not hasattr(self, "_parent"): raise AttributeError("transform parent uninitialized") parent = self._parent if parent is None: return parent + out = None if not isinstance(parent, EnvBase): # if it's not an env, it should be a Compose transform if not isinstance(parent, Compose): @@ -243,26 +244,23 @@ def parent(self) -> EnvBase: "A transform parent must be either another Compose transform or an environment object." ) compose = parent - # the parent of the compose must be a TransformedEnv - compose_parent = compose.parent - if not isinstance(compose_parent, TransformedEnv): - raise ValueError( - f"Compose parent was of type {type(compose_parent)} but expected TransformedEnv." + if compose.parent: + # the parent of the compose must be a TransformedEnv + compose_parent = compose.parent + if compose_parent.transform is not compose: + comp_parent_trans = compose_parent.transform.clone() + else: + comp_parent_trans = None + out = TransformedEnv( + compose_parent.base_env, + transform=comp_parent_trans, ) - if compose_parent.transform is not compose: - comp_parent_trans = compose_parent.transform.clone() - else: - comp_parent_trans = None - out = TransformedEnv( - compose_parent.base_env, - transform=comp_parent_trans, - ) - for orig_trans in compose.transforms: - if orig_trans is self: - break - transform = copy(orig_trans) - transform.reset_parent() - out.append_transform(transform) + for orig_trans in compose.transforms: + if orig_trans is self: + break + transform = copy(orig_trans) + transform.reset_parent() + out.append_transform(transform) elif isinstance(parent, TransformedEnv): out = TransformedEnv(parent.base_env) else: @@ -1033,13 +1031,14 @@ class FlattenObservation(ObservationTransform): def __init__( self, - first_dim: int = 0, + first_dim: int, last_dim: int = -3, in_keys: Optional[Sequence[str]] = None, + out_keys: Optional[Sequence[str]] = None, ): if in_keys is None: in_keys = IMAGE_KEYS # default - super().__init__(in_keys=in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) self.first_dim = first_dim self.last_dim = last_dim @@ -1049,15 +1048,26 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: def set_parent(self, parent: Union[Transform, EnvBase]) -> None: out = super().set_parent(parent) - observation_spec = self.parent.observation_spec - for key in self.in_keys: - if key in observation_spec: - observation_spec = observation_spec[key] - if self.first_dim >= 0: - self.first_dim = self.first_dim - len(observation_spec.shape) - if self.last_dim >= 0: - self.last_dim = self.last_dim - len(observation_spec.shape) - break + try: + observation_spec = self.parent.observation_spec + for key in self.in_keys: + if key in observation_spec: + observation_spec = observation_spec[key] + if self.first_dim >= 0: + self.first_dim = self.first_dim - len(observation_spec.shape) + if self.last_dim >= 0: + self.last_dim = self.last_dim - len(observation_spec.shape) + break + except AttributeError: + if self.first_dim >= 0 or self.last_dim >= 0: + raise ValueError( + f"FlattenObservation got first and last dim {self.first_dim} amd {self.last_dim}. " + f"Those values assume that the observation spec is known, which requires the " + f"parent environment to be set. " + f"Consider setting the parent environment beforehand (ie passing the transform " + f"to `TransformedEnv.append_transform()`) or setting strictly negative " + f"flatten dimensions to the transform." + ) return out @_apply_to_composite @@ -1119,7 +1129,15 @@ def set_parent(self, parent: Union[Transform, EnvBase]) -> None: self._unsqueeze_dim = self._unsqueeze_dim_orig else: parent = self.parent - batch_size = parent.batch_size + try: + batch_size = parent.batch_size + except AttributeError: + raise ValueError( + f"Got the unsqueeze dimension {self._unsqueeze_dim_orig} which is greater or equal to zero. " + f"However this requires to know what the parent environment is, but it has not been provided. " + f"Consider providing a negative dimension or setting the transform using the " + f"`TransformedEnv.append_transform()` method." + ) self._unsqueeze_dim = self._unsqueeze_dim_orig + len(batch_size) return super().set_parent(parent) @@ -1507,6 +1525,12 @@ class RewardScaling(Transform): Args: loc (number or torch.Tensor): location of the affine transform scale (number or torch.Tensor): scale of the affine transform + standard_normal (bool, optional): if True, the transform will be + + .. math:: + reward = (reward-loc)/scale + + as it is done for standardization. Default is `False`. """ inplace = True @@ -1516,10 +1540,13 @@ def __init__( loc: Union[float, torch.Tensor], scale: Union[float, torch.Tensor], in_keys: Optional[Sequence[str]] = None, + standard_normal: bool = False, ): if in_keys is None: in_keys = ["reward"] super().__init__(in_keys=in_keys) + self.standard_normal = standard_normal + if not isinstance(loc, torch.Tensor): loc = torch.tensor(loc) if not isinstance(scale, torch.Tensor): @@ -1529,8 +1556,16 @@ def __init__( self.register_buffer("scale", scale.clamp_min(1e-6)) def _apply_transform(self, reward: torch.Tensor) -> torch.Tensor: - reward.mul_(self.scale).add_(self.loc) - return reward + if self.standard_normal: + loc = self.loc + scale = self.scale + reward = (reward - loc) / scale + return reward + else: + scale = self.scale + loc = self.loc + reward = reward * scale + loc + return reward def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: if isinstance(reward_spec, UnboundedContinuousTensorSpec): diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 2d892aba456..9e57a7c6bc3 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -463,6 +463,15 @@ def _make_net(self, device: Optional[DEVICE_TYPING]) -> nn.Module: layers.append(Squeeze2dLayer()) return layers + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + *batch, C, L, W = inputs.shape + if len(batch) > 1: + inputs = inputs.flatten(0, len(batch) - 1) + out = super(ConvNet, self).forward(inputs) + if len(batch) > 1: + out = out.unflatten(0, batch) + return out + class DuelingMlpDQNet(nn.Module): """Creates a Dueling MLP Q-network. diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index d11c9ab12fd..60c01aa5328 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDictBase +from tensordict.tensordict import TensorDictBase, TensorDict from torchrl.envs import EnvBase from torchrl.modules.planners.common import MPCPlannerBase @@ -36,7 +36,7 @@ class CEMPlanner(MPCPlannerBase): planner num_candidates (int): The number of candidates to sample from the Gaussian distributions. - num_top_k_candidates (int): The number of top candidates to use to + top_k (int): The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution. reward_key (str, optional): The key in the TensorDict to use to retrieve the reward. Defaults to "reward". @@ -52,7 +52,7 @@ class CEMPlanner(MPCPlannerBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = CompositeSpec( - ... next_hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)) ... ) ... self.input_spec = CompositeSpec( ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), @@ -61,13 +61,17 @@ class CEMPlanner(MPCPlannerBase): ... self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: - ... tensordict = TensorDict({}, + ... tensordict = TensorDict( + ... {}, ... batch_size=self.batch_size, ... device=self.device, ... ) - ... tensordict = tensordict.update(self.input_spec.rand(self.batch_size)) - ... tensordict = tensordict.update(self.observation_spec.rand(self.batch_size)) + ... tensordict = tensordict.update( + ... self.input_spec.rand(self.batch_size)) + ... tensordict = tensordict.update( + ... self.observation_spec.rand(self.batch_size)) ... return tensordict + ... >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( @@ -91,7 +95,12 @@ class CEMPlanner(MPCPlannerBase): action: Tensor(torch.Size([5, 1]), dtype=torch.float32), done: Tensor(torch.Size([5, 1]), dtype=torch.bool), hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), - next_hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + next: LazyStackedTensorDict( + fields={ + hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), reward: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, batch_size=torch.Size([5]), device=cpu, @@ -104,7 +113,7 @@ def __init__( planning_horizon: int, optim_steps: int, num_candidates: int, - num_top_k_candidates: int, + top_k: int, reward_key: str = "reward", action_key: str = "action", ): @@ -112,46 +121,66 @@ def __init__( self.planning_horizon = planning_horizon self.optim_steps = optim_steps self.num_candidates = num_candidates - self.num_top_k_candidates = num_top_k_candidates + self.top_k = top_k self.reward_key = reward_key def planning(self, tensordict: TensorDictBase) -> torch.Tensor: batch_size = tensordict.batch_size - expanded_original_tensordict = ( - tensordict.unsqueeze(-1) - .expand(*batch_size, self.num_candidates) - .reshape(-1) + action_shape = ( + *batch_size, + self.num_candidates, + self.planning_horizon, + *self.action_spec.shape, ) - flatten_batch_size = batch_size.numel() - actions_means = torch.zeros( - flatten_batch_size, + action_stats_shape = ( + *batch_size, 1, self.planning_horizon, *self.action_spec.shape, - device=tensordict.device, - dtype=self.env.action_spec.dtype, ) - actions_stds = torch.ones( - flatten_batch_size, - 1, + action_topk_shape = ( + *batch_size, + self.top_k, self.planning_horizon, *self.action_spec.shape, + ) + TIME_DIM = len(self.action_spec.shape) - 3 + K_DIM = len(self.action_spec.shape) - 4 + expanded_original_tensordict = ( + tensordict.unsqueeze(-1) + .expand(*batch_size, self.num_candidates) + .to_tensordict() + ) + _action_means = torch.zeros( + *action_stats_shape, device=tensordict.device, dtype=self.env.action_spec.dtype, ) + _action_stds = torch.ones_like(_action_means) + container = TensorDict( + { + "tensordict": expanded_original_tensordict, + "stats": TensorDict( + { + "_action_means": _action_means, + "_action_stds": _action_stds, + }, + [*batch_size, 1, self.planning_horizon], + ), + }, + batch_size, + ) for _ in range(self.optim_steps): + actions_means = container.get(("stats", "_action_means")) + actions_stds = container.get(("stats", "_action_stds")) actions = actions_means + actions_stds * torch.randn( - flatten_batch_size, - self.num_candidates, - self.planning_horizon, - *self.action_spec.shape, - device=tensordict.device, - dtype=self.env.action_spec.dtype, + *action_shape, + device=actions_means.device, + dtype=actions_means.dtype, ) - actions = actions.flatten(0, 1) actions = self.env.action_spec.project(actions) - optim_tensordict = expanded_original_tensordict.to_tensordict() + optim_tensordict = container.get("tensordict").clone() policy = _PrecomputedActionsSequentialSetter(actions) optim_tensordict = self.env.rollout( max_steps=self.planning_horizon, @@ -159,23 +188,21 @@ def planning(self, tensordict: TensorDictBase) -> torch.Tensor: auto_reset=False, tensordict=optim_tensordict, ) - rewards = ( - optim_tensordict.get(self.reward_key) - .sum(dim=1) - .reshape(flatten_batch_size, self.num_candidates) - ) - _, top_k = rewards.topk(self.num_top_k_candidates, dim=1) - best_actions = actions.unflatten( - 0, (flatten_batch_size, self.num_candidates) + sum_rewards = optim_tensordict.get(self.reward_key).sum( + dim=TIME_DIM, keepdim=True + ) + _, top_k = sum_rewards.topk(self.top_k, dim=K_DIM) + top_k = top_k.expand(action_topk_shape) + best_actions = actions.gather(K_DIM, top_k) + container.set_( + ("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True) + ) + container.set_( + ("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True) ) - best_actions = best_actions[ - torch.arange(flatten_batch_size, device=tensordict.device).unsqueeze(1), - top_k, - ] - actions_means = best_actions.mean(dim=1, keepdim=True) - actions_stds = best_actions.std(dim=1, keepdim=True) - return actions_means[:, :, 0].reshape(*batch_size, *self.action_spec.shape) + action_means = container.get(("stats", "_action_means")) + return action_means[..., 0, 0, :] class _PrecomputedActionsSequentialSetter: @@ -183,9 +210,10 @@ def __init__(self, actions): self.actions = actions self.cmpt = 0 - def __call__(self, td): - if self.cmpt >= self.actions.shape[1]: - raise ValueError("Precomputed actions are too short") - td = td.set("action", self.actions[:, self.cmpt]) + def __call__(self, tensordict): + # checks that the step count is lower or equal to the horizon + if self.cmpt >= self.actions.shape[-2]: + raise ValueError("Precomputed actions sequence is too short") + tensordict = tensordict.set("action", self.actions[..., self.cmpt, :]) self.cmpt += 1 - return td + return tensordict diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py new file mode 100644 index 00000000000..b57b87770a5 --- /dev/null +++ b/torchrl/modules/planners/mppi.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.tensordict import TensorDictBase, TensorDict +from torch import nn + +from torchrl.envs import EnvBase +from torchrl.modules.planners.common import MPCPlannerBase + + +class MPPIPlanner(MPCPlannerBase): + """MPPI Planner Module. + + Reference: + - Model predictive path integral control using covariance variable importance + sampling. (Williams, G., Aldrich, A., and Theodorou, E. A.) https://arxiv.org/abs/1509.01149 + - Temporal Difference Learning for Model Predictive Control + (Hansen N., Wang X., Su H.) https://arxiv.org/abs/2203.04955 + + This module will perform a MPPI planning step when given a TensorDict + containing initial states. + + A call to the module returns the actions that empirically maximised the + returns given a planning horizon + + Args: + env (EnvBase): The environment to perform the planning step on (can be + `ModelBasedEnv` or :obj:`EnvBase`). + planning_horizon (int): The length of the simulated trajectories + optim_steps (int): The number of optimization steps used by the MPC + planner + num_candidates (int): The number of candidates to sample from the + Gaussian distributions. + top_k (int): The number of top candidates to use to + update the mean and standard deviation of the Gaussian distribution. + reward_key (str, optional): The key in the TensorDict to use to + retrieve the reward. Defaults to "reward". + action_key (str, optional): The key in the TensorDict to use to store + the action. Defaults to "action" + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.envs.model_based import ModelBasedEnvBase + >>> from torchrl.modules import TensorDictModule, ValueOperator + >>> from torchrl.objectives.value import TDLambdaEstimate + >>> class MyMBEnv(ModelBasedEnvBase): + ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): + ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) + ... self.observation_spec = CompositeSpec( + ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... ) + ... self.input_spec = CompositeSpec( + ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), + ... action=NdUnboundedContinuousTensorSpec((1,)), + ... ) + ... self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) + ... + ... def _reset(self, tensordict: TensorDict) -> TensorDict: + ... tensordict = TensorDict( + ... {}, + ... batch_size=self.batch_size, + ... device=self.device, + ... ) + ... tensordict = tensordict.update( + ... self.input_spec.rand(self.batch_size)) + ... tensordict = tensordict.update( + ... self.observation_spec.rand(self.batch_size)) + ... return tensordict + >>> from torchrl.modules import MLP, WorldModelWrapper + >>> import torch.nn as nn + >>> world_model = WorldModelWrapper( + ... TensorDictModule( + ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), + ... in_keys=["hidden_observation", "action"], + ... out_keys=["hidden_observation"], + ... ), + ... TensorDictModule( + ... nn.Linear(4, 1), + ... in_keys=["hidden_observation"], + ... out_keys=["reward"], + ... ), + ... ) + >>> env = MyMBEnv(world_model) + >>> value_net = nn.Linear(4, 1) + >>> value_net = ValueOperator(value_net, in_keys=["hidden_observation"]) + >>> adv = TDLambdaEstimate( + ... 0.99, + ... 0.95, + ... value_net, + ... ) + >>> # Build a planner and use it as actor + >>> planner = MPPIPlanner( + ... env, + ... adv, + ... temperature=1.0, + ... planning_horizon=10, + ... optim_steps=11, + ... num_candidates=7, + ... top_k=3) + >>> env.rollout(5, planner) + TensorDict( + fields={ + action: Tensor(torch.Size([5, 1]), dtype=torch.float32), + done: Tensor(torch.Size([5, 1]), dtype=torch.bool), + hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32), + next: LazyStackedTensorDict( + fields={ + hidden_observation: Tensor(torch.Size([5, 4]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + reward: Tensor(torch.Size([5, 1]), dtype=torch.float32)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + """ + + def __init__( + self, + env: EnvBase, + advantage_module: nn.Module, + temperature: float, + planning_horizon: int, + optim_steps: int, + num_candidates: int, + top_k: int, + reward_key: str = "reward", + action_key: str = "action", + ): + super().__init__(env=env, action_key=action_key) + self.advantage_module = advantage_module + self.planning_horizon = planning_horizon + self.optim_steps = optim_steps + self.num_candidates = num_candidates + self.top_k = top_k + self.reward_key = reward_key + self.register_buffer("temperature", torch.tensor(temperature)) + + def planning(self, tensordict: TensorDictBase) -> torch.Tensor: + batch_size = tensordict.batch_size + action_shape = ( + *batch_size, + self.num_candidates, + self.planning_horizon, + *self.action_spec.shape, + ) + action_stats_shape = ( + *batch_size, + 1, + self.planning_horizon, + *self.action_spec.shape, + ) + action_topk_shape = ( + *batch_size, + self.top_k, + self.planning_horizon, + *self.action_spec.shape, + ) + adv_topk_shape = ( + *batch_size, + self.top_k, + 1, + 1, + ) + K_DIM = len(self.action_spec.shape) - 4 + expanded_original_tensordict = ( + tensordict.unsqueeze(-1) + .expand(*batch_size, self.num_candidates) + .to_tensordict() + ) + _action_means = torch.zeros( + *action_stats_shape, + device=tensordict.device, + dtype=self.env.action_spec.dtype, + ) + _action_stds = torch.ones_like(_action_means) + container = TensorDict( + { + "tensordict": expanded_original_tensordict, + "stats": TensorDict( + { + "_action_means": _action_means, + "_action_stds": _action_stds, + }, + [*batch_size, 1, self.planning_horizon], + ), + }, + batch_size, + ) + + for _ in range(self.optim_steps): + actions_means = container.get(("stats", "_action_means")) + actions_stds = container.get(("stats", "_action_stds")) + actions = actions_means + actions_stds * torch.randn( + *action_shape, + device=actions_means.device, + dtype=actions_means.dtype, + ) + actions = self.env.action_spec.project(actions) + optim_tensordict = container.get("tensordict").clone() + policy = _PrecomputedActionsSequentialSetter(actions) + optim_tensordict = self.env.rollout( + max_steps=self.planning_horizon, + policy=policy, + auto_reset=False, + tensordict=optim_tensordict, + ) + # compute advantage + self.advantage_module(optim_tensordict) + # get advantage of the current state + advantage = optim_tensordict["advantage"][..., :1, :] + # get top-k trajectories + _, top_k = advantage.topk(self.top_k, dim=K_DIM) + # get omega weights for each top-k trajectory + vals = advantage.gather(K_DIM, top_k.expand(adv_topk_shape)) + Omegas = (self.temperature * vals).exp() + + # gather best actions + best_actions = actions.gather(K_DIM, top_k.expand(action_topk_shape)) + + # compute weighted average + _action_means = (Omegas * best_actions).sum( + dim=K_DIM, keepdim=True + ) / Omegas.sum(K_DIM, True) + _action_stds = ( + (Omegas * (best_actions - _action_means).pow(2)).sum( + dim=K_DIM, keepdim=True + ) + / Omegas.sum(K_DIM, True) + ).sqrt() + container.set_(("stats", "_action_means"), _action_means) + container.set_(("stats", "_action_stds"), _action_stds) + action_means = container.get(("stats", "_action_means")) + return action_means[..., 0, 0, :] + + +class _PrecomputedActionsSequentialSetter: + def __init__(self, actions): + self.actions = actions + self.cmpt = 0 + + def __call__(self, tensordict): + # checks that the step count is lower or equal to the horizon + if self.cmpt >= self.actions.shape[-2]: + raise ValueError("Precomputed actions sequence is too short") + tensordict = tensordict.set("action", self.actions[..., self.cmpt, :]) + self.cmpt += 1 + return tensordict diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 87bd86ec790..0046315aea6 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -123,7 +123,7 @@ def make_env_transforms( env.append_transform(Resize(cfg.image_size, cfg.image_size)) if cfg.grayscale: env.append_transform(GrayScale()) - env.append_transform(FlattenObservation()) + env.append_transform(FlattenObservation(0)) env.append_transform(CatFrames(N=cfg.catframes, in_keys=["pixels"])) if stats is None: obs_stats = {"loc": 0.0, "scale": 1.0} @@ -272,8 +272,15 @@ def make_transformed_env(**kwargs) -> TransformedEnv: "frame_skip": frame_skip, "from_pixels": from_pixels or len(video_tag), "pixels_only": from_pixels, - "categorical_action_encoding": categorical_action_encoding, } + if env_library is GymEnv: + env_kwargs.update( + {"categorical_action_encoding": categorical_action_encoding} + ) + elif categorical_action_encoding: + raise NotImplementedError( + "categorical_action_encoding=True is currently only compatible with GymEnvs." + ) if env_library is DMControlEnv: env_kwargs.update({"task_name": env_task}) env_kwargs.update(kwargs) diff --git a/torchrl/trainers/helpers/replay_buffer.py b/torchrl/trainers/helpers/replay_buffer.py index ef9ef9963ba..c7bbee12d82 100644 --- a/torchrl/trainers/helpers/replay_buffer.py +++ b/torchrl/trainers/helpers/replay_buffer.py @@ -24,7 +24,6 @@ def make_replay_buffer( if not cfg.prb: buffer = TensorDictReplayBuffer( cfg.buffer_size, - collate_fn=lambda x: x, pin_memory=device != torch.device("cpu"), prefetch=cfg.buffer_prefetch, storage=LazyMemmapStorage( @@ -38,7 +37,6 @@ def make_replay_buffer( cfg.buffer_size, alpha=0.7, beta=0.5, - collate_fn=lambda x: x, pin_memory=device != torch.device("cpu"), prefetch=cfg.buffer_prefetch, storage=LazyMemmapStorage( diff --git a/torchrl/trainers/loggers/csv.py b/torchrl/trainers/loggers/csv.py index 93126f6de22..3db8dbbe23a 100644 --- a/torchrl/trainers/loggers/csv.py +++ b/torchrl/trainers/loggers/csv.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import os from collections import defaultdict +from pathlib import Path from typing import Optional import torch @@ -41,6 +42,8 @@ def add_video(self, tag, vid_tensor, global_step: Optional[int] = None, **kwargs filepath = os.path.join( self.log_dir, "videos", "_".join([tag, str(global_step)]) + ".pt" ) + path_to_create = Path(str(filepath)).parent + os.makedirs(path_to_create, exist_ok=True) torch.save(vid_tensor, filepath) def add_text(self, tag, text, global_step: Optional[int] = None): diff --git a/version.txt b/version.txt index da086c89b60..bcab45af15a 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.0.2a +0.0.3 From 953b186b13f98cc72943bdfb0205dcd3408f4e39 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 5 Jan 2023 14:14:34 +0000 Subject: [PATCH 5/5] lint --- test/test_modules.py | 5 +++-- torchrl/modules/planners/cem.py | 18 +++++++++--------- torchrl/modules/planners/mppi.py | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 4c88d0a954c..f3bb917aca7 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -439,9 +439,9 @@ def test_lstm_net_nobatch(device, out_features, hidden_size): torch.testing.assert_close(tds_vec["hidden1_out"][-1], tds_loop["hidden1_out"][-1]) +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("batch_size", [3, 5]) class TestPlanner: - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("batch_size", [3, 5]) def test_CEM_model_free_env(self, device, batch_size, seed=1): env = MockBatchedUnLockedEnv(device=device) torch.manual_seed(seed) @@ -475,6 +475,7 @@ def test_MPPI(self, device, batch_size, seed=1): 0.95, value_net, ) + value_net(env.reset()) planner = MPPIPlanner( env, advantage_module, diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 60c01aa5328..491f02f3391 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.envs import EnvBase from torchrl.modules.planners.common import MPCPlannerBase @@ -45,20 +45,20 @@ class CEMPlanner(MPCPlannerBase): Examples: >>> from tensordict import TensorDict - >>> from torchrl.data import CompositeSpec, NdUnboundedContinuousTensorSpec + >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec >>> from torchrl.envs.model_based import ModelBasedEnvBase - >>> from torchrl.modules import TensorDictModule + >>> from torchrl.modules import SafeModule >>> class MyMBEnv(ModelBasedEnvBase): ... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None): ... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size) ... self.observation_spec = CompositeSpec( - ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)) + ... next_hidden_observation=UnboundedContinuousTensorSpec((4,)) ... ) ... self.input_spec = CompositeSpec( - ... hidden_observation=NdUnboundedContinuousTensorSpec((4,)), - ... action=NdUnboundedContinuousTensorSpec((1,)), + ... hidden_observation=UnboundedContinuousTensorSpec((4,)), + ... action=UnboundedContinuousTensorSpec((1,)), ... ) - ... self.reward_spec = NdUnboundedContinuousTensorSpec((1,)) + ... self.reward_spec = UnboundedContinuousTensorSpec((1,)) ... ... def _reset(self, tensordict: TensorDict) -> TensorDict: ... tensordict = TensorDict( @@ -75,12 +75,12 @@ class CEMPlanner(MPCPlannerBase): >>> from torchrl.modules import MLP, WorldModelWrapper >>> import torch.nn as nn >>> world_model = WorldModelWrapper( - ... TensorDictModule( + ... SafeModule( ... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0), ... in_keys=["hidden_observation", "action"], ... out_keys=["hidden_observation"], ... ), - ... TensorDictModule( + ... SafeModule( ... nn.Linear(4, 1), ... in_keys=["hidden_observation"], ... out_keys=["reward"], diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index b57b87770a5..f1a5fe9b255 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch -from tensordict.tensordict import TensorDictBase, TensorDict +from tensordict.tensordict import TensorDict, TensorDictBase from torch import nn from torchrl.envs import EnvBase