/
core.py
187 lines (158 loc) · 7.61 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Core methods for building closures in torch and interfacing with numpy."""
from __future__ import annotations
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
import torch
from botorch.optim.utils import (
_handle_numerical_errors,
get_tensors_as_ndarray_1d,
set_tensors_from_ndarray_1d,
)
from botorch.optim.utils.numpy_utils import as_ndarray
from botorch.utils.context_managers import zero_grad_ctx
from numpy import float64 as np_float64, full as np_full, ndarray, zeros as np_zeros
from torch import Tensor
class ForwardBackwardClosure:
r"""Wrapper for fused forward and backward closures."""
def __init__(
self,
forward: Callable[[], Tensor],
parameters: Dict[str, Tensor],
backward: Callable[[Tensor], None] = Tensor.backward,
reducer: Optional[Callable[[Tensor], Tensor]] = torch.sum,
callback: Optional[Callable[[Tensor, Sequence[Optional[Tensor]]], None]] = None,
context_manager: Callable = None, # pyre-ignore [9]
) -> None:
r"""Initializes a ForwardBackwardClosure instance.
Args:
closure: Callable that returns a tensor.
parameters: A dictionary of tensors whose `grad` fields are to be returned.
backward: Callable that takes the (reduced) output of `forward` and sets the
`grad` attributes of tensors in `parameters`.
reducer: Optional callable used to reduce the output of the forward pass.
callback: Optional callable that takes the reduced output of `forward` and
the gradients of `parameters` as positional arguments.
context_manager: A ContextManager used to wrap each forward-backward call.
When passed as `None`, `context_manager` defaults to a `zero_grad_ctx`
that zeroes the gradients of `parameters` upon entry.
"""
if context_manager is None:
context_manager = partial(zero_grad_ctx, parameters)
self.forward = forward
self.backward = backward
self.parameters = parameters
self.reducer = reducer
self.callback = callback
self.context_manager = context_manager
def __call__(self, **kwargs: Any) -> Tuple[Tensor, Tuple[Optional[Tensor], ...]]:
with self.context_manager():
values = self.forward(**kwargs)
value = values if self.reducer is None else self.reducer(values)
self.backward(value)
grads = tuple(param.grad for param in self.parameters.values())
if self.callback:
self.callback(value, grads)
return value, grads
class NdarrayOptimizationClosure:
r"""Adds stateful behavior and a numpy.ndarray-typed API to a closure with an
expected return type Tuple[Tensor, Union[Tensor, Sequence[Optional[Tensor]]]]."""
def __init__(
self,
closure: Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]],
parameters: Dict[str, Tensor],
as_array: Callable[[Tensor], ndarray] = None, # pyre-ignore [9]
as_tensor: Callable[[ndarray], Tensor] = torch.as_tensor,
get_state: Callable[[], ndarray] = None, # pyre-ignore [9]
set_state: Callable[[ndarray], None] = None, # pyre-ignore [9]
fill_value: float = 0.0,
persistent: bool = True,
) -> None:
r"""Initializes a NdarrayOptimizationClosure instance.
Args:
closure: A ForwardBackwardClosure instance.
parameters: A dictionary of tensors representing the closure's state.
Expected to correspond with the first `len(parameters)` optional
gradient tensors returned by `closure`.
as_array: Callable used to convert tensors to ndarrays.
as_tensor: Callable used to convert ndarrays to tensors.
get_state: Callable that returns the closure's state as an ndarray. When
passed as `None`, defaults to calling `get_tensors_as_ndarray_1d`
on `closure.parameters` while passing `as_array` (if given by the user).
set_state: Callable that takes a 1-dimensional ndarray and sets the
closure's state. When passed as `None`, `set_state` defaults to
calling `set_tensors_from_ndarray_1d` with `closure.parameters` and
a given ndarray while passing `as_tensor`.
fill_value: Fill value for parameters whose gradients are None. In most
cases, `fill_value` should either be zero or NaN.
persistent: Boolean specifying whether an ndarray should be retained
as a persistent buffer for gradients.
"""
if get_state is None:
# Note: Numpy supports copying data between ndarrays with different dtypes.
# Hence, our default behavior need not coerce the ndarray representations
# of tensors in `parameters` to float64 when copying over data.
_as_array = as_ndarray if as_array is None else as_array
get_state = partial(
get_tensors_as_ndarray_1d,
tensors=parameters,
dtype=np_float64,
as_array=_as_array,
)
if as_array is None: # per the note, do this after resolving `get_state`
as_array = partial(as_ndarray, dtype=np_float64)
if set_state is None:
set_state = partial(
set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor
)
self.closure = closure
self.parameters = parameters
self.as_array = as_ndarray
self.as_tensor = as_tensor
self._get_state = get_state
self._set_state = set_state
self.fill_value = fill_value
self.persistent = persistent
self._gradient_ndarray: Optional[ndarray] = None
def __call__(
self, state: Optional[ndarray] = None, **kwargs: Any
) -> Tuple[ndarray, ndarray]:
if state is not None:
self.state = state
try:
value_tensor, grad_tensors = self.closure(**kwargs)
value = self.as_array(value_tensor)
grads = self._get_gradient_ndarray(fill_value=self.fill_value)
index = 0
for param, grad in zip(self.parameters.values(), grad_tensors):
size = param.numel()
if grad is not None:
grads[index : index + size] = self.as_array(grad.view(-1))
index += size
except RuntimeError as e:
value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
return value, grads
@property
def state(self) -> ndarray:
return self._get_state()
@state.setter
def state(self, state: ndarray) -> None:
self._set_state(state)
def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray:
if self.persistent and self._gradient_ndarray is not None:
if fill_value is not None:
self._gradient_ndarray.fill(fill_value)
return self._gradient_ndarray
size = sum(param.numel() for param in self.parameters.values())
array = (
np_zeros(size, dtype=np_float64)
if fill_value is None or fill_value == 0.0
else np_full(size, fill_value, dtype=np_float64)
)
if self.persistent:
self._gradient_ndarray = array
return array