diff --git a/experiments/Alfredo-ex1/seq_training.py b/experiments/Alfredo-ex1/seq_training.py index 3f46bd0..b1d4b3a 100644 --- a/experiments/Alfredo-ex1/seq_training.py +++ b/experiments/Alfredo-ex1/seq_training.py @@ -23,6 +23,19 @@ from alfredo.agents.A0 import Alfredo from alfredo.train import ppo + +import wandb +# Initialize a new run +wandb.init(project="alfredo" + config = { + "env_name": "A0", + "backend": "positional", + "seed": 0, + "len_training": 1_000_000, + # add any other hyperparameters or configurations you'd like to track + } +) + # ============================== # Useful Functions & Data Defs # ============================== @@ -31,20 +44,14 @@ def progress(num_steps, metrics): - d_and_t = datetime.now() - print(f"Next Iteration: {num_steps}") - print(f"Datetime: {d_and_t}") - - print(f"Total Reward: {metrics['eval/episode_reward']}") - print(f"Target Reward: {metrics['eval/episode_reward_to_target']}") - print(f"Vel Reward: {metrics['eval/episode_reward_velocity']}") - print(f"Alive Reward: {metrics['eval/episode_reward_alive']}") - print(f"Ctrl Reward: {metrics['eval/episode_reward_ctrl']}") - print(f"a_vel_x: {metrics['eval/episode_agent_x_velocity']}") - print(f"a_vel_y: {metrics['eval/episode_agent_y_velocity']}") - - print("==========================================================") - + wandb.log({"step": num_steps, + "Total Reward": metrics['eval/episode_reward'], + "Target Reward": metrics['eval/episode_reward_to_target'], + "Vel Reward": metrics['eval/episode_reward_velocity'], + "Alive Reward": metrics['eval/episode_reward_alive'], + "Ctrl Reward": metrics['eval/episode_reward_ctrl'], + "a_vel_x": metrics['eval/episode_agent_x_velocity'], + "a_vel_y": metrics['eval/episode_agent_y_velocity']}) # ============================== # General Variable Defs @@ -55,14 +62,6 @@ def progress(num_steps, metrics): import alfredo.scenes as scenes scene_fp = os.path.dirname(scenes.__file__) - -env_name = "A0" -backend = "positional" - -seed = 0 - -len_training = 1_000_000 - # ============================ # Loading and Defining Envs # ============================ @@ -73,42 +72,11 @@ def progress(num_steps, metrics): global_key, local_key = jax.random.split(key) key_policy, key_value = jax.random.split(global_key) -env = Alfredo(backend=backend, paramFile_path=pf_paths[0]) -# print(env.__dict__) -# print(dir(env)) -# print(env.observation_size) -# print(dir(env._pipeline)) +env = Alfredo(backend=wandb.config.backend, paramFile_path=pf_paths[0]) rng = jax.random.PRNGKey(seed=1) state = env.reset(rng) -# ==================SINGLE STEP DEBUGGING ================== - -# com0 = env._com(state.pipeline_state) -# print(f"\n") -# nState = env.step(state, jp.zeros(env.action_size)) -# com1 = env._com(nState.pipeline_state) -# lcom = len(com0) -# print(f"{com0}") -# print(f"{lcom}") - -# print(f"\n") -# nState = env.step(nState, jp.zeros(env.action_size)) - -# print(f"\n") -# nState = env.step(nState, jp.zeros(env.action_size)) - -# print(f"\n") -# nState = env.step(nState, jp.zeros(env.action_size)) - -# print(f"\n") -# nState = env.step(nState, jp.zeros(env.action_size)) - -# print(f"\n") -# nState = env.step(nState, jp.zeros(env.action_size)) - -# ======================================================== - ppo_network = ppo_networks.make_ppo_networks( env.observation_size, env.action_size, normalize_fn ) @@ -151,27 +119,25 @@ def progress(num_steps, metrics): print(f"[{d_and_t}] jitting end for model: {i}") # define new training function - train_fn = { - "A0": functools.partial( - ppo.train, - num_timesteps=len_training, - num_evals=10, - reward_scaling=0.1, - episode_length=1000, - normalize_observations=True, - action_repeat=1, - unroll_length=10, - num_minibatches=32, - num_updates_per_batch=8, - discounting=0.97, - learning_rate=3e-4, - entropy_cost=1e-3, - num_envs=2048, - batch_size=1024, - seed=1, - in_params=mParams, - ) - }[env_name] + train_fn = functools.partial( + ppo.train, + num_timesteps=len_training, + num_evals=10, + reward_scaling=0.1, + episode_length=1000, + normalize_observations=True, + action_repeat=1, + unroll_length=10, + num_minibatches=32, + num_updates_per_batch=8, + discounting=0.97, + learning_rate=3e-4, + entropy_cost=1e-3, + num_envs=2048, + batch_size=1024, + seed=1, + in_params=mParams, + ) d_and_t = datetime.now() print(f"[{d_and_t}] training start for model: {i}") @@ -185,9 +151,3 @@ def progress(num_steps, metrics): d_and_t = datetime.now() print(f"[{d_and_t}] loop end for model: {i}") - -# ============================ -# Presenting Final Stats -# ============================ - -# none right now diff --git a/flake.lock b/flake.lock index 6ddf819..ddbe61c 100644 --- a/flake.lock +++ b/flake.lock @@ -17,12 +17,15 @@ } }, "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1667395993, - "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "lastModified": 1685518550, + "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", "owner": "numtide", "repo": "flake-utils", - "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", "type": "github" }, "original": { @@ -54,11 +57,11 @@ }, "nixlib": { "locked": { - "lastModified": 1681001314, - "narHash": "sha256-5sDnCLdrKZqxLPK4KA8+f4A3YKO/u6ElpMILvX0g72c=", + "lastModified": 1689469483, + "narHash": "sha256-2SBhY7rZQ/iNCxe04Eqxlz9YK9KgbaTMBssq3/BgdWY=", "owner": "nix-community", "repo": "nixpkgs.lib", - "rev": "367c0e1086a4eb4502b24d872cea2c7acdd557f4", + "rev": "02fea408f27186f139153e1ae88f8ab2abd9c22c", "type": "github" }, "original": { @@ -69,11 +72,11 @@ }, "nixos": { "locked": { - "lastModified": 1686392259, - "narHash": "sha256-hqSS9hKhWldIZr1bBp9xKhIznnGPICGKzuehd2LH0UA=", + "lastModified": 1688392541, + "narHash": "sha256-lHrKvEkCPTUO+7tPfjIcb7Trk6k31rz18vkyqmkeJfY=", "owner": "nixos", "repo": "nixpkgs", - "rev": "ef24b2fa0c5f290a35064b847bc211f25cb85c88", + "rev": "ea4c80b39be4c09702b0cb3b42eab59e2ba4f24b", "type": "github" }, "original": { @@ -91,11 +94,11 @@ ] }, "locked": { - "lastModified": 1683530131, - "narHash": "sha256-R0RSqj6JdZfru2x/cM19KJMHsU52OjtyxI5cccd+uFc=", + "lastModified": 1690133435, + "narHash": "sha256-YNZiefETggroaTLsLJG2M+wpF0pJPwiauKG4q48ddNU=", "owner": "nix-community", "repo": "nixos-generators", - "rev": "10079333313ff62446e6f2b0e7c5231c7431d269", + "rev": "b1171de4d362c022130c92d7c8adc4bf2b83d586", "type": "github" }, "original": { @@ -106,11 +109,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1686412476, - "narHash": "sha256-inl9SVk6o5h75XKC79qrDCAobTD1Jxh6kVYTZKHzewA=", + "lastModified": 1692447944, + "narHash": "sha256-fkJGNjEmTPvqBs215EQU4r9ivecV5Qge5cF/QDLVn3U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "21951114383770f96ae528d0ae68824557768e81", + "rev": "d680ded26da5cf104dd2735a51e88d2d8f487b4d", "type": "github" }, "original": { @@ -122,32 +125,32 @@ }, "nixpkgs-stable": { "locked": { - "lastModified": 1678872516, - "narHash": "sha256-/E1YwtMtFAu2KUQKV/1+KFuReYPANM2Rzehk84VxVoc=", + "lastModified": 1685801374, + "narHash": "sha256-otaSUoFEMM+LjBI1XL/xGB5ao6IwnZOXc47qhIgJe8U=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "9b8e5abb18324c7fe9f07cb100c3cd4a29cda8b8", + "rev": "c37ca420157f4abc31e26f436c1145f8951ff373", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-22.11", + "ref": "nixos-23.05", "repo": "nixpkgs", "type": "github" } }, "nixpkgs_2": { "locked": { - "lastModified": 1681303793, - "narHash": "sha256-JEdQHsYuCfRL2PICHlOiH/2ue3DwoxUX7DJ6zZxZXFk=", + "lastModified": 1689261696, + "narHash": "sha256-LzfUtFs9MQRvIoQ3MfgSuipBVMXslMPH/vZ+nM40LkA=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "fe2ecaf706a5907b5e54d979fbde4924d84b65fc", + "rev": "df1eee2aa65052a18121ed4971081576b25d6b5c", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-unstable", + "ref": "nixpkgs-unstable", "repo": "nixpkgs", "type": "github" } @@ -161,11 +164,11 @@ "nixpkgs-stable": "nixpkgs-stable" }, "locked": { - "lastModified": 1685361114, - "narHash": "sha256-4RjrlSb+OO+e1nzTExKW58o3WRwVGpXwj97iCta8aj4=", + "lastModified": 1692274144, + "narHash": "sha256-BxTQuRUANQ81u8DJznQyPmRsg63t4Yc+0kcyq6OLz8s=", "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "ca2fdbf3edda2a38140184da6381d49f8206eaf4", + "rev": "7e3517c03d46159fdbf8c0e5c97f82d5d4b0c8fa", "type": "github" }, "original": { @@ -198,16 +201,31 @@ "type": "github" } }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, "utils": { "inputs": { - "systems": "systems" + "systems": "systems_2" }, "locked": { - "lastModified": 1681202837, - "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", + "lastModified": 1689068808, + "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", "owner": "numtide", "repo": "flake-utils", - "rev": "cfacdce06f30d2b68473a46042957675eebb3401", + "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", "type": "github" }, "original": { diff --git a/nix/chex.nix b/nix/chex.nix index da3ba78..1edb5f0 100644 --- a/nix/chex.nix +++ b/nix/chex.nix @@ -9,6 +9,7 @@ , pytestCheckHook , toolz , cloudpickle +, typing-extensions }: buildPythonPackage rec { @@ -30,6 +31,7 @@ buildPythonPackage rec { jax numpy toolz + typing-extensions ]; postPatch = '' diff --git a/nix/flax.nix b/nix/flax.nix index 07b1358..c8d1770 100644 --- a/nix/flax.nix +++ b/nix/flax.nix @@ -10,11 +10,12 @@ buildPythonPackage rec { name = "flax"; + format = "pyproject"; src = fetchFromGitHub { owner = "google"; repo = "flax"; - rev = "v0.6.5"; - hash = "sha256-Vv68BK83gTIKj0r9x+twdhqmRYziD0vxQCdHkYSeTak="; + rev = "v0.7.2"; + hash = "sha256-Zj2xwtUBYrr0lwSjKn8bLHiBtKB0ZUFif7byHoGSZvg="; }; propagatedBuildInputs = [ jax @@ -25,8 +26,8 @@ buildPythonPackage rec { matplotlib ]; postPatch = '' - sed -i '/tensorstore/d' setup.py + sed -i '/tensorstore/d' pyproject.toml + sed -i '/orbax/d' pyproject.toml ''; doCheck = false; - pythonRemoveDeps = [ "orbax" ]; } diff --git a/nix/jax.nix b/nix/jax.nix index d0e8136..e69de29 100644 --- a/nix/jax.nix +++ b/nix/jax.nix @@ -1,108 +0,0 @@ -{ lib -, absl-py -, blas -, buildPythonPackage -, etils -, fetchFromGitHub -, lapack -, matplotlib -, numpy -, opt-einsum -, pytestCheckHook -, pytest-xdist -, pythonOlder -, scipy -, typing-extensions -, jaxlib -}: - -let - usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; -in -buildPythonPackage rec { - pname = "jax"; - version = "0.4.10"; - format = "setuptools"; - - disabled = pythonOlder "3.7"; - - src = fetchFromGitHub { - owner = "google"; - repo = pname; - rev = "jax-v${version}"; - hash = "sha256-USdEVEcZ28YHDJQDzWzpWdBQQimi27xe5Quc9dESoXw="; - }; - - # jaxlib is _not_ included in propagatedBuildInputs because there are - # different versions of jaxlib depending on the desired target hardware. The - # JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the - # CPU wheel is packaged. - propagatedBuildInputs = [ - absl-py - etils - numpy - opt-einsum - scipy - typing-extensions - jaxlib - ] ++ etils.optional-dependencies.epath; - - checkInputs = [ - jaxlib - matplotlib - pytestCheckHook - pytest-xdist - ]; - - # high parallelism will result in the tests getting stuck - dontUsePytestXdist = true; - - # NOTE: Don't run the tests in the expiremental directory as they require flax - # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. - # Not a big deal, this is how the JAX docs suggest running the test suite - # anyhow. - pytestFlagsArray = [ - "--numprocesses=4" - "-W ignore::DeprecationWarning" - "tests/" - ]; - - disabledTests = [ - # Exceeds tolerance when the machine is busy - "test_custom_linear_solve_aux" - ] ++ lib.optionals usingMKL [ - # See - # * https://github.com/google/jax/issues/9705 - # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 - # * https://github.com/NixOS/nixpkgs/issues/161960 - "test_custom_linear_solve_cholesky" - "test_custom_root_with_aux" - "testEigvalsGrad_shape" - ]; - - doCheck = false; # Disable running checks during the build process - - - # See https://github.com/google/jax/issues/11722. This is a temporary fix in - # order to unblock etils, and upgrading jax/jaxlib to the latest version. See - # https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993. - disabledTestPaths = [ - "tests/api_test.py" - "tests/core_test.py" - "tests/lax_numpy_indexing_test.py" - "tests/lax_numpy_test.py" - "tests/nn_test.py" - "tests/random_test.py" - "tests/sparse_test.py" - ]; - - # As of 0.3.22, `import jax` does not work without jaxlib being installed. - pythonImportsCheck = [ ]; - - meta = with lib; { - description = "Differentiate, compile, and transform Numpy code"; - homepage = "https://github.com/google/jax"; - license = licenses.asl20; - maintainers = with maintainers; [ samuela ]; - }; -} diff --git a/nix/jaxopt.nix b/nix/jaxopt.nix index 6417662..a1e84b8 100644 --- a/nix/jaxopt.nix +++ b/nix/jaxopt.nix @@ -8,6 +8,7 @@ , fetchFromGitHub , matplotlib , scikit-learn +, typing-extensions }: buildPythonPackage rec { @@ -27,6 +28,7 @@ buildPythonPackage rec { jax matplotlib scikit-learn + typing-extensions ]; doCheck = true; diff --git a/nix/orbax.nix b/nix/orbax.nix index 7ebb75a..80d5760 100644 --- a/nix/orbax.nix +++ b/nix/orbax.nix @@ -11,16 +11,17 @@ , numpy , pyyaml , tensorflow +, importlib-resources #, tensorstore }: buildPythonPackage rec { - name = "orbax"; + name = "orbax-checkpoint"; src = fetchFromGitHub { owner = "google"; repo = "orbax"; - rev = "v0.1.6"; - hash = "sha256-Vkqt2ovTan6bQJI4Il06hG0NlYmt60to4ue4U9qG9HY="; + rev = "v0.1.7"; + hash = "sha256-Zk9hbvSA82jt0wLR7AZWEmHDA4A1+9t0ezf74FYkqe0="; }; format = "pyproject"; @@ -36,6 +37,7 @@ buildPythonPackage rec { numpy pyyaml tensorflow + importlib-resources # tensorstore ];