Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloader worker seed fix #161

Merged
merged 18 commits into from
Jan 12, 2024
Merged

Dataloader worker seed fix #161

merged 18 commits into from
Jan 12, 2024

Conversation

gokulavasan
Copy link
Contributor

@gokulavasan gokulavasan commented Jan 9, 2024

Changelog

Reasoning:

With what we have right now in torchtune main, we set the worker RNG state and thus user will get the same data in dataloader batches as long as they use same dataloader_seed and same number of workers (changing the number of workers will lead to change in random transform output).

But if the user sets the num_workers to 0, suddenly the dataloader batches will now start changing for every invocation even though the user has set the dataloader_seed. But now this is technically trainer process RNG state and because we have dataset’s methods also being called in the same trainer process, we don’t end up setting the correct RNG state for those transforms.

Ideally if we can control the RNG state for just the dataloader/dataset operations, that would be ideal but I am not sure if that is feasible (nor worth it). This is hard for users to reason about I feel and thus we can probably go back setting torch.manual_seed and then let it control worker RNG state.

Test plan

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 9, 2024
Copy link

netlify bot commented Jan 9, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit e6602a7
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65a1589eb61fe6000890f3de
😎 Deploy Preview https://deploy-preview-161--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@gokulavasan gokulavasan changed the title Seed fix Dataloader worker seed fix Jan 9, 2024
@gokulavasan gokulavasan marked this pull request as ready for review January 9, 2024 21:29
@@ -54,6 +54,9 @@ def recipe(kwargs):
# ---- Initialize components ---- #
logger = get_logger()

# ---- Initialize seed ---- #
seed(kwargs["seed"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I'm not sure if we need to have separate seeds for training outside of data loading (any random initialization, generation, etc, and dataloader worker). Would be great to get @NicolasHug review on this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the changes in this diff - https://fburl.com/k2o5zm7w. From dataloading/transforms point of view, this is required to seed the random transforms that get executed on the trainer process (when number of dataloader workers is set to 0, that is, no multiprocess dataloading). I did connect with @NicolasHug about this

Copy link
Contributor Author

@gokulavasan gokulavasan Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other alternative is for ReproducibleDataLoader to take in the seed value in its constructor (from recipe kwargs) and then set the seed in its init method (that way it is not present in the recipe). But any of the components created before the DataLoader won't see that seed and I am not sure if this is a concern for other components - Cc @daniellepintz

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am supportive of the change to just have one single seed to control the entire RNG of a training run.

Enabling separate RNG stream for the data-loading and for all the rest can be a nice plus, but it's only that: a nice plus. It's not critical at all, and as we've seen (from past PR discussions and adhoc chats), it leads to some massive complexity that we shouldn't have to deal with in torchtune at that point.

recipes/finetune_llm.py Outdated Show resolved Hide resolved
@@ -54,6 +54,9 @@ def recipe(kwargs):
# ---- Initialize components ---- #
logger = get_logger()

# ---- Initialize seed ---- #
seed(kwargs["seed"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be

Suggested change
seed(kwargs["seed"])
seed(kwargs["seed"] + rank)

because

  1. This is what happens by default when not setting a manual seed.
  2. This will greatly simplify the internals of the ReproducibleDataLoader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yes we need this so that transforms have different RNG state in each trainer (same as the case when num_workers > 0).

That said, if we do this, we need to explicitly pass in the kwargs["seed"] to ReproducibleDataLoader as the seed used for DistributedSampler needs to be same for all ranks.

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jan 10, 2024

Discussed offline with @NicolasHug.

Few things to complete:
i) Within ReproducibleDataloader, add a distributed barrier and ensure seed passed in, is the same across all ranks [Done]
ii) Add unit tests to thoroughly test all the combinations [Done]
iii) Open a task to create awareness on the issue of same worker id in different tanks having the same RNG state and how we are fixing it in this diff [https://github.com//issues/168]

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jan 10, 2024

@rohan-varma @NicolasHug PR ready for review as I have addressed comments

@@ -59,19 +59,24 @@ def __init__( # noqa: DOC101
if isinstance(dataset, IterableDataset):
raise ValueError("ReproducibleDataLoader only supports Map style datasets.")

# Ensure that the seed provided is the same across all ranks
if dist.is_available() and dist.is_initialized():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma @daniellepintz Please review this section to see if this looks okay. I am not sure how to test this logic - any idea how to make sure dist is setup properly either in unit test or when I start finetune_llm.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably run a finetune here with distributed (should be documented in the README) and where the seed is set to make sure this codepath executes as expected.

For example, not sure if things will work OOTB creating a seed_tensor on the CPU and passing it into collectives that may require tensor to live on the GPU. Can help debug any issues that pop up once we can run this.

Copy link
Contributor Author

@gokulavasan gokulavasan Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma I see that cpu (tensor) is a supported backend for collectives - https://pytorch.org/docs/stable/distributed.html#backends. I want to run the distributed and verify it all works but blocked at the moment as the tune command is failing (due to ImportError) in main

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I have a very noob question - why do we need to make this check? IIUC, torch.manual_seed sets the same seed across all devices (ref: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735). This seems like checking PyTorch functionality? Let me know if I misunderstand

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.manual_seed sets the same seed across all devices

yes.

Except that by default, all ranks get a different base seed. So if torch.manual_seed() isn't called and if users were to pass sampler_seed=torch.randint(...), sampler_seed would have a different value for each rank, which would be really bad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug we can get around that by just documenting this in code and making sure the recipe does call manual_seed? There are so many ways in which the user can shoot themselves in the foot, but we can clearly articulate - do it this way if you want you data read in a deterministic fashion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I take this comment back. It makes sense to have this check since this component and the recipe should be decoupled.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for working on this! have some comments / questions.

recipes/finetune_llm.py Outdated Show resolved Hide resolved
recipes/finetune_llm.py Outdated Show resolved Hide resolved
torchtune/trainer/reproducible_dataloader.py Outdated Show resolved Hide resolved
user. If no generator is provided, seed is also used to set the
base_seed for all dataloader workers to ensure transforms are
repeatable. If no seed is provided, a random number is used as the seed.
user. If no seed is provided, a random number is used as the seed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provide some details on what this seed does in the distributed sampler in this docstring itself would be great.

torchtune/trainer/reproducible_dataloader.py Outdated Show resolved Hide resolved
# Ensure that the seed provided is the same across all ranks
if dist.is_available() and dist.is_initialized():
seed_tensor = torch.tensor(sampler_seed)
dist.barrier()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need any of the barriers, as .item() will sync GPU with CPU and we can be sure that the result is correct.

@@ -59,19 +59,24 @@ def __init__( # noqa: DOC101
if isinstance(dataset, IterableDataset):
raise ValueError("ReproducibleDataLoader only supports Map style datasets.")

# Ensure that the seed provided is the same across all ranks
if dist.is_available() and dist.is_initialized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this will only run if init_process_group is called, which is called via init_from_env. If this call is after, then the logic should be okay. cc @daniellepintz

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so I don't require checks in L63? I tried to perform a run without these checks and it failed and thus I put in this check. Will confirm once again with the distributed run documented in README

@@ -59,19 +59,24 @@ def __init__( # noqa: DOC101
if isinstance(dataset, IterableDataset):
raise ValueError("ReproducibleDataLoader only supports Map style datasets.")

# Ensure that the seed provided is the same across all ranks
if dist.is_available() and dist.is_initialized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably run a finetune here with distributed (should be documented in the README) and where the seed is set to make sure this codepath executes as expected.

For example, not sure if things will work OOTB creating a seed_tensor on the CPU and passing it into collectives that may require tensor to live on the GPU. Can help debug any issues that pop up once we can run this.

NicolasHug
NicolasHug previously approved these changes Jan 11, 2024
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gokulavasan , made some comments below.

@gokulavasan @rohan-varma @kartikayk in its latest state (i.e. in this PR), the ReproducibleDataLoader does nothing other than passing a seed to the DistributedSampler, and automatically calls sampler.set_epoch().

In other words: it's a pretty shallow wrapper whose sole purpose is to avoid users to call set_epoch(). We have already missed 2 bugs during code review, one of which is really serious (#177, #168) .

Should we simplify this further and instead of having a custom DataLoader in torchtune, simply have a custom DistributedSampler? This DistributedSampler would have the exact same benefits as the current ReproducibleDataLoader (i.e. automatically call set_epoch() in __iter__()), but without the major problems that come with owning a data-loader.

@@ -54,6 +54,15 @@ def recipe(kwargs):
# ---- Initialize components ---- #
logger = get_logger()

# ---- Initialize seed ---- #
_, rank = get_world_size_and_rank()
if kwargs.get("seed", None) is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should kwargs always contain the "seed" key? In which case we could do a one-liner:

base_seed = kwargs["seed"] or torch.empty((), dtype=torch.int32).random_().item()


for run in range(4):
dataloader = ReproducibleDataLoader(
map_dataset, batch_size=2, shuffle=True, num_workers=4, seed=seed
map_dataset,
batch_size=2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @kartikayk noticed in #65 (comment), batch_size and num_workers are hard-coded here but also parametrized-over.

output_max = dist.all_reduce(seed_tensor, op=ReduceOp.MAX)
dist.barrier()
output_min = dist.all_reduce(seed_tensor, op=ReduceOp.MIN)
if output_max.item() != sampler_seed or output_min.item() != sampler_seed:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking min and max values works but there's a more direct way, e.g. using all_gather to collect all individual values in a list, and then just asserting something like all(v == sampler_seed for v in gathered_list).

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jan 11, 2024

@NicolasHug Thanks for the call out. I do agree that getting the RNG state is tough to get correct. I should have been more proactive in addressing these (which were opened for later during that original PR) - #85 (comment) and #85 (comment).

Basically having these unit tests would have caught these. The issues that the current unit tests in main provide enough confidence on the data iterator order but did miss out the transform randomness control. I have addressed it in this PR by adding these required tests.

I also want to point out that #177 would have been the case even if we had used just the DistributedSampler, the reason being the seed would have been set to just the dataloader_seed in the recipe/trainer process. Worker=0 case (#168) was certainly a miss and the unit test now tests it.

I am fine with reverting back to just using a Sampler but I am currently working on dataloader/sampler state checkpoint here - #176 (it is not complete just yet but opened a draft PR to provide a flavor of the changes). This change will involve tapping into dataloader next calls so that we can keep track of the number of steps that we have invoked so that we can resume from the correct state (that is, seek just after the number of items consumed [not just prefetched]).

@NicolasHug
Copy link
Member

NicolasHug commented Jan 11, 2024

Considering the dozens of hours we have collectively spent on this, and the risk for bugs, I believe it's worth re-considering whether all these efforts are worth it. We are saving the users a few lines of boilerplate here, but at what cost? At this early stage of the project where time is such a scarce resource, I still think trading a bit of boilerplate for a lot of simplicity is safer, and more efficient for the long run. A UX like this one, without a custom sampler and without a data-loader (at least not for anything RNG-related) seems reasonable to me for an alpha/MVP.

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jan 11, 2024

@NicolasHug As I mentioned in the comment above, I feel like the issues that are not specific to the DataLoader. The reason I say this is because the gist UX that you had shared above has the same two issues that are being referenced here - not setting rng for worker=0 case, not handling rng for multirank+worker case.

Again this is about getting RNG state correct which is certainly not easy and I have added unit tests to cover it in this diff. We can debate if that is necessary - do we really to provide reproducability for random transforms? That is a valid question that we can debate and I am fine if we want to punt on it for alpha/MVP.

More importantly, I have added the reason why we need a custom dataloader (because we need to override the dataloader next method to keep track of the number of batches that were fetched so that we can seek to that exact offset on checkpoint restore).

@NicolasHug
Copy link
Member

the gist UX that you had shared above has the same two issues that are being referenced here - not setting rng for worker=0 case, not handling rng for multirank+worker case

It doesn't have the same issues as long as we call .manual_seed(args.seed + rank) which is something we have to do regardless of the technical solution.

I have added the reason why we need a custom dataloader (because we need to override the dataloader next

Having a custom dataloader is one thing. Having a custom dataloader within which we have complex RNG logic only to save users a few lines is another thing. I'm suggesting to reconsider the latter. Perhaps, as you mentioned, we'll need to have a custom dataloader anyway. But hopefully this custom data-loader wouldn't need to handle anything RNG-related.

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jan 11, 2024

@NicolasHug

It doesn't have the same issues as long as we call .manual_seed(args.seed + rank) which is something we have to do regardless of the technical solution.

Yes I agree. I was just trying to highlight it is hard to get RNG state correct which we both agree on.

But hopefully this custom data-loader wouldn't need to handle anything RNG-related.

We have moved all the RNG related handling to the main trainer process with this diff (that is, out of the dataloader). Let me know if you think that isn't the case.

@NicolasHug
Copy link
Member

This ReproducibleDataLoader accepts a sampler_seed parameter, it's not "RNG-agnostic"

@gokulavasan
Copy link
Contributor Author

gokulavasan commented Jan 11, 2024

@NicolasHug Fair point. Basically you suggest we move the DistributedSampler out of the ReproducibleDataLoader and let the ReproducibleDataLoader take a sampler param in its constructor.

We can make that change but I really feel we are that they are not required as the issues we identified were all with setting the RNG state of the libraries (torch, random, numpy) and NOT with the sample order/shuffling.

@NicolasHug
Copy link
Member

you suggest we move the DistributedSampler out of the ReproducibleDataLoader and let the ReproducibleDataLoader take a sampler param in its constructor

Yes, which would reduce ReproducibleDataLodaer to the builtin torch.utils.data.DataLoader and remove the needs for its existence - at least until checkpointing (but we can still keep the checkpointing logic completely separate from dataloading RNG).

issues we identified were all with setting the RNG state of the libraries (torch, random, numpy) and NOT with the sample order/shuffling

I don't hink we can consider these two things to be disjoint. Doing the first one at the wrong place, or at the wrong time, or with the wrong seeds will lead to the second one being silently buggy.

@gokulavasan
Copy link
Contributor Author

@NicolasHug @kartikayk @pbontrager @rohan-varma Updated PR by removing ReproducibleDataLoader and associated tests. Please take a look

@kartikayk
Copy link
Contributor

Thanks so much for this change @gokulavasan! I think conceptually this looks a lot cleaner and makes sense to me:

  • We let the user/recipe manage the setting of the seeds
  • We let the user/recipe setup the right sampler
  • We pass the sampler to a stateful dataloader (WIP) that can help deal with correctly checkpointing the iterator state

This setup helps decouple the dataloader from the recipe and helps ensure the stateful dataloader is an independent component which can be used outside of TorchTune recipes.

recipes/finetune_llm.py Outdated Show resolved Hide resolved
recipes/finetune_llm.py Outdated Show resolved Hide resolved
recipes/finetune_llm.py Outdated Show resolved Hide resolved
@NicolasHug NicolasHug mentioned this pull request Jan 12, 2024
@@ -1,6 +1,6 @@
# Dataset and Dataloader
dataset: alpaca
dataloader_seed: null
seed: 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to make this change as seed won't take in null value

@@ -54,6 +54,12 @@ def recipe(kwargs):
# ---- Initialize components ---- #
logger = get_logger()

# ---- Initialize seed ---- #
world_size, rank = get_world_size_and_rank()
if "seed" in kwargs:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If seed is not passed in, we don't have generate a rand int32 anymore as torch.initial_seed should already contain a random seed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And importantly that initial seed would already be different across rank as this is the default

num_replicas=world_size,
rank=rank,
shuffle=kwargs["shuffle"],
seed=0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For anyone wondering, this is the default value for the seed of the DistributedSampler (look it up).

In follow-up works, if that is deemed relevant, we could make that seed depend on --seed by drawing a seed in rank0 and propagating it to the other ranks.

That is not needed for this PR though and best left-out as future work, as the ways things are done here already addresses 90% (if not more) of use-cases.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Gokul, LGTM.

The recipe and the torchtune code-bases are now much, much simpler.

@gokulavasan
Copy link
Contributor Author

Had to update the loss values in this PR as the DistributedSampler shuffle seed is now set to 0 always. Refer to #161 (comment) for reasoning. Since the data order is now different, loss values changed and thus had to update them for the CI to pass.

@gokulavasan gokulavasan merged commit 5670109 into main Jan 12, 2024
15 checks passed
@gokulavasan gokulavasan deleted the seed-fix branch January 12, 2024 15:31
@@ -54,6 +54,12 @@ def recipe(kwargs):
# ---- Initialize components ---- #
logger = get_logger()

# ---- Initialize seed ---- #
world_size, rank = get_world_size_and_rank()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't do a thorough review of this. I think it will always return world=0 and rank=1, even if we're distributed training here. Reason is that the process group and distributed is not initialized till L64. So I think this is an issue for distributed training cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch - sorry should have caught this in my review. I didnt pay attention to the non-dataloader stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants