Skip to content

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Aug 11, 2023

This PR addresses https://fb.workplace.com/groups/pytorch.oss.dev/posts/1699944830430051 and does a bunch of stacked changes:

  1. Make Generator weakref-able (C++ part)
  2. Create a registry of manually-created Generator objects (via weakrefs)
  3. Use this registry in the Dataloader's _worker_loop to re-seed all existing Generator instances: this extends what is already applied to the global Generator, which is already re-seeded.

TODO: a bit of docs and justification, which I'll do if this PR is mergeable.

CC @albanD as previously discussed

cc @ezyang @gchanan @ssnl @VitalyFedyunin @ejguan @dzhulgakov @pbelevich

@NicolasHug NicolasHug requested a review from ejguan as a code owner August 11, 2023 14:02
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107034

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures

As of commit f73cc30:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# We also need to create a generator_seed that depends on the current generator state, otherwise
# all Generator instances within a given worker would have the same RNG.
generator_seed = torch.empty((), dtype=torch.int64).random_(generator=generator).item() + seed
generator.manual_seed(generator_seed)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is a change of behaviour: In main, the Generator instances have the same RNG across workers; now they don't.

I don't think users are / should rely on this behaviour anyway? The only use I can think of where users would want to have the same RNG across workers is if they're using the Generator to shuffle the datasets: when shuffling, you want all workers to shuffle in the same way. But for those who need that, I think it's fair to say that they should be relying on worker_init_fn anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically BC-breaking so we should properly document it (please add a small paragraph in the description for the release notes). But that's ok to do it I think yes.

@NicolasHug NicolasHug added module: dataloader Related to torch.utils.data.DataLoader and Sampler module: random Related to random number generation in PyTorch (rng generator) labels Aug 11, 2023
@albanD albanD added the module: bc-breaking Related to a BC-breaking change label Aug 11, 2023
@pytorch-bot pytorch-bot bot added the topic: bc breaking topic category label Aug 11, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small things but the approach sounds good to me.
Will let @ezyang give his opinion as well


for from_global, from_g1, from_g2 in dl:
# Assert RNG of all Generators are different within a given worker (each "batch" comes from a single worker)
assert len(set([from_global, from_g1, from_g2])) == 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remove all the plain assert from tests. You can use self.assertEqual(), self.assertNot(), etc as appropriate

self->weakreflist = NULL;

static py::handle _generator_registry = py::module::import("torch").attr("random").attr("_generator_registry");
_generator_registry.attr("add")(py::cast<py::object>((PyObject*)self.get()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the add above. You don't want to do the "add" dynamically every time right?

}
self->weakreflist = NULL;

static py::handle _generator_registry = py::module::import("torch").attr("random").attr("_generator_registry");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should have a "release()" at the end here to make sure to leak the reference.

for device, device_rng_state in zip(devices, device_rng_states):
device_mod.set_rng_state(device_rng_state, device)

# We keep track of all Generator instances (except the default one) via a registry of weak references.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho why not the default one? Because that is the global rng and so it is already handled?

# would be the same as the global one.
# We also need to create a generator_seed that depends on the current generator state, otherwise
# all Generator instances within a given worker would have the same RNG.
generator_seed = torch.empty((), dtype=torch.int64).random_(generator=generator).item() + seed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use generator.initial_seed() ?

@ezyang
Copy link
Contributor

ezyang commented Aug 13, 2023

As we discussed earlier, I am begrudgingly ok with this approach. There are some implementation detail's I'll talk about tomorrow; one in particular is having the registry in C++ and thread safe. EDIT: Actually, we cannot easily do this because you use the Python weakref functionality 🤔

seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
for generator in torch.random._generator_registry:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me, technically, you don't even need the generator registry; we could just gc.get_objects() and traverse the entire live heap to look for generators 😆

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, legit question, why don't we do this? The patch becomes super simple then.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave it a try in #107131, LMK what you think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yeah that definitely makes it a a lot less invasive if we're ok with traversing all the alive objects on process creation. Which I think we are.

pytorchmergebot pushed a commit that referenced this pull request Aug 18, 2023
Alternative to #107034, implements @ezyang 's suggestion from #107034 (comment).

This PR addresses https://fb.workplace.com/groups/pytorch.oss.dev/posts/1699944830430051 and does a bunch of stacked changes:

- Make `Generator` class support GC;this makes all `Generator` instances tracked and accessile through Python's GC.
- Use the GC to retrieve all existing Generator instances in Dataloader's `_worker_loop` and re-seed them: this extends what is already applied to the global/default Generator, which is already re-seeded.

~TODO: a bit of docs and justification, which I'll do if this PR is mergeable.~ -- Done

CC @albanD @ezyang  as previously discussed

BC-Breaking Note
-------------------

We now re-seed all `Generator` instances within the `Dataloader` workers' loop to ensure that their RNG is different across workers.
Previously, the RNG of user-defined `Generators` would be the same across workers, which could lead to wrong training procedures. This only affects user-defined `Generators`, not the default `Generator` (which was already re-seeded).

Pull Request resolved: #107131
Approved by: https://github.com/ezyang
@NicolasHug
Copy link
Member Author

Superseded by #107131

@NicolasHug NicolasHug closed this Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: bc-breaking Related to a BC-breaking change module: dataloader Related to torch.utils.data.DataLoader and Sampler module: random Related to random number generation in PyTorch (rng generator) release notes: dataloader release notes category topic: bc breaking topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants