-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[1/2] Intel GPU Runtime Upstreaming for Generator
ghstack-source-id: e99b8f26a89895d4d3473bdd6d40c348ddb5f3b2 Pull Request resolved: #118528
- Loading branch information
Showing
3 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
#include <ATen/xpu/XPUGeneratorImpl.h> | ||
#include <c10/core/StreamGuard.h> | ||
#include <c10/util/CallOnce.h> | ||
#include <c10/xpu/XPUFunctions.h> | ||
|
||
namespace at { | ||
namespace xpu::detail { | ||
namespace { | ||
|
||
/* | ||
* Currently, there is one generator pool containing XPU generator per device. | ||
* Each generator is lazily initialized when the first time generator is | ||
* requested for a device. | ||
*/ | ||
c10::once_flag init_flag; | ||
DeviceIndex num_gpus = -1; | ||
std::deque<c10::once_flag> xpu_gens_init_flag; | ||
std::vector<Generator> default_gens_xpu; | ||
|
||
void initXPUGenVector() { | ||
num_gpus = device_count(); | ||
xpu_gens_init_flag.resize(num_gpus); | ||
default_gens_xpu.resize(num_gpus); | ||
} | ||
|
||
inline void check_device(DeviceIndex device) { | ||
TORCH_CHECK( | ||
device >= 0 && device < num_gpus, | ||
"device is out of range, device is ", | ||
static_cast<int16_t>(device), | ||
", total number of device is ", | ||
static_cast<int16_t>(num_gpus), | ||
"."); | ||
} | ||
|
||
} // anonymous namespace | ||
|
||
const Generator& getDefaultXPUGenerator(DeviceIndex device) { | ||
c10::call_once(init_flag, initXPUGenVector); | ||
if (device == -1) { | ||
device = c10::xpu::current_device(); | ||
} | ||
check_device(device); | ||
c10::call_once(xpu_gens_init_flag[device], [&]() { | ||
default_gens_xpu[device] = make_generator<XPUGeneratorImpl>(device); | ||
default_gens_xpu[device].seed(); | ||
}); | ||
return default_gens_xpu[device]; | ||
} | ||
|
||
Generator createXPUGenerator(DeviceIndex device) { | ||
c10::call_once(init_flag, initXPUGenVector); | ||
if (device == -1) { | ||
device = c10::xpu::current_device(); | ||
} | ||
check_device(device); | ||
auto gen = make_generator<XPUGeneratorImpl>(device); | ||
auto xpu_gen = check_generator<XPUGeneratorImpl>(gen); | ||
xpu_gen->set_current_seed(default_rng_seed_val); | ||
xpu_gen->set_philox_offset_per_thread(0); | ||
return gen; | ||
} | ||
|
||
} // namespace xpu::detail | ||
|
||
XPUGeneratorImpl::XPUGeneratorImpl(DeviceIndex device_index) | ||
: GeneratorImpl{ | ||
Device(DeviceType::XPU, device_index), | ||
DispatchKeySet(c10::DispatchKey::XPU)} {} | ||
|
||
void XPUGeneratorImpl::set_current_seed(uint64_t seed) { | ||
seed_ = seed; | ||
philox_offset_per_thread_ = 0; | ||
} | ||
|
||
void XPUGeneratorImpl::set_offset(uint64_t offset) { | ||
set_philox_offset_per_thread(offset); | ||
} | ||
|
||
uint64_t XPUGeneratorImpl::get_offset() const { | ||
return philox_offset_per_thread_; | ||
} | ||
|
||
uint64_t XPUGeneratorImpl::current_seed() const { | ||
return seed_; | ||
} | ||
|
||
uint64_t XPUGeneratorImpl::seed() { | ||
auto random = c10::detail::getNonDeterministicRandom(true); | ||
this->set_current_seed(random); | ||
return random; | ||
} | ||
|
||
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const { | ||
// The RNG state comprises the seed, and an offset used for Philox. | ||
static const size_t seed_size = sizeof(uint64_t); | ||
static const size_t offset_size = sizeof(uint64_t); | ||
static const size_t total_size = seed_size + offset_size; | ||
|
||
auto state_tensor = at::detail::empty_cpu( | ||
{(int64_t)total_size}, | ||
ScalarType::Byte, | ||
c10::nullopt, | ||
c10::nullopt, | ||
c10::nullopt, | ||
c10::nullopt); | ||
auto rng_state = state_tensor.data_ptr<uint8_t>(); | ||
auto current_seed = this->current_seed(); | ||
auto offset = this->philox_offset_per_thread(); | ||
memcpy(rng_state, ¤t_seed, seed_size); | ||
memcpy(rng_state + seed_size, &offset, offset_size); | ||
|
||
return state_tensor.getIntrusivePtr(); | ||
} | ||
|
||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { | ||
static const size_t seed_size = sizeof(uint64_t); | ||
static const size_t offset_size = sizeof(uint64_t); | ||
static const size_t total_size = seed_size + offset_size; | ||
|
||
at::detail::check_rng_state(new_state); | ||
|
||
bool no_philox_seed = false; | ||
auto new_state_size = new_state.numel(); | ||
if (new_state_size == total_size - offset_size) { | ||
no_philox_seed = true; | ||
} else { | ||
TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); | ||
} | ||
|
||
uint64_t input_seed; | ||
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>(); | ||
memcpy(&input_seed, new_rng_state, seed_size); | ||
this->set_current_seed(input_seed); | ||
uint64_t philox_offset = 0; | ||
if (!no_philox_seed) { | ||
memcpy(&philox_offset, new_rng_state + seed_size, offset_size); | ||
} | ||
this->set_philox_offset_per_thread(philox_offset); | ||
} | ||
|
||
void XPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { | ||
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); | ||
philox_offset_per_thread_ = offset; | ||
} | ||
|
||
uint64_t XPUGeneratorImpl::philox_offset_per_thread() const { | ||
return philox_offset_per_thread_; | ||
} | ||
|
||
std::pair<uint64_t, uint64_t> XPUGeneratorImpl::philox_engine_inputs( | ||
uint64_t increment) { | ||
increment = ((increment + 3) / 4) * 4; | ||
TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); | ||
uint64_t offset = this->philox_offset_per_thread_; | ||
this->philox_offset_per_thread_ += increment; | ||
return std::make_pair(this->seed_, offset); | ||
} | ||
|
||
DeviceType XPUGeneratorImpl::device_type() { | ||
return DeviceType::XPU; | ||
} | ||
|
||
std::shared_ptr<XPUGeneratorImpl> XPUGeneratorImpl::clone() const { | ||
return std::shared_ptr<XPUGeneratorImpl>(this->clone_impl()); | ||
} | ||
|
||
XPUGeneratorImpl* XPUGeneratorImpl::clone_impl() const { | ||
auto gen = new XPUGeneratorImpl(this->device().index()); | ||
gen->set_current_seed(this->seed_); | ||
gen->set_philox_offset_per_thread(this->philox_offset_per_thread_); | ||
return gen; | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#pragma once | ||
|
||
#include <ATen/core/Generator.h> | ||
|
||
namespace at { | ||
|
||
struct TORCH_API XPUGeneratorImpl : public GeneratorImpl { | ||
// Constructors | ||
XPUGeneratorImpl(DeviceIndex device_index = -1); | ||
~XPUGeneratorImpl() override = default; | ||
|
||
// XPUGeneratorImpl methods | ||
std::shared_ptr<XPUGeneratorImpl> clone() const; | ||
void set_current_seed(uint64_t seed) override; | ||
void set_offset(uint64_t offset) override; | ||
uint64_t get_offset() const override; | ||
uint64_t current_seed() const override; | ||
uint64_t seed() override; | ||
void set_state(const c10::TensorImpl& new_state) override; | ||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override; | ||
void set_philox_offset_per_thread(uint64_t offset); | ||
uint64_t philox_offset_per_thread() const; | ||
std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment); | ||
static c10::DeviceType device_type(); | ||
|
||
private: | ||
XPUGeneratorImpl* clone_impl() const override; | ||
uint64_t seed_ = default_rng_seed_val; | ||
uint64_t philox_offset_per_thread_ = 0; | ||
}; | ||
|
||
namespace xpu::detail { | ||
|
||
TORCH_XPU_API const Generator& getDefaultXPUGenerator(DeviceIndex device = -1); | ||
|
||
TORCH_XPU_API Generator createXPUGenerator(DeviceIndex device = -1); | ||
|
||
} // namespace xpu::detail | ||
} // namespace at |