/
allocation_strategy.py
211 lines (182 loc) · 8.58 KB
/
allocation_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Live variable analysis.
A variable is "dead" at some point if the compiler can find a proof that no
future instruction will read the value before that value is overwritten; "live"
otherwise.
This module implements a liveness analysis for the IR defined in
instructions.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import six
from tensorflow_probability.python.experimental.auto_batching import instructions as inst
from tensorflow_probability.python.experimental.auto_batching import liveness
__all__ = [
'optimize'
]
def optimize(program):
"""Optimizes a `Program`'s variable allocation strategy.
The variable allocation strategies determine how much memory the `Program`
consumes, and how costly its memory access operations are (see
`instructions.VariableAllocation`). In general, a variable holding data with
a longer or more complex lifetime will need a more expensive storage strategy.
This analysis examines variables' liveness and opportunistically selects
inexpensive sound allocation strategies.
Specifically, the algorithm is to:
- Run liveness analysis to determine the lifespan of each variable.
- Assume optimistically that no variable needs to be stored at all
(`instructions.VariableAllocation.NULL`).
- Traverse the instructions and pattern-match conditions that require
some storage:
- If a variable is read by an instruction, it must be at least
`instructions.VariableAllocation.TEMPORARY`.
- If a variable is live out of some block (i.e., crosses a block boundary),
it must be at least `instructions.VariableAllocation.REGISTER`. This is
because temporaries do not appear in the loop state in `execute`.
- If a variable is alive across a call to an autobatched `Function`, it must
be `instructions.VariableAllocation.FULL`, because that `Function` may
push values to it that must not overwrite the value present at the call
point. (This can be improved by examining the call graph to see whether
the callee really does push values to this variable, but that's future
work.)
Args:
program: `Program` to optimize.
Returns:
program: A newly allocated `Program` with the same semantics but possibly
different allocation strategies for some (or all) variables. Each new
strategy may be more efficient than the input `Program`'s allocation
strategy for that variable (if the analysis can prove it safe), but will
not be less efficient.
"""
alloc = {var: inst.VariableAllocation.NULL for var in program.var_defs.keys()}
# The program counter is always read
_variable_is_read(inst.pc_var, alloc)
_optimize_1(program.graph, inst.pattern_flatten(program.vars_out), alloc)
for varname in program.vars_in:
# Because there is a while_loop iteration between the inputs and the first
# block.
_variable_crosses_block_boundary(varname, alloc)
for func in program.functions:
_optimize_1(func.graph, inst.pattern_flatten(func.vars_out), alloc)
for varname in func.vars_in:
_variable_crosses_block_boundary(varname, alloc)
null_vars = [k for k, v in six.iteritems(alloc)
if v is inst.VariableAllocation.NULL]
if null_vars:
logging.warning('Found variables with NULL allocation. These are written '
'but never read: %s', null_vars)
return program.replace(var_alloc=alloc)
def _vars_read_by(op):
if isinstance(op, (inst.FunctionCallOp, inst.PrimOp)):
return op.vars_in
if isinstance(op, inst.BranchOp):
return op.cond_var
return []
def _variable_is_read(varname, alloc):
if alloc[varname] is inst.VariableAllocation.NULL:
alloc[varname] = inst.VariableAllocation.TEMPORARY
def _variable_crosses_block_boundary(varname, alloc):
if (alloc[varname] is inst.VariableAllocation.NULL or
alloc[varname] is inst.VariableAllocation.TEMPORARY):
alloc[varname] = inst.VariableAllocation.REGISTER
def _variable_crosses_function_call_boundary(varname, alloc):
alloc[varname] = inst.VariableAllocation.FULL
def _optimize_1(graph, live_out, alloc):
"""Optimize the variable allocation strategy for one CFG.
Args:
graph: `ControlFlowGraph` to traverse.
live_out: Set of `str` variable names that are live out of this graph (i.e.,
returned by the function this graph represents).
alloc: Dictionary of allocation strategy deductions made so far.
This is mutated; but no variable is moved to a cheaper strategy.
"""
liveness_map = liveness.liveness_analysis(graph, set(live_out))
if graph.exit_index() > 0:
_variable_crosses_block_boundary(inst.pc_var, alloc)
for i in range(graph.exit_index()):
block = graph.block(i)
for op, live_out in zip(
block.instructions, liveness_map[block].live_out_instructions):
for varname in inst.pattern_traverse(_vars_read_by(op)):
_variable_is_read(varname, alloc)
if isinstance(op, inst.FunctionCallOp):
callee_writes = _indirectly_writes(op.function)
for varname in live_out - set(inst.pattern_flatten(op.vars_out)):
# A variable only needs the conservative storage strategy if it
# crosses a call to some function that writes it (e.g., a recursive
# self-call).
if varname in callee_writes:
_variable_crosses_function_call_boundary(varname, alloc)
else:
_variable_crosses_block_boundary(varname, alloc)
# TODO(axch): Actually, the PC only needs a stack at this site if this
# is not a tail call.
_variable_crosses_function_call_boundary(inst.pc_var, alloc)
if isinstance(block.terminator, inst.BranchOp):
# TODO(axch): Actually, being read by BranchOp only implies
# _variable_is_read. However, the downstream VM doesn't know how to pop a
# condition variable that is not needed after the BranchOp, so for now we
# have to allocate a register for it.
_variable_crosses_block_boundary(block.terminator.cond_var, alloc)
for varname in liveness_map[block].live_out_of_block:
_variable_crosses_block_boundary(varname, alloc)
def _directly_writes(graph):
"""Set of variables directly written by the given graph."""
answer = set()
for i in range(graph.exit_index()):
block = graph.block(i)
for op in block.instructions:
if isinstance(op, inst.PrimOp):
answer = answer.union(set(inst.pattern_flatten(op.vars_out)))
elif isinstance(op, inst.FunctionCallOp):
answer = answer.union(set(inst.pattern_flatten(op.vars_out)))
# These because the caller writes them before the goto.
answer = answer.union(set(inst.pattern_flatten(op.function.vars_in)))
elif isinstance(op, inst.PopOp):
# If the stack discipline is followed, any local variable will be
# written by something before it is ever popped. Formal parameters are
# written by the caller and popped before returning.
# Pops should not prevent a variable from being allocated as a register
# instead of a full variable, because pops as such do not cause
# registers to lose and data that a full variable would have kept.
pass
return answer
def _directly_calls(graph):
"""Set of Functions directly called by the given graph."""
# TODO(axch): Deduplicate this and _directly_writes into a generic CFG
# traversal?
answer = set()
for i in range(graph.exit_index()):
block = graph.block(i)
for op in block.instructions:
if isinstance(op, inst.FunctionCallOp):
answer.add(op.function)
return answer
def _indirectly_writes(function):
"""Set of variables written by the given function including callees."""
queue = [function]
visited = set()
answer = set()
while queue:
func = queue.pop()
if func in visited:
continue
visited.add(func)
answer = answer.union(_directly_writes(func.graph))
queue += list(_directly_calls(func.graph))
return answer