forked from chaitjo/learning-paradigms-for-tsp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor_functions.py
34 lines (29 loc) · 1.57 KB
/
tensor_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
def compute_in_batches(f, calc_batch_size, *args, n=None):
"""
Computes memory heavy function f(*args) in batches
:param n: the total number of elements, optional if it cannot be determined as args[0].size(0)
:param f: The function that is computed, should take only tensors as arguments and return tensor or tuple of tensors
:param calc_batch_size: The batch size to use when computing this function
:param args: Tensor arguments with equally sized first batch dimension
:return: f(*args), this should be one or multiple tensors with equally sized first batch dimension
"""
if n is None:
n = args[0].size(0)
n_batches = (n + calc_batch_size - 1) // calc_batch_size # ceil
if n_batches == 1:
return f(*args)
# Run all batches
# all_res = [f(*batch_args) for batch_args in zip(*[torch.chunk(arg, n_batches) for arg in args])]
# We do not use torch.chunk such that it also works for other classes that support slicing
all_res = [f(*(arg[i * calc_batch_size:(i + 1) * calc_batch_size] for arg in args)) for i in range(n_batches)]
# Allow for functions that return None
def safe_cat(chunks, dim=0):
if chunks[0] is None:
assert all(chunk is None for chunk in chunks)
return None
return torch.cat(chunks, dim)
# Depending on whether the function returned a tuple we need to concatenate each element or only the result
if isinstance(all_res[0], tuple):
return tuple(safe_cat(res_chunks, 0) for res_chunks in zip(*all_res))
return safe_cat(all_res, 0)