Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,23 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/omnilib/ufmt
rev: v1.3.2
rev: v2.0.0b2
hooks:
- id: ufmt
additional_dependencies:
- black == 21.9b0
- usort == 0.6.4
- black == 22.3.0
- usort == 1.0.3
- libcst == 0.4.7

- repo: https://github.com/pycqa/flake8
rev: 3.9.2
rev: 4.0.1
hooks:
- id: flake8
args: [--config=setup.cfg]
additional_dependencies:
- flake8-bugbear==22.10.27
- flake8-comprehensions==3.10.1


- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/storage/benchmark_sample_latency_over_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
}

storage_arg_options = {
"LazyMemmapStorage": dict(scratch_dir="/tmp/", device=torch.device("cpu")),
"LazyTensorStorage": dict(),
"ListStorage": dict(),
"LazyMemmapStorage": {"scratch_dir": "/tmp/", "device": torch.device("cpu")},
"LazyTensorStorage": {},
"ListStorage": {},
}
parser = argparse.ArgumentParser(
description="RPC Replay Buffer Example",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/setup_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .extension import get_ext_modules, CMakeBuild # noqa
from .extension import CMakeBuild, get_ext_modules # noqa
2 changes: 1 addition & 1 deletion build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import platform
import subprocess
from pathlib import Path
from subprocess import check_output, STDOUT, CalledProcessError
from subprocess import CalledProcessError, check_output, STDOUT

import torch
from setuptools import Extension
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
}


aafig_default_options = dict(scale=1.5, aspect=1.0, proportional=True)
aafig_default_options = {"scale": 1.5, "aspect": 1.0, "proportional": True}

# -- Generate knowledge base references -----------------------------------
current_path = os.path.dirname(os.path.realpath(__file__))
Expand Down
16 changes: 5 additions & 11 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
Expand All @@ -23,21 +23,15 @@
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
parallel_env_constructor,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_ddpg_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_ddpg_actor,
DDPGModelConfig,
)
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.losses import LossConfig, make_ddpg_loss
from torchrl.trainers.helpers.models import DDPGModelConfig, make_ddpg_actor
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
Expand Down
16 changes: 5 additions & 11 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.modules import EGreedyWrapper
from torchrl.record import VideoRecorder
Expand All @@ -22,21 +22,15 @@
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
parallel_env_constructor,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_dqn_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_dqn_actor,
DiscreteModelConfig,
)
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.losses import LossConfig, make_dqn_loss
from torchrl.trainers.helpers.models import DiscreteModelConfig, make_dqn_actor
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig


Expand Down
16 changes: 5 additions & 11 deletions examples/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch.cuda
import tqdm
from dreamer_utils import (
parallel_env_constructor,
transformed_env_constructor,
call_record,
EnvConfig,
grad_norm,
make_recorder_env,
EnvConfig,
parallel_env_constructor,
transformed_env_constructor,
)
from hydra.core.config_store import ConfigStore

Expand All @@ -38,14 +38,8 @@
get_stats_random_rollout,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.models import (
make_dreamer,
DreamerConfig,
)
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.models import DreamerConfig, make_dreamer
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import TrainerConfig
from torchrl.trainers.trainers import Recorder, RewardNormalizer

Expand Down
12 changes: 4 additions & 8 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from typing import Callable, Optional, Union, Any, Sequence
from dataclasses import dataclass, field as dataclass_field
from typing import Any, Callable, Optional, Sequence, Union

from torchrl.data import NdUnboundedContinuousTensorSpec
from torchrl.envs import ParallelEnv
Expand All @@ -14,6 +13,7 @@
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
CatFrames,
CenterCrop,
DoubleToFloat,
GrayScale,
NoopResetEnv,
Expand All @@ -22,12 +22,8 @@
RewardScaling,
ToTensorImage,
TransformedEnv,
CenterCrop,
)
from torchrl.envs.transforms.transforms import (
FlattenObservation,
TensorDictPrimer,
)
from torchrl.envs.transforms.transforms import FlattenObservation, TensorDictPrimer
from torchrl.record.recorder import VideoRecorder
from torchrl.trainers.loggers import Logger

Expand Down
9 changes: 3 additions & 6 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.objectives.value import GAE
Expand All @@ -23,17 +23,14 @@
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
parallel_env_constructor,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_ppo_loss, PPOLossConfig
from torchrl.trainers.helpers.models import (
make_ppo_model,
PPOModelConfig,
)
from torchrl.trainers.helpers.models import make_ppo_model, PPOModelConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
Expand Down
16 changes: 5 additions & 11 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
Expand All @@ -23,21 +23,15 @@
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
parallel_env_constructor,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_redq_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_redq_model,
REDQModelConfig,
)
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.losses import LossConfig, make_redq_loss
from torchrl.trainers.helpers.models import make_redq_model, REDQModelConfig
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
Expand Down
16 changes: 5 additions & 11 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import hydra
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
Expand All @@ -23,21 +23,15 @@
)
from torchrl.trainers.helpers.envs import (
correct_for_frame_skip,
EnvConfig,
get_stats_random_rollout,
parallel_env_constructor,
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_sac_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_sac_model,
SACModelConfig,
)
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
)
from torchrl.trainers.helpers.losses import LossConfig, make_sac_loss
from torchrl.trainers.helpers.models import make_sac_model, SACModelConfig
from torchrl.trainers.helpers.replay_buffer import make_replay_buffer, ReplayArgsConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
Expand Down
20 changes: 10 additions & 10 deletions examples/torchrl_features/memmap_td_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,25 @@ def tensordict_add_noreturn():
time.sleep(1)
t0 = time.time()
for w in range(1, args.world_size):
fut0 = rpc.rpc_async(f"worker{w}", get_tensordict, args=tuple())
fut0 = rpc.rpc_async(f"worker{w}", get_tensordict, args=())
fut0.wait()
fut1 = rpc.rpc_async(f"worker{w}", tensordict_add, args=tuple())
fut1 = rpc.rpc_async(f"worker{w}", tensordict_add, args=())
tensordict2 = fut1.wait()
tensordict2.clone()
print("time: ", time.time() - t0)
elif args.task == 1:
time.sleep(1)
t0 = time.time()
waiters = [
rpc.remote(f"worker{w}", get_tensordict, args=tuple())
rpc.remote(f"worker{w}", get_tensordict, args=())
for w in range(1, args.world_size)
]
td = torch.stack([waiter.to_here() for waiter in waiters], 0).contiguous()
print("time: ", time.time() - t0)

t0 = time.time()
waiters = [
rpc.remote(f"worker{w}", tensordict_add, args=tuple())
rpc.remote(f"worker{w}", tensordict_add, args=())
for w in range(1, args.world_size)
]
td = torch.stack([waiter.to_here() for waiter in waiters], 0).contiguous()
Expand All @@ -118,9 +118,9 @@ def tensordict_add_noreturn():
elif args.task == 2:
time.sleep(1)
t0 = time.time()
# waiters = [rpc.rpc_async(f"worker{w}", get_tensordict, args=tuple()) for w in range(1, args.world_size)]
# waiters = [rpc.rpc_async(f"worker{w}", get_tensordict, args=()) for w in range(1, args.world_size)]
waiters = [
rpc.remote(f"worker{w}", get_tensordict, args=tuple())
rpc.remote(f"worker{w}", get_tensordict, args=())
for w in range(1, args.world_size)
]
# td = torch.stack([waiter.wait() for waiter in waiters], 0).clone()
Expand All @@ -129,7 +129,7 @@ def tensordict_add_noreturn():
t0 = time.time()
if args.memmap:
waiters = [
rpc.remote(f"worker{w}", tensordict_add_noreturn, args=tuple())
rpc.remote(f"worker{w}", tensordict_add_noreturn, args=())
for w in range(1, args.world_size)
]
print("temp t: ", time.time() - t0)
Expand All @@ -139,7 +139,7 @@ def tensordict_add_noreturn():
print("temp t: ", time.time() - t0)
else:
waiters = [
rpc.remote(f"worker{w}", tensordict_add, args=tuple())
rpc.remote(f"worker{w}", tensordict_add, args=())
for w in range(1, args.world_size)
]
print("temp t: ", time.time() - t0)
Expand All @@ -153,14 +153,14 @@ def tensordict_add_noreturn():
time.sleep(1)
t0 = time.time()
waiters = [
rpc.remote(f"worker{w}", get_tensordict, args=tuple())
rpc.remote(f"worker{w}", get_tensordict, args=())
for w in range(1, args.world_size)
]
td = torch.stack([waiter.to_here() for waiter in waiters], 0)
print("time to receive objs: ", time.time() - t0)
t0 = time.time()
waiters = [
rpc.remote(f"worker{w}", tensordict_add, args=tuple())
rpc.remote(f"worker{w}", tensordict_add, args=())
for w in range(1, args.world_size)
]
print("temp t: ", time.time() - t0)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#[tool.usort]
[tool.usort]
first_party_detection = false

[build-system]
requires = ["setuptools", "wheel", "torch"]
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ per-file-ignores =
test/opengl_rendering.py: F401

exclude = venv
extend-select = B901, C401, C408, C409

[pydocstyle]
;select = D417 # Missing argument descriptions in the docstring
Expand Down
Loading