/
tensor_utils.py
50 lines (41 loc) · 1.25 KB
/
tensor_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Tensor utils.
"""
from __future__ import annotations
import numpy
import torch
from returnn.tensor import Tensor, TensorDict
def tensor_dict_numpy_to_torch_(x: TensorDict):
"""
:func:`tensor_numpy_to_torch_` on all values
"""
for v in x.data.values():
tensor_numpy_to_torch_(v)
def tensor_numpy_to_torch_(x: Tensor[numpy.ndarray]):
"""
torch.from_numpy() on Tensor, including dims
"""
if x.raw_tensor is None or isinstance(x.raw_tensor, torch.Tensor):
pass
else:
assert isinstance(x.raw_tensor, numpy.ndarray)
x.raw_tensor = torch.from_numpy(x.raw_tensor)
for dim in x.dims:
dim.transform_tensors(tensor_numpy_to_torch_)
def tensor_dict_torch_to_numpy_(x: TensorDict):
"""
:func:`tensor_torch_to_numpy_` on all values
"""
for v in x.data.values():
tensor_torch_to_numpy_(v)
def tensor_torch_to_numpy_(x: Tensor[torch.Tensor]):
"""
.numpy() on Tensor, including dims
"""
if x.raw_tensor is None or isinstance(x.raw_tensor, numpy.ndarray):
pass
else:
assert isinstance(x.raw_tensor, torch.Tensor)
x.raw_tensor = x.raw_tensor.detach().cpu().numpy()
for dim in x.dims:
dim.transform_tensors(tensor_torch_to_numpy_)