-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move generator state APIs to ATen (#49589)
Summary: ## Rationale While most of the `torch.Generator` properties and methods are implemented as a thin wrapper of the corresponding `at::Generator` methods, `torch.Generator.get_state()` and `torch.Generator.set_state()` are implemented in legacy Torch code and are not dispatched through the `c10::GeneratorImpl` interface. This is not structured well and makes implementing generators for new backends (e.g. `XLAGeneratorImpl` for the XLA backend) inconvenient. As such, this pull request seeks to move these generator state APIs to c10 and ATen. ## What is being refactored? * Interfaces - Added `c10::GeneratorImpl::set_state` and `c10::GeneratorImpl::state` for getting and setting the internal state of a random number generator. - `at::Generator::set_state` and `at::Generator::state` wraps the above-mentioned APIs, as it's basically a PIMPL. - Added helper function `at::detail::check_rng_state` for checking the validity of new RNG state tensor. * CPU Generator - Renamed and moved `THTensor_(setRNGState)` and `THTensor_(getRNGState)` to `CPUGeneratorImpl::set_state` and `CPUGenerator::state`. - Renamed and moved `THGeneratorState` and `THGeneratorStateNew` to `CPUGeneratorStateLegacy` and `CPUGeneratorState`. * CUDA Generator - Renamed and moved `THCRandom_setRNGState` and `THCRandom_getRNGState` to `CUDAGeneratorImpl::set_state` and `CUDAGeneratorImpl::state`. * PyTorch Bindings - `THPGenerator_setState` and `THPGenerator_getState` now simply forward to `at::Generator::set_state` and `at::Generator::state`. Pull Request resolved: #49589 Reviewed By: H-Huang Differential Revision: D25785774 Pulled By: pbelevich fbshipit-source-id: 8ed79209c4ffb1a0ae8b19952ac8871ac9e0255f
- Loading branch information
1 parent
b6b76a1
commit 4e2ab2c
Showing
16 changed files
with
294 additions
and
252 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#include <ATen/core/Generator.h> | ||
#include <ATen/core/Tensor.h> | ||
#include <c10/util/Exception.h> | ||
|
||
namespace at { | ||
|
||
void Generator::set_state(const at::Tensor& new_state) { | ||
TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed"); | ||
this->impl_->set_state(*new_state.unsafeGetTensorImpl()); | ||
} | ||
|
||
at::Tensor Generator::get_state() const { | ||
return at::Tensor::wrap_tensor_impl(this->impl_->get_state()); | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.