-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Replace Generator* with Generator that holds std::shared_ptr<GeneratorImpl> #34468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CircleCI build failures summary and remediationsAs of commit 52b450d (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following build failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
This PR adds c-tor `Generator(Device device)` to C++ API to match Python API, to avoid code duplication the logic was moved from `THPGenerator_pynew` to corresponding `at::Generator` c-tor. Unfortunately it makes Generator.h depend on CPUGenerator.h and CUDAGenerator.h, but this is how our Python API works :( * #34468 Replace Generator* with Generator that holds std::shared_ptr<GeneratorImpl> [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) [ghstack-poisoned]
This PR adds c-tor `Generator(Device device)` to C++ API to match Python API, to avoid code duplication the logic was moved from `THPGenerator_pynew` to corresponding `at::Generator` c-tor. Unfortunately it makes Generator.h depend on CPUGenerator.h and CUDAGenerator.h, but this is how our Python API works :( * #34468 Replace Generator* with Generator that holds std::shared_ptr<GeneratorImpl> [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) [ghstack-poisoned]
This PR adds c-tor `Generator(Device device)` to C++ API to match Python API, to avoid code duplication the logic was moved from `THPGenerator_pynew` to corresponding `at::Generator` c-tor. Unfortunately it makes Generator.h depend on CPUGenerator.h and CUDAGenerator.h, but this is how our Python API works :( * #34468 Replace Generator* with Generator that holds std::shared_ptr<GeneratorImpl> [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) [ghstack-poisoned]
This PR adds c-tor `Generator(Device device)` to C++ API to match Python API, to avoid code duplication the logic was moved from `THPGenerator_pynew` to corresponding `at::Generator` c-tor. Unfortunately it makes Generator.h depend on CPUGenerator.h and CUDAGenerator.h, but this is how our Python API works :( * #34468 Replace Generator* with Generator that holds std::shared_ptr<GeneratorImpl> [ghstack-poisoned]
std::lock_guard<std::mutex> lock(cuda_gen.mutex_); | ||
cuda_gen.set_current_seed(seed); | ||
std::lock_guard<std::mutex> lock(cuda_gen->mutex_); | ||
cuda_gen->set_current_seed(seed); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also doesn't have to be this PR, but we should consider adding methods like set_current_seed
to Generator directly, so dot-syntax works on them (making Generator consistent with Tensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do in #34987
aten/src/ATen/Utils.h
Outdated
return static_cast<T*>(expr); | ||
static inline T* get_generator_or_default(const Generator& expr, const Generator& defaultValue) { | ||
T* result = expr.defined() ? check_generator<T>(expr) : check_generator<T>(defaultValue); | ||
if (result == nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When can this be nullptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I rewrote this part, please take a look
virtual ~Generator() = default; | ||
std::shared_ptr<Generator> clone() const; | ||
// TODO(pbelevich): delete this after replace Generator generator = nullptr with c10::optional<at::Generator> = c10::nullopt | ||
Generator(std::nullptr_t gen_impl) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty important for us to get rid of this TODO soon, as it's pretty confusing that:
std::shared_ptr<c10::GeneratorImpl> impl = nullptr;
Generator g(impl); // errors
Generator g(nullptr); // this is OK?!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree, #35067
aten/src/ATen/core/Generator.h
Outdated
virtual uint64_t seed() = 0; | ||
Device device() const; | ||
bool operator==(const Generator& rhs) const { | ||
return (!(this->impl_) && !(rhs.impl_)) || (this->impl_ == rhs.impl_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this just be this->impl_ == rhs->impl_
? Because if they're both nullptr they'll compare equal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
aten/src/ATen/core/Generator.h
Outdated
|
||
DispatchKeySet key_set() const { return key_set_; } | ||
bool defined() const { | ||
return (bool)impl_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: static_cast<bool>(impl_)
for extra safety. (Or just let implicit conversion on pointer take care of it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Please take care of replacing the |
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) Differential Revision: [D20549420](https://our.internmc.facebook.com/intern/diff/D20549420) [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) Differential Revision: [D20549420](https://our.internmc.facebook.com/intern/diff/D20549420) [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) Differential Revision: [D20549420](https://our.internmc.facebook.com/intern/diff/D20549420) [ghstack-poisoned]
…tr<GeneratorImpl>" This PR prepares `at::Generator` for pybind11's `type_caster<at::Generator>` which is required to implement custom RNG in python. The following changes are done: 1. `at::Generator` was moved to `c10::GeneratorImpl` (similar to `c10::TensorImpl`) 2. `at::Generator` was recreated as a holder of `std::shared_ptr<c10::GeneratorImpl>` (similar to `at::Tensor` that holds `c10::intrusive_ptr<c10::TensorImpl>`) 3. Most of `at::Generator*` usages were replaced with `at::Generator` TBD: replacing `Generator generator = nullptr` with `{}` requires JIT changes(adding Generator to IValue?) Differential Revision: [D20549420](https://our.internmc.facebook.com/intern/diff/D20549420) [ghstack-poisoned]
@pbelevich merged this pull request in 5306713. |
This PR prepares
at::Generator
for pybind11'stype_caster<at::Generator>
which is required to implement custom RNG in python. The following changes are done:at::Generator
was moved toc10::GeneratorImpl
(similar toc10::TensorImpl
)at::Generator
was recreated as a holder ofstd::shared_ptr<c10::GeneratorImpl>
(similar toat::Tensor
that holdsc10::intrusive_ptr<c10::TensorImpl>
)at::Generator*
usages were replaced withat::Generator
TBD: replacing
Generator generator = nullptr
with{}
requires JIT changes(adding Generator to IValue?)Stack from ghstack:
Differential Revision: D20549420