-
Notifications
You must be signed in to change notification settings - Fork 436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dataloader worker seed fix #161
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
recipes/finetune_llm.py
Outdated
@@ -54,6 +54,9 @@ def recipe(kwargs): | |||
# ---- Initialize components ---- # | |||
logger = get_logger() | |||
|
|||
# ---- Initialize seed ---- # | |||
seed(kwargs["seed"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'm not sure if we need to have separate seeds for training outside of data loading (any random initialization, generation, etc, and dataloader worker). Would be great to get @NicolasHug review on this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the changes in this diff - https://fburl.com/k2o5zm7w. From dataloading/transforms point of view, this is required to seed the random transforms that get executed on the trainer process (when number of dataloader workers is set to 0, that is, no multiprocess dataloading). I did connect with @NicolasHug about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other alternative is for ReproducibleDataLoader to take in the seed value in its constructor (from recipe kwargs) and then set the seed in its init method (that way it is not present in the recipe). But any of the components created before the DataLoader won't see that seed and I am not sure if this is a concern for other components - Cc @daniellepintz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am supportive of the change to just have one single seed to control the entire RNG of a training run.
Enabling separate RNG stream for the data-loading and for all the rest can be a nice plus, but it's only that: a nice plus. It's not critical at all, and as we've seen (from past PR discussions and adhoc chats), it leads to some massive complexity that we shouldn't have to deal with in torchtune at that point.
recipes/finetune_llm.py
Outdated
@@ -54,6 +54,9 @@ def recipe(kwargs): | |||
# ---- Initialize components ---- # | |||
logger = get_logger() | |||
|
|||
# ---- Initialize seed ---- # | |||
seed(kwargs["seed"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be
seed(kwargs["seed"]) | |
seed(kwargs["seed"] + rank) |
because
- This is what happens by default when not setting a manual seed.
- This will greatly simplify the internals of the
ReproducibleDataLoader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, yes we need this so that transforms have different RNG state in each trainer (same as the case when num_workers > 0).
That said, if we do this, we need to explicitly pass in the kwargs["seed"] to ReproducibleDataLoader as the seed used for DistributedSampler needs to be same for all ranks.
Discussed offline with @NicolasHug. Few things to complete: |
@rohan-varma @NicolasHug PR ready for review as I have addressed comments |
@@ -59,19 +59,24 @@ def __init__( # noqa: DOC101 | |||
if isinstance(dataset, IterableDataset): | |||
raise ValueError("ReproducibleDataLoader only supports Map style datasets.") | |||
|
|||
# Ensure that the seed provided is the same across all ranks | |||
if dist.is_available() and dist.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rohan-varma @daniellepintz Please review this section to see if this looks okay. I am not sure how to test this logic - any idea how to make sure dist is setup properly either in unit test or when I start finetune_llm.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably run a finetune here with distributed (should be documented in the README) and where the seed is set to make sure this codepath executes as expected.
For example, not sure if things will work OOTB creating a seed_tensor
on the CPU and passing it into collectives that may require tensor to live on the GPU. Can help debug any issues that pop up once we can run this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rohan-varma I see that cpu (tensor) is a supported backend for collectives - https://pytorch.org/docs/stable/distributed.html#backends. I want to run the distributed and verify it all works but blocked at the moment as the tune command is failing (due to ImportError) in main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I have a very noob question - why do we need to make this check? IIUC, torch.manual_seed sets the same seed across all devices (ref: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735). This seems like checking PyTorch functionality? Let me know if I misunderstand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.manual_seed sets the same seed across all devices
yes.
Except that by default, all ranks get a different base seed. So if torch.manual_seed()
isn't called and if users were to pass sampler_seed=torch.randint(...)
, sampler_seed
would have a different value for each rank, which would be really bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug we can get around that by just documenting this in code and making sure the recipe does call manual_seed? There are so many ways in which the user can shoot themselves in the foot, but we can clearly articulate - do it this way if you want you data read in a deterministic fashion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I take this comment back. It makes sense to have this check since this component and the recipe should be decoupled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for working on this! have some comments / questions.
user. If no generator is provided, seed is also used to set the | ||
base_seed for all dataloader workers to ensure transforms are | ||
repeatable. If no seed is provided, a random number is used as the seed. | ||
user. If no seed is provided, a random number is used as the seed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provide some details on what this seed does in the distributed sampler in this docstring itself would be great.
# Ensure that the seed provided is the same across all ranks | ||
if dist.is_available() and dist.is_initialized(): | ||
seed_tensor = torch.tensor(sampler_seed) | ||
dist.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need any of the barriers, as .item()
will sync GPU with CPU and we can be sure that the result is correct.
@@ -59,19 +59,24 @@ def __init__( # noqa: DOC101 | |||
if isinstance(dataset, IterableDataset): | |||
raise ValueError("ReproducibleDataLoader only supports Map style datasets.") | |||
|
|||
# Ensure that the seed provided is the same across all ranks | |||
if dist.is_available() and dist.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this will only run if init_process_group
is called, which is called via init_from_env
. If this call is after, then the logic should be okay. cc @daniellepintz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so I don't require checks in L63? I tried to perform a run without these checks and it failed and thus I put in this check. Will confirm once again with the distributed run documented in README
@@ -59,19 +59,24 @@ def __init__( # noqa: DOC101 | |||
if isinstance(dataset, IterableDataset): | |||
raise ValueError("ReproducibleDataLoader only supports Map style datasets.") | |||
|
|||
# Ensure that the seed provided is the same across all ranks | |||
if dist.is_available() and dist.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably run a finetune here with distributed (should be documented in the README) and where the seed is set to make sure this codepath executes as expected.
For example, not sure if things will work OOTB creating a seed_tensor
on the CPU and passing it into collectives that may require tensor to live on the GPU. Can help debug any issues that pop up once we can run this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gokulavasan , made some comments below.
@gokulavasan @rohan-varma @kartikayk in its latest state (i.e. in this PR), the ReproducibleDataLoader
does nothing other than passing a seed to the DistributedSampler
, and automatically calls sampler.set_epoch()
.
In other words: it's a pretty shallow wrapper whose sole purpose is to avoid users to call set_epoch()
. We have already missed 2 bugs during code review, one of which is really serious (#177, #168) .
Should we simplify this further and instead of having a custom DataLoader in torchtune, simply have a custom DistributedSampler
? This DistributedSampler would have the exact same benefits as the current ReproducibleDataLoader
(i.e. automatically call set_epoch()
in __iter__()
), but without the major problems that come with owning a data-loader.
recipes/finetune_llm.py
Outdated
@@ -54,6 +54,15 @@ def recipe(kwargs): | |||
# ---- Initialize components ---- # | |||
logger = get_logger() | |||
|
|||
# ---- Initialize seed ---- # | |||
_, rank = get_world_size_and_rank() | |||
if kwargs.get("seed", None) is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should kwargs
always contain the "seed"
key? In which case we could do a one-liner:
base_seed = kwargs["seed"] or torch.empty((), dtype=torch.int32).random_().item()
|
||
for run in range(4): | ||
dataloader = ReproducibleDataLoader( | ||
map_dataset, batch_size=2, shuffle=True, num_workers=4, seed=seed | ||
map_dataset, | ||
batch_size=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @kartikayk noticed in #65 (comment), batch_size
and num_workers
are hard-coded here but also parametrized-over.
output_max = dist.all_reduce(seed_tensor, op=ReduceOp.MAX) | ||
dist.barrier() | ||
output_min = dist.all_reduce(seed_tensor, op=ReduceOp.MIN) | ||
if output_max.item() != sampler_seed or output_min.item() != sampler_seed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking min
and max
values works but there's a more direct way, e.g. using all_gather
to collect all individual values in a list, and then just asserting something like all(v == sampler_seed for v in gathered_list)
.
Didn't mean to approve just yet, sorry
@NicolasHug Thanks for the call out. I do agree that getting the RNG state is tough to get correct. I should have been more proactive in addressing these (which were opened for later during that original PR) - #85 (comment) and #85 (comment). Basically having these unit tests would have caught these. The issues that the current unit tests in main provide enough confidence on the data iterator order but did miss out the transform randomness control. I have addressed it in this PR by adding these required tests. I also want to point out that #177 would have been the case even if we had used just the DistributedSampler, the reason being the seed would have been set to just the dataloader_seed in the recipe/trainer process. Worker=0 case (#168) was certainly a miss and the unit test now tests it. I am fine with reverting back to just using a Sampler but I am currently working on dataloader/sampler state checkpoint here - #176 (it is not complete just yet but opened a draft PR to provide a flavor of the changes). This change will involve tapping into dataloader next calls so that we can keep track of the number of steps that we have invoked so that we can resume from the correct state (that is, seek just after the number of items consumed [not just prefetched]). |
Considering the dozens of hours we have collectively spent on this, and the risk for bugs, I believe it's worth re-considering whether all these efforts are worth it. We are saving the users a few lines of boilerplate here, but at what cost? At this early stage of the project where time is such a scarce resource, I still think trading a bit of boilerplate for a lot of simplicity is safer, and more efficient for the long run. A UX like this one, without a custom sampler and without a data-loader (at least not for anything RNG-related) seems reasonable to me for an alpha/MVP. |
@NicolasHug As I mentioned in the comment above, I feel like the issues that are not specific to the DataLoader. The reason I say this is because the gist UX that you had shared above has the same two issues that are being referenced here - not setting rng for worker=0 case, not handling rng for multirank+worker case. Again this is about getting RNG state correct which is certainly not easy and I have added unit tests to cover it in this diff. We can debate if that is necessary - do we really to provide reproducability for random transforms? That is a valid question that we can debate and I am fine if we want to punt on it for alpha/MVP. More importantly, I have added the reason why we need a custom dataloader (because we need to override the dataloader next method to keep track of the number of batches that were fetched so that we can seek to that exact offset on checkpoint restore). |
It doesn't have the same issues as long as we call
Having a custom dataloader is one thing. Having a custom dataloader within which we have complex RNG logic only to save users a few lines is another thing. I'm suggesting to reconsider the latter. Perhaps, as you mentioned, we'll need to have a custom dataloader anyway. But hopefully this custom data-loader wouldn't need to handle anything RNG-related. |
Yes I agree. I was just trying to highlight it is hard to get RNG state correct which we both agree on.
We have moved all the RNG related handling to the main trainer process with this diff (that is, out of the dataloader). Let me know if you think that isn't the case. |
This |
@NicolasHug Fair point. Basically you suggest we move the DistributedSampler out of the ReproducibleDataLoader and let the ReproducibleDataLoader take a sampler param in its constructor. We can make that change but I really feel we are that they are not required as the issues we identified were all with setting the RNG state of the libraries (torch, random, numpy) and NOT with the sample order/shuffling. |
Yes, which would reduce
I don't hink we can consider these two things to be disjoint. Doing the first one at the wrong place, or at the wrong time, or with the wrong seeds will lead to the second one being silently buggy. |
@NicolasHug @kartikayk @pbontrager @rohan-varma Updated PR by removing ReproducibleDataLoader and associated tests. Please take a look |
Thanks so much for this change @gokulavasan! I think conceptually this looks a lot cleaner and makes sense to me:
This setup helps decouple the dataloader from the recipe and helps ensure the stateful dataloader is an independent component which can be used outside of TorchTune recipes. |
@@ -1,6 +1,6 @@ | |||
# Dataset and Dataloader | |||
dataset: alpaca | |||
dataloader_seed: null | |||
seed: 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to make this change as seed won't take in null value
@@ -54,6 +54,12 @@ def recipe(kwargs): | |||
# ---- Initialize components ---- # | |||
logger = get_logger() | |||
|
|||
# ---- Initialize seed ---- # | |||
world_size, rank = get_world_size_and_rank() | |||
if "seed" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If seed is not passed in, we don't have generate a rand int32 anymore as torch.initial_seed should already contain a random seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And importantly that initial seed would already be different across rank as this is the default
num_replicas=world_size, | ||
rank=rank, | ||
shuffle=kwargs["shuffle"], | ||
seed=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For anyone wondering, this is the default value for the seed of the DistributedSampler
(look it up).
In follow-up works, if that is deemed relevant, we could make that seed depend on --seed
by drawing a seed in rank0 and propagating it to the other ranks.
That is not needed for this PR though and best left-out as future work, as the ways things are done here already addresses 90% (if not more) of use-cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Gokul, LGTM.
The recipe and the torchtune code-bases are now much, much simpler.
Had to update the loss values in this PR as the DistributedSampler shuffle seed is now set to 0 always. Refer to #161 (comment) for reasoning. Since the data order is now different, loss values changed and thus had to update them for the CI to pass. |
@@ -54,6 +54,12 @@ def recipe(kwargs): | |||
# ---- Initialize components ---- # | |||
logger = get_logger() | |||
|
|||
# ---- Initialize seed ---- # | |||
world_size, rank = get_world_size_and_rank() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't do a thorough review of this. I think it will always return world=0 and rank=1, even if we're distributed training here. Reason is that the process group and distributed is not initialized till L64. So I think this is an issue for distributed training cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch - sorry should have caught this in my review. I didnt pay attention to the non-dataloader stuff.
Changelog
seed
in finetune_llm.py #129.Reasoning:
With what we have right now in torchtune main, we set the worker RNG state and thus user will get the same data in dataloader batches as long as they use same dataloader_seed and same number of workers (changing the number of workers will lead to change in random transform output).
But if the user sets the num_workers to 0, suddenly the dataloader batches will now start changing for every invocation even though the user has set the dataloader_seed. But now this is technically trainer process RNG state and because we have dataset’s methods also being called in the same trainer process, we don’t end up setting the correct RNG state for those transforms.
Ideally if we can control the RNG state for just the dataloader/dataset operations, that would be ideal but I am not sure if that is feasible (nor worth it). This is hard for users to reason about I feel and thus we can probably go back setting torch.manual_seed and then let it control worker RNG state.
Test plan