/
variable.py
289 lines (261 loc) · 11.2 KB
/
variable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
VariableLayer and related
"""
from __future__ import annotations
from typing import Optional, Sequence, List, TypeVar
import contextlib
import tensorflow as tf
import returnn.tf.compat as tf_compat
import returnn.tf.util.basic as tf_util
from returnn.tensor import Tensor, Dim
from .base import LayerBase
T = TypeVar("T")
class VariableLayer(LayerBase):
"""
Represents a variable. Can add batch/time dimension if wanted. Can be trainable.
See defaults.
"""
layer_class = "variable"
def __init__(
self,
shape,
dtype="float32",
add_batch_axis=False,
add_time_axis=False,
trainable=True,
saveable=True,
non_critical_for_restore=False,
init=None,
init_by_layer=None,
param_name=None,
**kwargs,
):
"""
:param tuple[int|Dim]|list[int|Dim] shape:
:param str dtype:
:param bool add_batch_axis:
:param bool add_time_axis:
:param bool trainable: whether it is updated by grad descent
:param bool saveable: whether it is stored in the checkpoint
:param bool non_critical_for_restore: if True, and it cannot be found in a checkpoint, it will not be an error
:param str|float|int|None init: see :func:`returnn.tf.util.basic.get_initializer`. 0 by default.
Alternatively, you can also use option `init_by_layer`.
:param LayerBase|None init_by_layer:
:param str|None param_name: self.name (layer name) by default
"""
shape # noqa # used in get_out_data_from_opts
super(VariableLayer, self).__init__(trainable=trainable, **kwargs)
assert not self.sources, "%s: does not expect any sources" % self
self.init_by_layer = init_by_layer
dim_tags = list(self.output.dim_tags)
if add_batch_axis:
assert dim_tags[0].is_batch_dim()
dim_tags = dim_tags[1:]
if add_time_axis:
assert dim_tags[0].dimension == 1
dim_tags = dim_tags[1:]
shape_ = [d.dimension for d in dim_tags]
assert all(shape_), self.output # all static
with self.var_creation_scope():
if init_by_layer is None:
if init is None:
init = 0
initializer = tf_util.get_initializer(
init, dtype=dtype, seed=self.network.random.randint(2**31), eval_local_ns={"layer": self}
)
else:
assert init_by_layer is not None
out_data_base = Tensor(name=self.output.name, dim_tags=dim_tags, dtype=dtype)
initializer = init_by_layer.output.copy_compatible_to(out_data_base).placeholder
shape_ = None # get_variable requires shape to be not defined when the initializer is another tensor
self.var = self.add_param(
tf_compat.v1.get_variable(
name=param_name or self.name,
shape=shape_,
dtype=dtype,
initializer=initializer,
trainable=trainable,
),
axes_split_info=[d.axis_split_info() for d in dim_tags],
trainable=trainable,
saveable=saveable,
non_critical_for_restore=non_critical_for_restore,
)
out = self.var
if add_time_axis:
out = tf.expand_dims(out, axis=0)
if add_batch_axis:
# Unbroadcast to not confuse some other layers
batch_dim = self.output.get_batch_dim()
out = tf_util.expand_dims_unbroadcast(out, axis=0, dim=batch_dim)
self.output.placeholder = out
def get_dep_layers(self):
"""
:rtype: list[LayerBase]
"""
deps = super(VariableLayer, self).get_dep_layers()
if self.init_by_layer:
deps.append(self.init_by_layer)
return deps
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d: will modify inplace
:param returnn.tf.network.TFNetwork network:
:param ((str) -> LayerBase) get_layer: function to get or construct another layer
"""
# Overwrite default behavior for default sources.
# Here: none by default.
d.setdefault("from", [])
super(VariableLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
if d.get("init_by_layer", None):
d["init_by_layer"] = get_layer(d["init_by_layer"])
@classmethod
def get_out_data_from_opts(
cls, name, network, shape, dtype="float32", add_batch_axis=False, add_time_axis=False, **kwargs
):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param tuple[int|Dim]|list[int|Dim] shape:
:param str dtype:
:param bool add_batch_axis:
:param bool add_time_axis:
:rtype: Tensor
"""
assert isinstance(shape, (list, tuple))
assert len(shape) == 0 or all(shape)
dim_tags = []
for i, d in enumerate(shape):
if isinstance(d, Dim):
assert d.dimension is not None, "%r: need static dims but got %r" % (name, d)
elif isinstance(d, int):
d = Dim(
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
description="%s:static:%i" % (name, i),
auto_generated=True,
dimension=d,
)
else:
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
dim_tags.append(d)
if add_time_axis:
dim_tags.insert(
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1, auto_generated=True)
)
if add_batch_axis:
from returnn.tensor.dim import batch_dim
dim_tags.insert(0, batch_dim)
return Tensor(
name="%s_output" % name,
dim_tags=dim_tags,
dtype=dtype,
batch=network.get_global_batch_info() if add_batch_axis else None,
)
class VariableAssignLayer(LayerBase):
"""
Assigns a new value to a variable.
"""
layer_class = "variable_assign"
def __init__(
self,
var: LayerBase,
value: LayerBase,
control_dependencies: Optional[Sequence[LayerBase]] = None,
op: str = "assign",
**kwargs,
):
"""
:param var:
:param value:
:param control_dependencies:
:param op: "assign" or "add"
"""
super().__init__(**kwargs)
self.var = var
self.value = value
self.control_dependencies = list(control_dependencies) if control_dependencies else []
deps = [src.output.placeholder.op for src in self.control_dependencies]
while not isinstance(var, VariableLayer):
if isinstance(var, VariableAssignLayer):
deps.append(var.output.placeholder.op)
var = var.var
elif isinstance(var, VariableReadLayer):
deps.append(var.output.placeholder.op)
var = var.var
else:
raise TypeError(f"{self}: invalid var {var!r}")
assert isinstance(var, VariableLayer), f"{self}: var must be a VariableLayer, got {var}"
self.tf_var: tf.Variable = var.var
assert isinstance(self.tf_var, tf.Variable), f"{self}: var must be a tf.Variable, got {self.tf_var}"
value_data = value.output.copy_compatible_to(self.var.output)
with tf.control_dependencies(deps) if deps else contextlib.nullcontext():
if op == "assign":
op_ = self.tf_var.assign(value_data.placeholder, read_value=False)
elif op == "add":
op_ = self.tf_var.assign_add(value_data.placeholder, read_value=False)
else:
raise ValueError(f"{self}: invalid op {op!r}")
# op_ is only defined in graph-mode. in eager-mode, it's not relevant.
with tf.control_dependencies([op_]) if op_ is not None else contextlib.nullcontext():
self.output.placeholder = tf.zeros((), dtype="int32")
def get_dep_layers(self) -> List[LayerBase]:
"""deps"""
return super().get_dep_layers() + [self.var, self.value] + self.control_dependencies
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""transform"""
d.setdefault("from", [])
super().transform_config_dict(d, network=network, get_layer=get_layer)
d["var"] = get_layer(d["var"])
d["value"] = get_layer(d["value"])
if d.get("control_dependencies"):
d["control_dependencies"] = [get_layer(layer) for layer in d["control_dependencies"]]
@classmethod
def get_out_data_from_opts(cls, name: str, var: LayerBase, **kwargs):
"""out"""
return Tensor(name, dims=(), dtype="int32") # dummy, will be just the op
class VariableReadLayer(LayerBase):
"""
Read a variable (currently expected from VariableLayer).
Supports control dependencies to exactly specify when it should be read.
"""
layer_class = "variable_read"
def __init__(self, var: LayerBase, control_dependencies: Optional[Sequence[LayerBase]] = None, **kwargs):
"""
:param var: e.g. VariableLayer
:param control_dependencies: to control what ops must run before the var is read (e.g. assign ops)
"""
super().__init__(**kwargs)
self.var = var
self.control_dependencies = list(control_dependencies) if control_dependencies else []
deps = [src.output.placeholder.op for src in self.control_dependencies]
while not isinstance(var, VariableLayer):
if isinstance(var, VariableAssignLayer):
deps.append(var.output.placeholder.op)
var = var.var
elif isinstance(var, VariableReadLayer):
deps.append(var.output.placeholder.op)
var = var.var
else:
raise TypeError(f"{self}: invalid var {var!r}")
assert isinstance(var, VariableLayer), f"{self}: var must be a VariableLayer, got {var}"
self.tf_var: tf.Variable = var.var
assert isinstance(self.tf_var, tf.Variable), f"{self}: var must be a tf.Variable, got {self.tf_var}"
with tf.control_dependencies(deps) if deps else contextlib.nullcontext():
self.output.placeholder = self.tf_var.read_value()
def get_dep_layers(self) -> List[LayerBase]:
"""deps"""
return super().get_dep_layers() + [self.var] + self.control_dependencies
@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""transform"""
d.setdefault("from", [])
super().transform_config_dict(d, network=network, get_layer=get_layer)
d["var"] = get_layer(d["var"])
if d.get("control_dependencies"):
d["control_dependencies"] = [get_layer(layer) for layer in d["control_dependencies"]]
@classmethod
def get_out_data_from_opts(cls, name: str, var: LayerBase, **kwargs):
"""out"""
return var.output.copy_template(name="%s_output" % name)