Skip to content

Commit

Permalink
[2/2] Intel GPU Runtime Upstreaming for Generator (#118613)
Browse files Browse the repository at this point in the history
# Motivation
According to [[1/2] Intel GPU Runtime Upstreaming for Generator](#118528), as mentioned in [[RFC] Intel GPU Runtime Upstreaming](#114842), the second PR covers the changes under `python frontend`.

# Design
Currently, it primarily offers geneartor-related APIs, including

- `torch.xpu.default_generators`
- `torch.xpu.get_rng_state`
- `torch.xpu.get_rng_state_all`
- `torch.xpu.initial_seed`
- `torch.xpu.manual_seed`
- `torch.xpu.manual_seed_all`
- `torch.xpu.seed`
- `torch.xpu.seed_all`
- `torch.xpu.set_rng_state`
- `torch.xpu.set_rng_state_all`

# Additional Context
The differences with CUDA:
The generator-related frontend python APIs are 1:1 mapping with CUDA.

Pull Request resolved: #118613
Approved by: https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/jgong5, https://github.com/albanD
  • Loading branch information
guangyey authored and pytorchmergebot committed Feb 28, 2024
1 parent 8ba4cb4 commit 12995a5
Show file tree
Hide file tree
Showing 11 changed files with 407 additions and 3 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/detail/XPUHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ struct TORCH_API XPUHooksInterface {
return 0;
}

virtual DeviceIndex current_device() const {
TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
}

virtual Device getDeviceFromPtr(void* /*data*/) const {
TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
}
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/xpu/detail/XPUHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ int XPUHooks::getGlobalIdxFromDevice(const at::Device& device) const {
return at::xpu::getGlobalIdxFromDevice(device.index());
}

Generator XPUHooks::getXPUGenerator(DeviceIndex device_index) const {
return make_generator<at::XPUGeneratorImpl>(device_index);
}

const Generator& XPUHooks::getDefaultXPUGenerator(
DeviceIndex device_index) const {
return at::xpu::detail::getDefaultXPUGenerator(device_index);
Expand All @@ -40,6 +44,10 @@ int XPUHooks::getNumGPUs() const {
return at::xpu::device_count();
}

DeviceIndex XPUHooks::current_device() const {
return c10::xpu::current_device();
}

void XPUHooks::deviceSynchronize(DeviceIndex device_index) const {
// Only the SYCL queues we have reserved will be synchronized, see Note
// [Synchronize Streams on Device].
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/xpu/detail/XPUHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ struct XPUHooks : public at::XPUHooksInterface {
bool hasXPU() const override;
std::string showConfig() const override;
int getGlobalIdxFromDevice(const at::Device& device) const override;
Generator getXPUGenerator(DeviceIndex device_index = -1) const override;
const Generator& getDefaultXPUGenerator(
DeviceIndex device_index = -1) const override;
Device getDeviceFromPtr(void* data) const override;
int getNumGPUs() const override;
DeviceIndex current_device() const override;
void deviceSynchronize(DeviceIndex device_index) const override;
};

Expand Down
17 changes: 17 additions & 0 deletions docs/source/xpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ torch.xpu
stream
synchronize

Random Number Generator
-------------------------
.. autosummary::
:toctree: generated
:nosignatures:

get_rng_state
get_rng_state_all
initial_seed
manual_seed
manual_seed_all
seed
seed_all
set_rng_state
set_rng_state_all

Streams and events
------------------
.. autosummary::
Expand All @@ -37,4 +53,5 @@ Streams and events

.. This module needs to be documented. Adding here in the meantime
.. for tracking purposes
.. py:module:: torch.xpu.random
.. py:module:: torch.xpu.streams
18 changes: 18 additions & 0 deletions test/test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ def test_events(self):
event.synchronize()
self.assertTrue(event.query())

def test_generator(self):
torch.manual_seed(2024)
g_state0 = torch.xpu.get_rng_state()
torch.manual_seed(1234)
g_state1 = torch.xpu.get_rng_state()
self.assertNotEqual(g_state0, g_state1)

torch.xpu.manual_seed(2024)
g_state2 = torch.xpu.get_rng_state()
self.assertEqual(g_state0, g_state2)

torch.xpu.set_rng_state(g_state1)
self.assertEqual(g_state1, torch.xpu.get_rng_state())

torch.manual_seed(1234)
torch.xpu.set_rng_state(g_state0)
self.assertEqual(2024, torch.xpu.initial_seed())


if __name__ == "__main__":
run_tests()
6 changes: 6 additions & 0 deletions torch/csrc/api/include/torch/xpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ size_t TORCH_API device_count();
/// Returns true if at least one XPU device is available.
bool TORCH_API is_available();

/// Sets the seed for the current GPU.
void TORCH_API manual_seed(uint64_t seed);

/// Sets the seed for all available GPUs.
void TORCH_API manual_seed_all(uint64_t seed);

/// Waits for all kernels in all streams on a XPU device to complete.
void TORCH_API synchronize(int64_t device_index);

Expand Down
25 changes: 25 additions & 0 deletions torch/csrc/api/src/xpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,31 @@ bool is_available() {
return xpu::device_count() > 0;
}

void manual_seed(uint64_t seed) {
if (is_available()) {
auto index = at::detail::getXPUHooks().current_device();
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
}
}

/// Sets the seed for all available GPUs.
void manual_seed_all(uint64_t seed) {
auto num_gpu = device_count();
for (const auto i : c10::irange(num_gpu)) {
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(i);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
}
}

void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No XPU are available");
at::detail::getXPUHooks().deviceSynchronize(
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/xpu/Module.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/xpu/XPUContext.h>
#include <ATen/xpu/XPUGeneratorImpl.h>
#include <c10/util/CallOnce.h>
#include <c10/xpu/XPUCachingAllocator.h>
#include <c10/xpu/XPUFunctions.h>
Expand Down Expand Up @@ -275,6 +276,21 @@ static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) {
if (!m)
throw python_error();

auto set_module_attr = [&](const char* name, PyObject* v) {
if (PyObject_SetAttrString(m, name, v) < 0) {
throw python_error();
}
};

auto num_gpus = c10::xpu::device_count();
THPObjectPtr default_xpu_generators(
PyTuple_New(static_cast<Py_ssize_t>(num_gpus)));
for (const auto i : c10::irange(num_gpus)) {
const auto& gen = at::xpu::detail::getDefaultXPUGenerator(i);
auto* cast_gen = THPGenerator_initDefaultGenerator(gen);
PyTuple_SetItem(default_xpu_generators.get(), i, cast_gen);
}
set_module_attr("default_generators", default_xpu_generators.get());
bindGetDeviceProperties(m);

Py_RETURN_NONE;
Expand Down
8 changes: 8 additions & 0 deletions torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def manual_seed(seed) -> torch._C.Generator:
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)

import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)

_seed_custom_device(seed)

return default_generator.manual_seed(seed)
Expand All @@ -62,6 +66,10 @@ def seed() -> int:
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)

import torch.xpu
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)

_seed_custom_device(seed)

return seed
Expand Down
130 changes: 127 additions & 3 deletions torch/xpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
:func:`is_available()` to determine if your system supports XPU.
"""
import threading
import traceback
from functools import lru_cache
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch._C
Expand All @@ -16,11 +17,40 @@
from .streams import Event, Stream

_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls: List[
Tuple[Callable[[], None], List[str]]
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]


class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []

def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]

def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]

def get_calls(self) -> List:
return self.call_order


_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]


def _is_compiled() -> bool:
r"""Return true if compile with XPU support."""
return torch._C._has_xpu
Expand Down Expand Up @@ -65,6 +95,20 @@ def is_initialized():
return _initialized and not _is_in_bad_fork()


def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))


def init():
r"""Initialize PyTorch's XPU state.
This is a Python API about lazy initialization that avoids initializing
Expand All @@ -75,8 +119,8 @@ def init():


def _lazy_init():
global _initialized
if is_initialized():
global _initialized, _queued_calls
if is_initialized() or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
# This test was was protected via GIL. Double-check whether XPU has
Expand All @@ -93,6 +137,26 @@ def _lazy_init():
raise AssertionError("Torch not compiled with XPU enabled")
# This function inits XPU backend and detects bad fork processing.
torch._C._xpu_init()
# Some of the queued calls may reentrantly call _lazy_init(); We need to
# just return without initializing in that case.
_tls.is_initializing = True

for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)

try:
for queued_call, orig_traceback in _queued_calls:
try:
queued_call()
except Exception as e:
msg = (
f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
)
raise Exception(msg) from e
finally:
delattr(_tls, "is_initializing")
_initialized = True


Expand Down Expand Up @@ -357,25 +421,85 @@ def empty_cache() -> None:
torch._C._xpu_emptyCache()


def _get_generator(device: torch.device) -> torch._C.Generator:
r"""Return the XPU Generator object for the given device.
Args:
device (torch.device): selected device.
"""
idx = device.index
if idx is None:
idx = current_device()
return torch.xpu.default_generators[idx]


def _set_rng_state_offset(
offset: int, device: Union[int, str, torch.device] = "xpu"
) -> None:
r"""Set the random number generator state offset of the specified GPU.
Args:
offset (int): The desired offset
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
"""
final_device = _get_device(device)

def cb():
default_generator = _get_generator(final_device)
default_generator.set_offset(offset)

_lazy_call(cb)


def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
r"""Return the random number generator state offset of the specified GPU.
Args:
device (torch.device or int, optional): The device to return the RNG state offset of.
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
.. warning::
This function eagerly initializes XPU.
"""
_lazy_init()
final_device = _get_device(device)
default_generator = _get_generator(final_device)
return default_generator.get_offset()


from .random import * # noqa: F403


__all__ = [
"Event",
"Stream",
"StreamContext",
"current_device",
"current_stream",
"default_generators",
"device",
"device_of",
"device_count",
"empty_cache",
"get_device_capability",
"get_device_name",
"get_device_properties",
"get_rng_state",
"get_rng_state_all",
"get_stream",
"init",
"initial_seed",
"is_available",
"is_bf16_supported",
"is_initialized",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"set_device",
"set_rng_state",
"set_rng_state_all",
"set_stream",
"stream",
"streams",
Expand Down

0 comments on commit 12995a5

Please sign in to comment.