/
device.py
64 lines (49 loc) · 1.4 KB
/
device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Device handling.
"""
from __future__ import annotations
from typing import Optional
from contextlib import contextmanager
from returnn.tensor import Tensor
__all__ = ["copy_to_device", "get_default_device", "set_default_device", "set_default_device_ctx"]
_default_device: Optional[str] = None
def copy_to_device(x: Tensor, device: Optional[str] = None) -> Tensor:
"""
Copy tensor to device.
:param x: tensor
:param device:
:return: tensor on device
"""
if not device:
device = get_default_device()
if not device:
return x
if x.raw_tensor is None:
return x
if x.device == device:
return x
# noinspection PyProtectedMember
return x._raw_backend.copy_to_device(x, device)
def get_default_device() -> Optional[str]:
"""
:return: default device, where to put new tensors (via random number generators, constant, range_over_dim, etc)
"""
return _default_device
def set_default_device(device: Optional[str]):
"""
:param device: see :func:`get_default_device`
"""
global _default_device
_default_device = device
@contextmanager
def set_default_device_ctx(device: Optional[str]):
"""
:param device: see :func:`get_default_device`
"""
global _default_device
old_device = _default_device
try:
_default_device = device
yield
finally:
_default_device = old_device