-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[doc][train] Clarify prepare_data_loader
shuffle behavior and include set_epoch
usage in all examples
#41807
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
What would Ray Train do if a user does have |
@woshiyyya If the dataloader already has a
|
Got it. The logic is indeed quite convoluted. There are actually 3 factors that may affects the shuffling behavior.
We can mention that if the users provided a |
@@ -129,6 +129,8 @@ Compare a PyTorch training script with and without Ray Train. | |||
|
|||
# Training | |||
for epoch in range(10): | |||
if ray.train.get_context().get_world_size() > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it surprised me that there's no set_epoch
in any of these Accelerate examples. So it'd be great to show it in our example.
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@woshiyyya I added some extra notes talking about caveats in this section: https://anyscale-ray--41807.com.readthedocs.build/en/41807/train/getting-started-pytorch.html#set-up-a-dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😍
…ay-project#41876) This PR fixes a previous typo in ray-project#41807 that updated the test, but CI didn't actually run the test due to no Ray Serve code being changed. Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Why are these changes needed?
prepare_data_loader
adds aDistributedSampler
to an existing pytorchDataLoader
object. To do this, it recreates aDataLoader
object and passes most arguments through from the original object, but also makes some implicit assumptions that are not configurable/visible to the caller.For example, if using just vanilla pytorch by itself, it's possible to do:
train_dataloader = DataLoader(..., shuffle=False, sampler=DistributedSampler(shuffle=True))
. Here, theDataLoader
setsshuffle=False
, but theDistributedSampler
will still do a shuffle on every epoch so that the training data order is not always the same. Theshuffle=False
argument of theDataLoader
is pretty much ignored because a custom sampler is supplied.However, with Ray Train, since this
prepare_data_loader
utility injects theDistributedSampler
for the user, there's no visibility on theshuffle
parameter. Ray Train will detect theshuffle
parameter set on the original dataloader, then pass that along to theDistributedSampler
. So, it's not possible to have thisFalse+True
situation.Additionally, if
shuffle=True
,DistributedSampler.set_epoch
must be called at the start of each epoch in order for the dataset ordering to be different for all workers on every epoch. This is because the seed of the sampler is determined at the epoch start (epoch seed = base random seed + epoch number
).Shuffling can be very important for training a model successfully -- if the data order remains the same every epoch, it's possible that training never converges (ex: we ran into this issue training resnet18 on imagenet).
Example:
Note: the
ray.train.get_context().get_world_size() > 1
condition is needed so that debugging withnum_workers=1
doesn't throw an error.set_epoch
is only a valid method onDistributedSampler
, and that only gets set whennum_workers > 1
.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.