Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainers as generators #559

Merged
merged 46 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
8bad065
add docstring :param buffer to offline_trainer in offline.py
jamartinh Mar 5, 2022
ff9c0c9
Add param yield_epoch to trainers. if True, converts the function int…
jamartinh Mar 5, 2022
2b72992
Add trainer geneators for offline.py, offpolicy.py and onpolicy.py .
jamartinh Mar 5, 2022
9a6a72b
fix PEP8
jamartinh Mar 5, 2022
d05f0e0
fix PEP8
jamartinh Mar 5, 2022
5566be0
fix PEP8
jamartinh Mar 5, 2022
185c006
fix yapf
jamartinh Mar 5, 2022
79f050a
removed comments in format section of Makefile. It produces errors on…
jamartinh Mar 5, 2022
4cbc7c8
fix isort
jamartinh Mar 5, 2022
ffbe30a
fix rare error with dict with mypy
jamartinh Mar 5, 2022
23f00d2
fix rare error with dict with mypy
jamartinh Mar 5, 2022
f64eb2d
fix docstrings
jamartinh Mar 5, 2022
b6b0ed7
refactored offline.py to one iterator class
jamartinh Mar 6, 2022
0f39eac
drop test_sac_with_il_trainer_generator.py
jamartinh Mar 6, 2022
21cdbe6
improve offline.py with best practices on exhausting iterator and cle…
jamartinh Mar 6, 2022
2483dea
Create an Iterator class instead of a generator function, following t…
jamartinh Mar 6, 2022
88cb63c
Expose new _iter versions and Iterator Classes
jamartinh Mar 6, 2022
34feb5b
Add OffPolicyTrainer as Iterator adn add testing in test_td3.py
jamartinh Mar 6, 2022
1c7eaef
fix doc format
jamartinh Mar 6, 2022
4067428
Merge branch 'master' into trainers_as_generators
Trinkle23897 Mar 8, 2022
5ca6fb8
* Refactored trainers into One BaseTrainer class.
jamartinh Mar 8, 2022
d705744
Merge remote-tracking branch 'jamh/trainers_as_generators' into train…
jamartinh Mar 8, 2022
b4fa395
fix formatting
jamartinh Mar 8, 2022
91c787c
Merge remote-tracking branch 'origin/master' into trainers_as_generators
Trinkle23897 Mar 8, 2022
c1f5f25
docs
Trinkle23897 Mar 8, 2022
b12beb1
fix missing import
Trinkle23897 Mar 9, 2022
a4ae2e3
* fix formatting
jamartinh Mar 12, 2022
0690d12
Merge branch 'thu-ml:master' into trainers_as_generators
jamartinh Mar 12, 2022
e2756f0
Merge branch 'master' into trainers_as_generators
Trinkle23897 Mar 12, 2022
c902d61
update docs
Trinkle23897 Mar 12, 2022
a3e7e2c
update rst
Trinkle23897 Mar 12, 2022
651726f
fix early stopping during train [train_step]
jamartinh Mar 12, 2022
e6b00e2
* fix early stopping during train train_step
jamartinh Mar 12, 2022
4d76843
* fix early stopping during train train_step
jamartinh Mar 12, 2022
23ce483
* fix early stopping during train train_step
jamartinh Mar 12, 2022
1d707f8
* fix early stopping during train train_step
jamartinh Mar 13, 2022
479b794
* fix early stopping during train train_step
jamartinh Mar 13, 2022
3adf0e1
Merge branch 'master' into trainers_as_generators
Trinkle23897 Mar 16, 2022
08f65a6
fix a bug in BaseTrainer.run return value missing
Trinkle23897 Mar 16, 2022
5ec4eb3
change seed to pass ci
Trinkle23897 Mar 16, 2022
89ce44f
learning_type: str
Trinkle23897 Mar 16, 2022
a320e68
fix ci
Trinkle23897 Mar 16, 2022
6df9365
reorg some code
Trinkle23897 Mar 17, 2022
7a00daf
revert
Trinkle23897 Mar 17, 2022
3ce4f6d
missing docs for on-policy trainer
Trinkle23897 Mar 17, 2022
a62cf84
missing docs
Trinkle23897 Mar 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
import tianshou, gym, torch, numpy, sys
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
```
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ lint:
flake8 ${LINT_PATHS} --count --show-source --statistics

format:
# sort imports
$(call check_install, isort)
isort ${LINT_PATHS}
# reformat using yapf
$(call check_install, yapf)
yapf -ir ${LINT_PATHS}

Expand Down Expand Up @@ -57,6 +55,6 @@ doc-clean:

clean: doc-clean

commit-checks: format lint mypy check-docstyle spelling
commit-checks: lint check-codestyle mypy check-docstyle spelling

.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
44 changes: 43 additions & 1 deletion docs/api/tianshou.trainer.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,49 @@
tianshou.trainer
================

.. automodule:: tianshou.trainer

On-policy
---------

.. autoclass:: tianshou.trainer.OnpolicyTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.onpolicy_trainer

.. autoclass:: tianshou.trainer.onpolicy_trainer_iter


Off-policy
----------

.. autoclass:: tianshou.trainer.OffpolicyTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.offpolicy_trainer

.. autoclass:: tianshou.trainer.offpolicy_trainer_iter


Offline
-------

.. autoclass:: tianshou.trainer.OfflineTrainer
:members:
:undoc-members:
:show-inheritance:

.. autofunction:: tianshou.trainer.offline_trainer

.. autoclass:: tianshou.trainer.offline_trainer_iter


utils
-----

.. autofunction:: tianshou.trainer.test_episode

.. autofunction:: tianshou.trainer.gather_info
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ fqf
iqn
qrdqn
rl
offpolicy
onpolicy
quantile
quantiles
dqn
param
async
subprocess
deque
nn
equ
cql
Expand Down
20 changes: 20 additions & 0 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho

Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.

We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
::

trainer = OnpolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)
do_something_with_policy()
query_something_about_policy()
make_a_plot_with(epoch_stat)
display(info)

# or even iterate on several trainers at the same time

trainer1 = OnpolicyTrainer(...)
trainer2 = OnpolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
compare_results(result1, result2, ...)


.. _pseudocode:

Expand Down
14 changes: 10 additions & 4 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
print("Fail to restore policy and optim.")

# trainer
result = onpolicy_trainer(
trainer = OnpolicyTrainer(
policy,
train_collector,
test_collector,
Expand All @@ -173,10 +173,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
)
assert stop_fn(result['best_reward'])

for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

assert stop_fn(info["best_reward"])

if __name__ == '__main__':
pprint.pprint(result)
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
Expand Down
2 changes: 1 addition & 1 deletion test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
Expand Down
19 changes: 12 additions & 7 deletions test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tianshou.env import DummyVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
Expand Down Expand Up @@ -135,8 +135,8 @@ def save_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

# trainer
result = offpolicy_trainer(
# Iterator trainer
trainer = OffpolicyTrainer(
policy,
train_collector,
test_collector,
Expand All @@ -148,12 +148,17 @@ def stop_fn(mean_rewards):
update_per_step=args.update_per_step,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])
for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

if __name__ == '__main__':
pprint.pprint(result)
assert stop_fn(info["best_reward"])

if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
Expand Down
16 changes: 11 additions & 5 deletions test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
Expand Down Expand Up @@ -195,7 +195,7 @@ def watch():
collector.collect(n_episode=1, render=1 / 35)

# trainer
result = offline_trainer(
trainer = OfflineTrainer(
policy,
buffer,
test_collector,
Expand All @@ -207,11 +207,17 @@ def watch():
stop_fn=stop_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])

for epoch, epoch_stat, info in trainer:
print(f"Epoch: {epoch}")
print(epoch_stat)
print(info)

assert stop_fn(info["best_reward"])

# Let's watch its performance!
if __name__ == '__main__':
pprint.pprint(result)
if __name__ == "__main__":
pprint.pprint(info)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
Expand Down
30 changes: 24 additions & 6 deletions tianshou/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
"""Trainer package."""

# isort:skip_file

from tianshou.trainer.utils import test_episode, gather_info
from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
from tianshou.trainer.offline import offline_trainer
from tianshou.trainer.base import BaseTrainer
from tianshou.trainer.offline import (
OfflineTrainer,
offline_trainer,
offline_trainer_iter,
)
from tianshou.trainer.offpolicy import (
OffpolicyTrainer,
offpolicy_trainer,
offpolicy_trainer_iter,
)
from tianshou.trainer.onpolicy import (
OnpolicyTrainer,
onpolicy_trainer,
onpolicy_trainer_iter,
)
from tianshou.trainer.utils import gather_info, test_episode

__all__ = [
"BaseTrainer",
"offpolicy_trainer",
"offpolicy_trainer_iter",
"OffpolicyTrainer",
"onpolicy_trainer",
"onpolicy_trainer_iter",
"OnpolicyTrainer",
"offline_trainer",
"offline_trainer_iter",
jamartinh marked this conversation as resolved.
Show resolved Hide resolved
"OfflineTrainer",
"test_episode",
"gather_info",
]
Loading