-
Notifications
You must be signed in to change notification settings - Fork 14
/
utils.py
59 lines (46 loc) · 1.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import List, Union, Optional
import torch
def padded_tensor(
items: List[Union[List[int], torch.LongTensor]],
pad_idx: int = 0,
pad_tail: bool = True,
max_len: Optional[int] = None,
debug: bool = False,
device: torch.device = torch.device('cpu'),
use_amp: bool = False
) -> torch.LongTensor:
"""Create a padded matrix from an uneven list of lists.
Returns padded matrix.
Matrix is right-padded (filled to the right) by default, but can be
left padded if the flag is set to True.
Matrix can also be placed on cuda automatically.
:param list[iter[int]] items: List of items
:param int pad_idx: the value to use for padding
:param bool pad_tail:
:param int max_len: if None, the max length is the maximum item length
:returns: padded tensor.
:rtype: Tensor[int64]
"""
# number of items
n = len(items)
# length of each item
lens: List[int] = [len(item) for item in items]
# max in time dimension
t = max(lens)
# if input tensors are empty, we should expand to nulls
t = max(t, 1)
if debug and max_len is not None:
t = max(t, max_len)
if use_amp:
t = t // 8 * 8
output = torch.full((n, t), fill_value=pad_idx, dtype=torch.long, device=device)
for i, (item, length) in enumerate(zip(items, lens)):
if length == 0:
continue
if not isinstance(item, torch.Tensor):
item = torch.tensor(item, dtype=torch.long, device=device)
if pad_tail:
output[i, :length] = item
else:
output[i, t - length:] = item
return output