Skip to content

Commit

Permalink
Update on "[2/2] Intel GPU Runtime Upstreaming for Generator"
Browse files Browse the repository at this point in the history
# Motivation
According to [[1/2] Intel GPU Runtime Upstreaming for Generator](#118528), as mentioned in [[RFC] Intel GPU Runtime Upstreaming](#114842), the second PR covers the changes under `python frontend`.

# Design
Currently, it primarily offers geneartor-related APIs, including

- `torch.xpu.default_generators`
- `torch.xpu.get_rng_state`
- `torch.xpu.get_rng_state_all`
- `torch.xpu.initial_seed`
- `torch.xpu.manual_seed`
- `torch.xpu.manual_seed_all`
- `torch.xpu.seed`
- `torch.xpu.seed_all`
- `torch.xpu.set_rng_state`
- `torch.xpu.set_rng_state_all`

# Additional Context
The differences with CUDA:
The generator-related frontend python APIs are 1:1 mapping with CUDA.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
  • Loading branch information
guangyey committed Feb 26, 2024
2 parents 9b8b4db + 7b99f52 commit 6ac3726
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
6 changes: 5 additions & 1 deletion aten/src/ATen/detail/XPUHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ struct TORCH_API XPUHooksInterface {
TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
}

virtual Generator getXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Cannot get XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
}

virtual const Generator& getDefaultXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
TORCH_CHECK(false, "Cannot get default XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
}
Expand All @@ -52,7 +56,7 @@ struct TORCH_API XPUHooksInterface {
}

virtual DeviceIndex current_device() const {
return -1;
TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
}

virtual Device getDeviceFromPtr(void* /*data*/) const {
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/xpu/detail/XPUHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ int XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const {
return at::xpu::getGlobalIdxFromDevice(device.index());
}

Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const {
return make_generator<at::XPUGeneratorImpl>(device_index);
}

const Generator& XPUHooks::getDefaultXPUGenerator(
DeviceIndex device_index) const {
return at::xpu::detail::getDefaultXPUGenerator(device_index);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/xpu/detail/XPUHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct XPUHooks : public at::XPUHooksInterface {
bool hasXPU() const override;
std::string showConfig() const override;
int getGlobalIdxFromDevice(const at::Device& device) const override;
Generator getXPUGenerator(DeviceIndex device_index = -1) const override;
const Generator& getDefaultXPUGenerator(
DeviceIndex device_index = -1) const override;
Device getDeviceFromPtr(void* data) const override;
Expand Down
11 changes: 2 additions & 9 deletions torch/csrc/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
#include <ATen/mps/MPSGeneratorImpl.h>
#endif

#ifdef USE_XPU
#include <ATen/xpu/XPUGeneratorImpl.h>
#endif

using namespace at;
using namespace torch;

Expand Down Expand Up @@ -76,12 +72,9 @@ static PyObject* THPGenerator_pynew(
self->cdata = make_generator<MPSGeneratorImpl>();
}
#endif
#ifdef USE_XPU
else if (device.type() == at::kXPU) {
self->cdata = make_generator<XPUGeneratorImpl>(device.index());
}
#endif
else if (device.type() == at::kIPU) {
self->cdata = at::detail::getXPUHooks().getXPUGenerator(device.index());
} else if (device.type() == at::kIPU) {
self->cdata = at::detail::getIPUHooks().newIPUGenerator(device.index());
} else if (device.type() == at::kPrivateUse1) {
self->cdata = at::GetGeneratorForPrivateuse1(device.index());
Expand Down

0 comments on commit 6ac3726

Please sign in to comment.