Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainers as generators #559

Merged
merged 46 commits into from
Mar 17, 2022

Conversation

jamartinh
Copy link
Contributor

The new proposed feature is to have trainers as generators.
The usage pattern is like:

trainer = onpolicy_trainer_generator(...)
for epoch, epoch_stat, info in trainer:
    print(f"Epoch: {epoch}")
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(ingo)
  • epoch int: the epoch number
  • epoch_stat dict: a large collection of metrics of current epoch, including stat
  • info dict: the usual dict out of the non generator version of trainer

You can even iterate on several different trainers at the same time:

trainer1 = onpolicy_trainer_generator(...)
trainer2 = onpolicy_trainer_generator(...)
for result1,result2 in zip(trainer1,trainer2):
   compare_results(result1,result2) 
  • I have marked all applicable categories:
    • exception-raising fix
    • algorithm implementation fix
    • documentation modification
    • new feature
  • I have reformatted the code using make format (required)
  • I have checked the code using make commit-checks (required)
  • If applicable, I have mentioned the relevant/related issue(s)
  • If applicable, I have listed every items in this Pull Request below

@Trinkle23897
Copy link
Collaborator

Nice suggestion! But could you please remove the duplicated code in a single trainer file? I'm thinking about the following approach (or something similar):

class Trainer:
  def __init__(self, ...):
    ...
  def __iter__(self, ...):
    ...
  def run(self):
    ...

onpolicy_trainer = lambda *args, **kwargs: Trainer(*args, **kwargs).run()
onpolicy_trainer_gen = Trainer  # or another name

@jamartinh
Copy link
Contributor Author

Hi @Trinkle23897 lets see how looks offline.py

Copy link
Collaborator

@Trinkle23897 Trinkle23897 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work!

tianshou/trainer/__init__.py Show resolved Hide resolved
@Trinkle23897
Copy link
Collaborator

And is it possible to create BaseTrainer to further reduce code duplication?

@jamartinh
Copy link
Contributor Author

jamartinh commented Mar 6, 2022 via email

@jamartinh
Copy link
Contributor Author

Now pushed the OffPolicyTrainer

@jamartinh
Copy link
Contributor Author

W.r.t common things of the three Trainers.

The init it is very the same except from few things
There are many common parameters, so that parameters can go to Base class
The run() functions will be exactly the same for all three Classes.

It is just the next that is different, however it has things in common.
I way of refactoring common things is to extract as common methods the repeated code parts in next

jamartinh and others added 4 commits March 6, 2022 21:11
* All the procedures are so equal that separating them will make to much unnecessary duplicated complex code

* Included tests in test_ppo.py, test_cql.py and test_rd3.py

* It can be simplified even more, but would break backward Api compatibility
@jamartinh
Copy link
Contributor Author

Hi man @Trinkle23897 I think I have done what I can.
There is piece of code that any formatter can solve to pass lint:

https://github.com/jamartinh/tianshou/runs/5469774938?check_suite_focus=true#step:7:27

But I cannot go further, so help is needed if this refactoring is usefull.

Thanks,
JAMH

@codecov-commenter
Copy link

codecov-commenter commented Mar 12, 2022

Codecov Report

Merging #559 (a62cf84) into master (2336a7d) will decrease coverage by 0.11%.
The diff coverage is 98.32%.

@@            Coverage Diff             @@
##           master     #559      +/-   ##
==========================================
- Coverage   93.62%   93.50%   -0.12%     
==========================================
  Files          64       65       +1     
  Lines        4392     4419      +27     
==========================================
+ Hits         4112     4132      +20     
- Misses        280      287       +7     
Flag Coverage Δ
unittests 93.50% <98.32%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tianshou/trainer/base.py 97.83% <97.83%> (ø)
tianshou/trainer/__init__.py 100.00% <100.00%> (ø)
tianshou/trainer/offline.py 100.00% <100.00%> (+3.77%) ⬆️
tianshou/trainer/offpolicy.py 100.00% <100.00%> (+2.50%) ⬆️
tianshou/trainer/onpolicy.py 100.00% <100.00%> (+6.17%) ⬆️
tianshou/utils/logger/tensorboard.py 73.80% <0.00%> (-21.43%) ⬇️
tianshou/policy/modelfree/trpo.py 88.52% <0.00%> (-4.92%) ⬇️
tianshou/data/collector.py 93.77% <0.00%> (-0.05%) ⬇️
tianshou/env/worker/subproc.py 94.03% <0.00%> (-0.04%) ⬇️

📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more

@jamartinh
Copy link
Contributor Author

@Trinkle23897
I get success in my computer wit GPU:

RHEL 7

test_sac_with_il.py::test_sac_with_il 

============================== 1 passed in 35.95s ==============================

Process finished with exit code 0
PASSED                             [100%]Epoch #1: test_reward: -480.222050 ± 144.561932, best_reward: -480.222050 ± 144.561932 in #1
Epoch #2: test_reward: -223.095769 ± 115.835893, best_reward: -223.095769 ± 115.835893 in #2
Epoch #1:  75%|#######4  | 17991/24000 [00:24<00:08, 725.82it/s, alpha=0.634, env_step=17990, len=200, loss/actor=115.555, loss/alpha=-0.534, loss/critic1=14.512, loss/critic2=17.633, n/ep=0, n/st=10, rew=-511.42]
Epoch #1: 501it [00:00, 506.55it/s, env_step=500, len=0, loss=0.004, n/ep=0, n/st=10, rew=0.00]                         
Epoch #2: 501it [00:00, 576.68it/s, env_step=1000, len=0, loss=0.001, n/ep=0, n/st=10, rew=0.00] 
CUDA Version: 11.4 
pytorch                   1.10.0          cuda112py39h3ad47f5_1    conda-forge
pytorch-gpu               1.10.0          cuda112py39h0bbbad9_1    conda-forge
torchaudio                0.10.0               py39_cu113    pytorch
torchmetrics              0.7.2              pyhd8ed1ab_0    conda-forge

@Trinkle23897
Copy link
Collaborator

2022-03-12 16-04-28 的屏幕截图

    def stop_fn(mean_rewards):
+       print("stop_fn", mean_rewards, args.reward_threshold)
        return mean_rewards >= args.reward_threshold

The upper half is the original version and the lower half is this version. It seems not exactly matched though...

@jamartinh
Copy link
Contributor Author

@Trinkle23897
Seems to be working for me, please report or help if you see any issue.
Thanks !

@jamartinh
Copy link
Contributor Author

@Trinkle23897 Ok now what?

@Trinkle23897
Copy link
Collaborator

@Trinkle23897 Ok now what?

Sorry about the delay because I have a deadline this evening and another deadline two days later. I'll have a look right after finishing those tasks.

Trinkle23897
Trinkle23897 previously approved these changes Mar 17, 2022
@Trinkle23897 Trinkle23897 merged commit 10d9190 into thu-ml:master Mar 17, 2022
@jamartinh jamartinh deleted the trainers_as_generators branch November 20, 2022 09:21
BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
The new proposed feature is to have trainers as generators.
The usage pattern is:

```python
trainer = OnPolicyTrainer(...)
for epoch, epoch_stat, info in trainer:
    print(f"Epoch: {epoch}")
    print(epoch_stat)
    print(info)
    do_something_with_policy()
    query_something_about_policy()
    make_a_plot_with(epoch_stat)
    display(info)
```

- epoch int: the epoch number
- epoch_stat dict: a large collection of metrics of the current epoch, including stat
- info dict: the usual dict out of the non-generator version of the trainer

You can even iterate on several different trainers at the same time:

```python
trainer1 = OnPolicyTrainer(...)
trainer2 = OnPolicyTrainer(...)
for result1, result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
```

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants