Skip to content

Commit

Permalink
fix logger.write error in atari script (#444)
Browse files Browse the repository at this point in the history
- fix a bug in #427: logger.write should pass a dict
- change SubprocVectorEnv to ShmemVectorEnv in atari
- increase logger interval for eps
  • Loading branch information
Trinkle23897 committed Sep 8, 2021
1 parent fc251ab commit e8f8cdf
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -148,3 +148,4 @@ MUJOCO_LOG.TXT
*.pkl
*.hdf5
wandb/
videos/
4 changes: 2 additions & 2 deletions examples/atari/atari_bcq.py
Expand Up @@ -11,7 +11,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_discrete_bcq(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_c51.py
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -75,10 +75,10 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down Expand Up @@ -141,7 +141,8 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_cql.py
Expand Up @@ -11,7 +11,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_discrete_cql(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_crr.py
Expand Up @@ -11,7 +11,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_discrete_crr(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_dqn.py
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -72,10 +72,10 @@ def test_dqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down Expand Up @@ -135,7 +135,8 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_fqf.py
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -78,10 +78,10 @@ def test_fqf(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down Expand Up @@ -158,7 +158,8 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_iqn.py
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -78,10 +78,10 @@ def test_iqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down Expand Up @@ -153,7 +153,8 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_qrdqn.py
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -73,10 +73,10 @@ def test_qrdqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down Expand Up @@ -137,7 +137,8 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down
12 changes: 7 additions & 5 deletions examples/atari/atari_rainbow.py
Expand Up @@ -10,7 +10,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import RainbowPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -85,10 +85,10 @@ def test_rainbow(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[lambda: make_atari_env(args) for _ in range(args.training_num)]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
)
# seed
Expand Down Expand Up @@ -174,15 +174,17 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
if not args.no_priority:
if env_step <= args.beta_anneal_step:
beta = args.beta - env_step / args.beta_anneal_step * \
(args.beta - args.beta_final)
else:
beta = args.beta_final
buffer.set_beta(beta)
logger.write('train/beta', env_step, beta)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/beta": beta})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down
9 changes: 5 additions & 4 deletions examples/vizdoom/vizdoom_c51.py
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.env import ShmemVectorEnv
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
Expand Down Expand Up @@ -72,13 +72,13 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv(
train_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num)
]
)
test_envs = SubprocVectorEnv(
test_envs = ShmemVectorEnv(
[
lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num))
Expand Down Expand Up @@ -144,7 +144,8 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down

0 comments on commit e8f8cdf

Please sign in to comment.