-
Notifications
You must be signed in to change notification settings - Fork 25.6k
reseed all Generators in Dataloader's _worker_loop() #107034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107034
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New FailuresAs of commit f73cc30: NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# We also need to create a generator_seed that depends on the current generator state, otherwise | ||
# all Generator instances within a given worker would have the same RNG. | ||
generator_seed = torch.empty((), dtype=torch.int64).random_(generator=generator).item() + seed | ||
generator.manual_seed(generator_seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this is a change of behaviour: In main
, the Generator
instances have the same RNG across workers; now they don't.
I don't think users are / should rely on this behaviour anyway? The only use I can think of where users would want to have the same RNG across workers is if they're using the Generator to shuffle the datasets: when shuffling, you want all workers to shuffle in the same way. But for those who need that, I think it's fair to say that they should be relying on worker_init_fn
anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is technically BC-breaking so we should properly document it (please add a small paragraph in the description for the release notes). But that's ok to do it I think yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small things but the approach sounds good to me.
Will let @ezyang give his opinion as well
|
||
for from_global, from_g1, from_g2 in dl: | ||
# Assert RNG of all Generators are different within a given worker (each "batch" comes from a single worker) | ||
assert len(set([from_global, from_g1, from_g2])) == 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please remove all the plain assert from tests. You can use self.assertEqual()
, self.assertNot()
, etc as appropriate
self->weakreflist = NULL; | ||
|
||
static py::handle _generator_registry = py::module::import("torch").attr("random").attr("_generator_registry"); | ||
_generator_registry.attr("add")(py::cast<py::object>((PyObject*)self.get())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the add above. You don't want to do the "add" dynamically every time right?
} | ||
self->weakreflist = NULL; | ||
|
||
static py::handle _generator_registry = py::module::import("torch").attr("random").attr("_generator_registry"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should have a "release()" at the end here to make sure to leak the reference.
for device, device_rng_state in zip(devices, device_rng_states): | ||
device_mod.set_rng_state(device_rng_state, device) | ||
|
||
# We keep track of all Generator instances (except the default one) via a registry of weak references. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho why not the default one? Because that is the global rng and so it is already handled?
# would be the same as the global one. | ||
# We also need to create a generator_seed that depends on the current generator state, otherwise | ||
# all Generator instances within a given worker would have the same RNG. | ||
generator_seed = torch.empty((), dtype=torch.int64).random_(generator=generator).item() + seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use generator.initial_seed()
?
As we discussed earlier, I am begrudgingly ok with this approach. There are some implementation detail's I'll talk about tomorrow; one in particular is having the registry in C++ and thread safe. EDIT: Actually, we cannot easily do this because you use the Python weakref functionality 🤔 |
seed = base_seed + worker_id | ||
random.seed(seed) | ||
torch.manual_seed(seed) | ||
for generator in torch.random._generator_registry: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs to me, technically, you don't even need the generator registry; we could just gc.get_objects() and traverse the entire live heap to look for generators 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, legit question, why don't we do this? The patch becomes super simple then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave it a try in #107131, LMK what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho yeah that definitely makes it a a lot less invasive if we're ok with traversing all the alive objects on process creation. Which I think we are.
Alternative to #107034, implements @ezyang 's suggestion from #107034 (comment). This PR addresses https://fb.workplace.com/groups/pytorch.oss.dev/posts/1699944830430051 and does a bunch of stacked changes: - Make `Generator` class support GC;this makes all `Generator` instances tracked and accessile through Python's GC. - Use the GC to retrieve all existing Generator instances in Dataloader's `_worker_loop` and re-seed them: this extends what is already applied to the global/default Generator, which is already re-seeded. ~TODO: a bit of docs and justification, which I'll do if this PR is mergeable.~ -- Done CC @albanD @ezyang as previously discussed BC-Breaking Note ------------------- We now re-seed all `Generator` instances within the `Dataloader` workers' loop to ensure that their RNG is different across workers. Previously, the RNG of user-defined `Generators` would be the same across workers, which could lead to wrong training procedures. This only affects user-defined `Generators`, not the default `Generator` (which was already re-seeded). Pull Request resolved: #107131 Approved by: https://github.com/ezyang
Superseded by #107131 |
This PR addresses https://fb.workplace.com/groups/pytorch.oss.dev/posts/1699944830430051 and does a bunch of stacked changes:
Generator
weakref-able (C++ part)Generator
objects (via weakrefs)Dataloader
's_worker_loop
to re-seed all existingGenerator
instances: this extends what is already applied to the globalGenerator
, which is already re-seeded.TODO: a bit of docs and justification, which I'll do if this PR is mergeable.
CC @albanD as previously discussed
cc @ezyang @gchanan @ssnl @VitalyFedyunin @ejguan @dzhulgakov @pbelevich