Skip to content

Commit

Permalink
Move generator state APIs to ATen (#49589)
Browse files Browse the repository at this point in the history
Summary:
## Rationale

While most of the `torch.Generator` properties and methods are implemented as a thin wrapper of the corresponding `at::Generator` methods, `torch.Generator.get_state()` and `torch.Generator.set_state()` are implemented in legacy Torch code and are not dispatched through the `c10::GeneratorImpl` interface. This is not structured well and makes implementing generators for new backends (e.g. `XLAGeneratorImpl` for the XLA backend) inconvenient. As such, this pull request seeks to move these generator state APIs to c10 and ATen.

## What is being refactored?
* Interfaces
  - Added `c10::GeneratorImpl::set_state` and `c10::GeneratorImpl::state` for getting and setting the internal state of a random number generator.
  - `at::Generator::set_state` and `at::Generator::state` wraps the above-mentioned APIs, as it's basically a PIMPL.
  - Added helper function `at::detail::check_rng_state` for checking the validity of new RNG state tensor.
* CPU Generator
  - Renamed and moved `THTensor_(setRNGState)` and `THTensor_(getRNGState)` to `CPUGeneratorImpl::set_state` and `CPUGenerator::state`.
  - Renamed and moved `THGeneratorState` and `THGeneratorStateNew` to `CPUGeneratorStateLegacy` and `CPUGeneratorState`.
* CUDA Generator
  - Renamed and moved `THCRandom_setRNGState` and `THCRandom_getRNGState` to `CUDAGeneratorImpl::set_state` and `CUDAGeneratorImpl::state`.
* PyTorch Bindings
  - `THPGenerator_setState` and `THPGenerator_getState` now simply forward to `at::Generator::set_state` and `at::Generator::state`.

Pull Request resolved: #49589

Reviewed By: H-Huang

Differential Revision: D25785774

Pulled By: pbelevich

fbshipit-source-id: 8ed79209c4ffb1a0ae8b19952ac8871ac9e0255f
  • Loading branch information
lqf96 authored and facebook-github-bot committed Jan 7, 2021
1 parent b6b76a1 commit 4e2ab2c
Show file tree
Hide file tree
Showing 16 changed files with 294 additions and 252 deletions.
160 changes: 160 additions & 0 deletions aten/src/ATen/CPUGeneratorImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,49 @@
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/Utils.h>
#include <ATen/core/MT19937RNGEngine.h>
#include <c10/util/C++17.h>
#include <algorithm>

namespace at {

namespace detail {

/**
* CPUGeneratorImplStateLegacy is a POD class needed for memcpys
* in torch.get_rng_state() and torch.set_rng_state().
* It is a legacy class and even though it is replaced with
* at::CPUGeneratorImpl, we need this class and some of its fields
* to support backward compatibility on loading checkpoints.
*/
struct CPUGeneratorImplStateLegacy {
/* The initial seed. */
uint64_t the_initial_seed;
int left; /* = 1; */
int seeded; /* = 0; */
uint64_t next;
uint64_t state[at::MERSENNE_STATE_N]; /* the array for the state vector */

/********************************/

/* For normal distribution */
double normal_x;
double normal_y;
double normal_rho;
int normal_is_valid; /* = 0; */
};

/**
* CPUGeneratorImplState is a POD class containing
* new data introduced in at::CPUGeneratorImpl and the legacy state. It is used
* as a helper for torch.get_rng_state() and torch.set_rng_state()
* functions.
*/
struct CPUGeneratorImplState {
CPUGeneratorImplStateLegacy legacy_pod;
float next_float_normal_sample;
bool is_next_float_normal_sample_valid;
};

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
Expand Down Expand Up @@ -75,6 +113,128 @@ uint64_t CPUGeneratorImpl::seed() {
return random;
}

/**
* Sets the internal state of CPUGeneratorImpl. The new internal state
* must be a strided CPU byte tensor and of the same size as either
* CPUGeneratorImplStateLegacy (for legacy CPU generator state) or
* CPUGeneratorImplState (for new state).
*
* FIXME: Remove support of the legacy state in the future?
*/
void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
using detail::CPUGeneratorImplState;
using detail::CPUGeneratorImplStateLegacy;

static_assert(std::is_pod<CPUGeneratorImplStateLegacy>::value, "CPUGeneratorImplStateLegacy is not a PODType");
static_assert(std::is_pod<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");

static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy);
static const size_t size_current = sizeof(CPUGeneratorImplState);
static_assert(size_legacy != size_current, "CPUGeneratorImplStateLegacy and CPUGeneratorImplState can't be of the same size");

detail::check_rng_state(new_state);

at::mt19937 engine;
auto float_normal_sample = c10::optional<float>();
auto double_normal_sample = c10::optional<double>();

// Construct the state of at::CPUGeneratorImpl based on input byte tensor size.
CPUGeneratorImplStateLegacy* legacy_pod;
auto new_state_size = new_state.numel();
if (new_state_size == size_legacy) {
legacy_pod = (CPUGeneratorImplStateLegacy*)new_state.data();
// Note that in CPUGeneratorImplStateLegacy, we didn't have float version
// of normal sample and hence we leave the c10::optional<float> as is

// Update next_double_normal_sample.
// Note that CPUGeneratorImplStateLegacy stores two uniform values (normal_x, normal_y)
// and a rho value (normal_rho). These three values were redundant and in the new
// DistributionsHelper.h, we store the actual extra normal sample, rather than three
// intermediate values.
if (legacy_pod->normal_is_valid) {
auto r = legacy_pod->normal_rho;
auto theta = 2.0 * M_PI * legacy_pod->normal_x;
// we return the sin version of the normal sample when in caching mode
double_normal_sample = c10::optional<double>(r * ::sin(theta));
}
} else if (new_state_size == size_current) {
auto rng_state = (CPUGeneratorImplState*)new_state.data();
legacy_pod = &rng_state->legacy_pod;
// update next_float_normal_sample
if (rng_state->is_next_float_normal_sample_valid) {
float_normal_sample = c10::optional<float>(rng_state->next_float_normal_sample);
}

// Update next_double_normal_sample.
// Note that in getRNGState, we now return the actual normal sample in normal_y
// and if it's valid in normal_is_valid. The redundant normal_x and normal_rho
// are squashed to 0.0.
if (legacy_pod->normal_is_valid) {
double_normal_sample = c10::optional<double>(legacy_pod->normal_y);
}
} else {
AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy,
" or a CPUGeneratorImplState of size ", size_current,
" but found the input RNG state size to be ", new_state_size);
}

// construct engine_
// Note that CPUGeneratorImplStateLegacy stored a state array of 64 bit uints, whereas in our
// redefined mt19937, we have changed to a state array of 32 bit uints. Hence, we are
// doing a std::copy.
at::mt19937_data_pod rng_data;
std::copy(std::begin(legacy_pod->state), std::end(legacy_pod->state), rng_data.state_.begin());
rng_data.seed_ = legacy_pod->the_initial_seed;
rng_data.left_ = legacy_pod->left;
rng_data.seeded_ = legacy_pod->seeded;
rng_data.next_ = static_cast<uint32_t>(legacy_pod->next);
engine.set_data(rng_data);
TORCH_CHECK(engine.is_valid(), "Invalid mt19937 state");
this->engine_ = engine;
this->next_float_normal_sample_ = float_normal_sample;
this->next_double_normal_sample_ = double_normal_sample;
}

/**
* Gets the current internal state of CPUGeneratorImpl. The internal
* state is returned as a CPU byte tensor.
*/
c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
using detail::CPUGeneratorImplState;

static const size_t size = sizeof(CPUGeneratorImplState);
static_assert(std::is_pod<CPUGeneratorImplState>::value, "CPUGeneratorImplState is not a PODType");

auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_tensor.data_ptr();

// accumulate generator data to be copied into byte tensor
auto accum_state = std::make_unique<CPUGeneratorImplState>();
auto rng_data = this->engine_.data();
accum_state->legacy_pod.the_initial_seed = rng_data.seed_;
accum_state->legacy_pod.left = rng_data.left_;
accum_state->legacy_pod.seeded = rng_data.seeded_;
accum_state->legacy_pod.next = rng_data.next_;
std::copy(rng_data.state_.begin(), rng_data.state_.end(), std::begin(accum_state->legacy_pod.state));
accum_state->legacy_pod.normal_x = 0.0; // we don't use it anymore and this is just a dummy
accum_state->legacy_pod.normal_rho = 0.0; // we don't use it anymore and this is just a dummy
accum_state->legacy_pod.normal_is_valid = false;
accum_state->legacy_pod.normal_y = 0.0;
accum_state->next_float_normal_sample = 0.0f;
accum_state->is_next_float_normal_sample_valid = false;
if (this->next_double_normal_sample_) {
accum_state->legacy_pod.normal_is_valid = true;
accum_state->legacy_pod.normal_y = *(this->next_double_normal_sample_);
}
if (this->next_float_normal_sample_) {
accum_state->is_next_float_normal_sample_valid = true;
accum_state->next_float_normal_sample = *(this->next_float_normal_sample_);
}

memcpy(rng_state, accum_state.get(), size);
return state_tensor.getIntrusivePtr();
}

/**
* Gets the DeviceType of CPUGeneratorImpl.
* Used for type checking during run time.
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CPUGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
static DeviceType device_type();
uint32_t random();
uint64_t random64();
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/CUDAGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void set_philox_offset_per_thread(uint64_t offset);
uint64_t philox_offset_per_thread();
uint64_t philox_offset_per_thread() const;
void capture_prologue(int64_t* offset_extragraph);
uint64_t capture_epilogue();
PhiloxCudaState philox_cuda_state(uint64_t increment);
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/core/Generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>

namespace at {

void Generator::set_state(const at::Tensor& new_state) {
TORCH_CHECK(new_state.defined(), "Undefined tensor is not allowed");
this->impl_->set_state(*new_state.unsafeGetTensorImpl());
}

at::Tensor Generator::get_state() const {
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
}

} // namespace at
28 changes: 28 additions & 0 deletions aten/src/ATen/core/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@

namespace at {

class Tensor;

struct TORCH_API Generator {
Generator() {}

Expand Down Expand Up @@ -96,6 +98,12 @@ struct TORCH_API Generator {

uint64_t seed() { return impl_->seed(); }

// Implementation not inlined to prevent cycle reference between
// `ATen/core/Generator.h` and `ATen/core/Tensor.h`
void set_state(const at::Tensor& new_state);

at::Tensor get_state() const;

std::mutex& mutex() {
return impl_->mutex_;
}
Expand Down Expand Up @@ -130,4 +138,24 @@ Generator make_generator(Args&&... args) {
return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
}

namespace detail {

/**
* Helper function for checking the validity of new random generator
* state. Right now following conditions are checked:
*
* - The new state tensor must be a torch.ByteTensor
* - Data of the new state tensor must be contiguous
*/
static inline void check_rng_state(const c10::TensorImpl& new_state) {
TORCH_CHECK_TYPE(
new_state.layout() == kStrided && new_state.device().type() == kCPU && new_state.dtype() == kByte,
"RNG state must be a torch.ByteTensor"
);

TORCH_CHECK(new_state.is_contiguous(), "RNG state must be contiguous");
}

} // namespace detail

} // namespace at
63 changes: 62 additions & 1 deletion aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,67 @@ uint64_t CUDAGeneratorImpl::seed() {
return random;
}

/**
* Gets the current internal state of CUDAGeneratorImpl. The internal
* state is returned as a CPU byte tensor.
*/
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
// The following line is just here for BC reason. sizeof curandStateMtgp32 is 4120.
// It used to be static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32);
// MAX_NUM_BLOCKS was 200 and sizeof(curandStateMtgp32) is 4120. Hardcoding these numbers here
// because this is just host side code and we don't want to worry about linking with cuda
static const size_t states_size = 200 * sizeof(4120);
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = states_size + seed_size + offset_size;

auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
// since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1
// gen_states in THCGenerator struct was an array of curandStateMtgp32s.
memset(rng_state, -1, states_size);
auto current_seed = this->current_seed();
auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
memcpy(rng_state + states_size, &current_seed, seed_size);
memcpy(rng_state + states_size + seed_size, &offset, offset_size);

return state_tensor.getIntrusivePtr();
}

/**
* Sets the internal state of CUDAGeneratorImpl. The new internal state
* must be a strided CPU byte tensor and have appropriate size. See
* comments of CUDAGeneratorImpl::state for information about the layout
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t states_size = 200 * sizeof(4120); // this line is just here for BC reason
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = states_size + seed_size + offset_size;

detail::check_rng_state(new_state);

bool no_philox_seed = false;
auto new_state_size = new_state.numel();
if (new_state_size == total_size - offset_size) {
no_philox_seed = true;
} else {
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size");
}

uint64_t input_seed;
auto new_rng_state = new_state.data<uint8_t>();
memcpy(&input_seed, new_rng_state + states_size, seed_size);
this->set_current_seed(input_seed);
int64_t philox_offset = 0;
if (!no_philox_seed) {
memcpy(&philox_offset, new_rng_state + states_size + seed_size, offset_size);
}
this->set_philox_offset_per_thread(static_cast<uint64_t>(philox_offset));
}

/**
* Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10
*
Expand All @@ -143,7 +204,7 @@ void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
/**
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() {
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::philox_offset_per_thread");
return philox_offset_per_thread_;
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/test/cpu_rng_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ struct TestCPUGenerator : public c10::GeneratorImpl {
void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
uint64_t seed() override { throw std::runtime_error("not implemented"); }
void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }

static DeviceType device_type() { return DeviceType::CPU; }
Expand Down
1 change: 0 additions & 1 deletion aten/src/TH/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ install(FILES
THHalf.h
THTensor.hpp
THStorageFunctions.hpp
THGenerator.hpp
DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH")

install(FILES
Expand Down
39 changes: 0 additions & 39 deletions aten/src/TH/THGenerator.hpp

This file was deleted.

0 comments on commit 4e2ab2c

Please sign in to comment.