New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to switch batch size during training? #153
Comments
@cclvr Hi, thanks for the issue. You can do it by setting import d3rlpy
cql = d3rlpy.algos.CQL(batch_size=256)
def callback(algo, epoch, total_step):
if total_step > 10000:
algo.set_params(batch_size=1024)
cql.fit(..., callback=callback) |
@takuseno Thanks for your kind and patient reply. I have tried this way but it doesn't work. Since I downloaded this codebase about 4 months ago, so I' m wondering whether the feature existed at that time? Thanks again for your patience. |
How did you confirm it didn't work? Probably, you can print the value during callback.
|
@takuseno , I print the batch_size during callback and it does change. But when I print the shape of batch.observations in functions e.g. in compute_critic_loss() of cql_impl.py, it doesn't change. And the time cost of one epoch is the same for different batch_size, which also means that batch_size is not set successfully. |
Ah, yes. You're right. Currently, the batch sampling is done at d3rlpy/d3rlpy/iterators/base.py Line 46 in 8eb11db
If you don't mind, you can hack around there. However, there is no way to change the mini-batch size without the hack for now. Sorry for the inconvenience. |
OK thanks, it is good enough right now. |
Hi, try this instead.
Let me know what happends |
Hi @jamartinh , thanks for your clue, I have tried your method but it still doesn't work. |
@cclvr When you set |
OK thanks! |
@takuseno , firstly thanks a lot for your clear and complete code base for offline RL. Recently I try to conduct new algorithms based on this code base, and I want to switch batch size during the training process, but I don't know how to modify it with the smallest changes . Could you help to give some clue? Looking forward to your replay.
The text was updated successfully, but these errors were encountered: