Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to switch batch size during training? #153

Closed
cclvr opened this issue Dec 14, 2021 · 10 comments
Closed

How to switch batch size during training? #153

cclvr opened this issue Dec 14, 2021 · 10 comments

Comments

@cclvr
Copy link

cclvr commented Dec 14, 2021

@takuseno , firstly thanks a lot for your clear and complete code base for offline RL. Recently I try to conduct new algorithms based on this code base, and I want to switch batch size during the training process, but I don't know how to modify it with the smallest changes . Could you help to give some clue? Looking forward to your replay.

@takuseno
Copy link
Owner

@cclvr Hi, thanks for the issue. You can do it by setting callback.

import d3rlpy

cql = d3rlpy.algos.CQL(batch_size=256)

def callback(algo, epoch, total_step):
    if total_step > 10000:
        algo.set_params(batch_size=1024)

cql.fit(..., callback=callback)

@cclvr
Copy link
Author

cclvr commented Dec 18, 2021

@cclvr Hi, thanks for the issue. You can do it by setting callback.

import d3rlpy

cql = d3rlpy.algos.CQL(batch_size=256)

def callback(algo, epoch, total_step):
    if total_step > 10000:
        algo.set_params(batch_size=1024)

cql.fit(..., callback=callback)

@takuseno Thanks for your kind and patient reply. I have tried this way but it doesn't work. Since I downloaded this codebase about 4 months ago, so I' m wondering whether the feature existed at that time? Thanks again for your patience.

@takuseno
Copy link
Owner

How did you confirm it didn't work? Probably, you can print the value during callback.

def callback(algo, epoch, total_step):
    if total_step > 10000:
        algo.set_params(batch_size=1024)
    print(algo.batch_size)

@cclvr
Copy link
Author

cclvr commented Dec 18, 2021

How did you confirm it didn't work? Probably, you can print the value during callback.

def callback(algo, epoch, total_step):
    if total_step > 10000:
        algo.set_params(batch_size=1024)
    print(algo.batch_size)

@takuseno , I print the batch_size during callback and it does change. But when I print the shape of batch.observations in functions e.g. in compute_critic_loss() of cql_impl.py, it doesn't change. And the time cost of one epoch is the same for different batch_size, which also means that batch_size is not set successfully.

@takuseno
Copy link
Owner

Ah, yes. You're right. Currently, the batch sampling is done at Iterator class.

def __next__(self) -> TransitionMiniBatch:

If you don't mind, you can hack around there. However, there is no way to change the mini-batch size without the hack for now. Sorry for the inconvenience.

@cclvr
Copy link
Author

cclvr commented Dec 18, 2021

Ah, yes. You're right. Currently, the batch sampling is done at Iterator class.

def __next__(self) -> TransitionMiniBatch:

If you don't mind, you can hack around there. However, there is no way to change the mini-batch size without the hack for now. Sorry for the inconvenience.

OK thanks, it is good enough right now.

@jamartinh
Copy link
Contributor

Hi, try this instead.

fitter = algo.fitter(
    dataset,      
    n_epochs=10,
    verbose=False,
    tensorboard_dir=None,
    save_metrics=False,
    shuffle=True,
)

for epoch, metrics in fitter:
    algo.batch_size += 1
    

Let me know what happends

@cclvr
Copy link
Author

cclvr commented Dec 28, 2021

Hi, try this instead.

fitter = algo.fitter(
    dataset,      
    n_epochs=10,
    verbose=False,
    tensorboard_dir=None,
    save_metrics=False,
    shuffle=True,
)

for epoch, metrics in fitter:
    algo.batch_size += 1
    

Let me know what happends

Hi @jamartinh , thanks for your clue, I have tried your method but it still doesn't work.

@takuseno
Copy link
Owner

@cclvr When you set n_epochs instead of n_steps, epoch means a set of iterations until the all training data are consumed. So 3899 came out from dataset size / batch size. If you want to set the exact number of steps per epoch, you should use n_steps.

@cclvr
Copy link
Author

cclvr commented Dec 28, 2021

@cclvr When you set n_epochs instead of n_steps, epoch means a set of iterations until the all training data are consumed. So 3899 came out from dataset size / batch size. If you want to set the exact number of steps per epoch, you should use n_steps.

OK thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants